diff --git a/README.md b/README.md index 1845cd2..fc4dfe1 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,9 @@ Optimization Level: 0. completely unoptimized. 1. the default - modify variables without using a temporary -2. turns on some potentially unsafe optimizations +2. more optimizations + - remove uncalled functions +3. turns on some potentially unsafe optimizations - augmented assignment and pre/postincrement/decrement don't modify `__rax` - returning from main becomes equivalent to `end` diff --git a/c2logic/compiler.py b/c2logic/compiler.py index a7bdb87..2ddb903 100644 --- a/c2logic/compiler.py +++ b/c2logic/compiler.py @@ -21,12 +21,12 @@ class Function(): instructions: list = dataclasses.field(default_factory=list) locals: list = dataclasses.field(init=False) - start: int = dataclasses.field(init=False) + start: int = dataclasses.field(default=None, init=False) + callees: set = dataclasses.field(init=False, default_factory=set) + callers: set = dataclasses.field(init=False, default_factory=set) def __post_init__(self): self.locals = self.params[:] - self.start = None - self.called = False @dataclass class Loop(): @@ -46,6 +46,7 @@ class Compiler(c_ast.NodeVisitor): self.functions: dict = None self.curr_function: Function = None self.globals: list = None + #TODO replace this with "blocks" attr on Function self.loops: list = None self.loop_end: int = None @@ -55,15 +56,14 @@ class Compiler(c_ast.NodeVisitor): self.globals = [] self.loops = [] self.loop_end = None - ast = parse_file( - filename, - use_cpp=True, - cpp_args=["-I", get_include_path()] - ) + ast = parse_file(filename, use_cpp=True, cpp_args=["-I", get_include_path()]) self.visit(ast) - - init_call = FunctionCall("main") + #remove uncalled functions if self.opt_level >= 2: + self.functions["main"].callers.add("__start") + self.remove_uncalled_funcs() + init_call = FunctionCall("main") + if self.opt_level >= 3: preamble = [init_call] else: preamble = [Set("__retaddr_main", "2"), init_call, End()] @@ -92,6 +92,30 @@ class Compiler(c_ast.NodeVisitor): ) return "\n".join(out) + def remove_uncalled_funcs(self): + to_remove = set() + for name, function in list(self.functions.items()): + print(function) + if name in to_remove: + continue + callers = set() + if not self.is_called(function, callers): + to_remove.add(name) + to_remove |= callers + for name in to_remove: + del self.functions[name] + + def is_called(self, function, callers): + if function.name in callers: #avoid infinite loops + return False + for func_name in function.callers: + if func_name == "__start": + return True + callers.add(function.name) + if self.is_called(self.functions[func_name], callers): + return True + return False + #utilities def push(self, instruction: Instruction): self.curr_function.instructions.append(instruction) @@ -148,7 +172,7 @@ class Compiler(c_ast.NodeVisitor): self.curr_function.instructions[offset].offset = self.loop_end def push_ret(self): - if self.opt_level >= 2 and self.curr_function.name == "main": + if self.opt_level >= 3 and self.curr_function.name == "main": top = self.peek() if isinstance(top, Set) and top.dest == "__rax": self.pop() @@ -156,7 +180,7 @@ class Compiler(c_ast.NodeVisitor): else: self.push(Return(self.curr_function.name)) - def optimize_psuedofunc_args(self, args): + def optimize_builtin_args(self, args): if self.opt_level >= 1: for i, arg in reversed(list(enumerate(args))): if self.can_avoid_indirection(arg): @@ -191,15 +215,18 @@ class Compiler(c_ast.NodeVisitor): self.visit(arg) self.set_to_rax(f"__{name}_arg{i}") argnames.append(f"__{name}_arg{i}") - return self.optimize_psuedofunc_args(argnames) + return self.optimize_builtin_args(argnames) #visitors def visit_FuncDef(self, node): # function definitions func_name = node.decl.name - func_decl = node.decl.type - params = [param_decl.name for param_decl in func_decl.args.params] - - self.curr_function = Function(func_name, params) + if func_name in self.functions: + self.curr_function = self.functions[func_name] + else: + func_decl = node.decl.type + print(func_decl) + params = [param_decl.name for param_decl in func_decl.args.params] + self.curr_function = Function(func_name, params) self.visit(node.body) #implicit return #needed if loop/if body is at end of function or hasn't returned yet @@ -247,7 +274,7 @@ class Compiler(c_ast.NodeVisitor): self.push(BinaryOp(varname, varname, self.pop().src, node.op[:-1])) else: self.push(BinaryOp(varname, varname, "__rax", node.op[:-1])) - if self.opt_level < 2: + if self.opt_level < 3: self.push(Set("__rax", varname)) def visit_Constant(self, node): # literals @@ -274,13 +301,13 @@ class Compiler(c_ast.NodeVisitor): def visit_UnaryOp(self, node): if node.op == "p++" or node.op == "p--": #postincrement/decrement varname = self.get_varname(node.expr.name) - if self.opt_level < 2: + if self.opt_level < 3: self.push(Set("__rax", varname)) self.push(BinaryOp(varname, varname, "1", node.op[1])) elif node.op == "++" or node.op == "--": varname = self.get_varname(node.expr.name) self.push(BinaryOp(varname, varname, "1", node.op[0])) - if self.opt_level < 2: + if self.opt_level < 3: self.push(Set("__rax", varname)) elif node.op == "!": self.visit(node.expr) @@ -383,7 +410,7 @@ class Compiler(c_ast.NodeVisitor): self.visit(arg) self.set_to_rax(f"__radar_arg{i}") argnames.append(f"__radar_arg{i}") - argnames = self.optimize_psuedofunc_args(argnames) + argnames = self.optimize_builtin_args(argnames) self.push(Radar("__rax", *argnames)) #pylint: disable=no-value-for-parameter elif name == "sensor": self.visit(args[0]) @@ -419,6 +446,9 @@ class Compiler(c_ast.NodeVisitor): func = self.functions[name] except KeyError: raise ValueError(f"{name} is not a function") + if self.opt_level >= 2: + self.curr_function.callees.add(name) + func.callers.add(self.curr_function.name) for param, arg in zip(func.params, args): self.visit(arg) self.set_to_rax(f"_{param}_{name}") @@ -443,7 +473,7 @@ def main(): import argparse parser = argparse.ArgumentParser() parser.add_argument("file") - parser.add_argument("-O", "--optimization-level", type=int, choices=range(3), default=1) + parser.add_argument("-O", "--optimization-level", type=int, choices=range(4), default=1) parser.add_argument("-o", "--output", type=argparse.FileType('w'), default="-") args = parser.parse_args() print(Compiler(args.optimization_level).compile(args.file), file=args.output) diff --git a/examples/dead_code.c b/examples/dead_code.c new file mode 100644 index 0000000..d755364 --- /dev/null +++ b/examples/dead_code.c @@ -0,0 +1,25 @@ +#include "c2logic/builtins.h" +// expected output on -O2 is only d should be in the compiled output +void a(void); +void b(void); +void c(void); +void d(void); +void a(void) { + print("a"); + b(); +} +void b(void) { + print("b"); + c(); + a(); + d(); +} +void c(void) { + print("c"); +} +void d(void) { + print("d"); +} +void main(void) { + d(); +} \ No newline at end of file