Browse Source

add uncalled function removal

rlbr-dev
Larry Xue 5 years ago
parent
commit
5d353ad942
  1. 4
      README.md
  2. 74
      c2logic/compiler.py
  3. 25
      examples/dead_code.c

4
README.md

@ -19,7 +19,9 @@ Optimization Level:
0. completely unoptimized.
1. the default
- modify variables without using a temporary
2. turns on some potentially unsafe optimizations
2. more optimizations
- remove uncalled functions
3. turns on some potentially unsafe optimizations
- augmented assignment and pre/postincrement/decrement don't modify `__rax`
- returning from main becomes equivalent to `end`

74
c2logic/compiler.py

@ -21,12 +21,12 @@ class Function():
instructions: list = dataclasses.field(default_factory=list)
locals: list = dataclasses.field(init=False)
start: int = dataclasses.field(init=False)
start: int = dataclasses.field(default=None, init=False)
callees: set = dataclasses.field(init=False, default_factory=set)
callers: set = dataclasses.field(init=False, default_factory=set)
def __post_init__(self):
self.locals = self.params[:]
self.start = None
self.called = False
@dataclass
class Loop():
@ -46,6 +46,7 @@ class Compiler(c_ast.NodeVisitor):
self.functions: dict = None
self.curr_function: Function = None
self.globals: list = None
#TODO replace this with "blocks" attr on Function
self.loops: list = None
self.loop_end: int = None
@ -55,15 +56,14 @@ class Compiler(c_ast.NodeVisitor):
self.globals = []
self.loops = []
self.loop_end = None
ast = parse_file(
filename,
use_cpp=True,
cpp_args=["-I", get_include_path()]
)
ast = parse_file(filename, use_cpp=True, cpp_args=["-I", get_include_path()])
self.visit(ast)
init_call = FunctionCall("main")
#remove uncalled functions
if self.opt_level >= 2:
self.functions["main"].callers.add("__start")
self.remove_uncalled_funcs()
init_call = FunctionCall("main")
if self.opt_level >= 3:
preamble = [init_call]
else:
preamble = [Set("__retaddr_main", "2"), init_call, End()]
@ -92,6 +92,30 @@ class Compiler(c_ast.NodeVisitor):
)
return "\n".join(out)
def remove_uncalled_funcs(self):
to_remove = set()
for name, function in list(self.functions.items()):
print(function)
if name in to_remove:
continue
callers = set()
if not self.is_called(function, callers):
to_remove.add(name)
to_remove |= callers
for name in to_remove:
del self.functions[name]
def is_called(self, function, callers):
if function.name in callers: #avoid infinite loops
return False
for func_name in function.callers:
if func_name == "__start":
return True
callers.add(function.name)
if self.is_called(self.functions[func_name], callers):
return True
return False
#utilities
def push(self, instruction: Instruction):
self.curr_function.instructions.append(instruction)
@ -148,7 +172,7 @@ class Compiler(c_ast.NodeVisitor):
self.curr_function.instructions[offset].offset = self.loop_end
def push_ret(self):
if self.opt_level >= 2 and self.curr_function.name == "main":
if self.opt_level >= 3 and self.curr_function.name == "main":
top = self.peek()
if isinstance(top, Set) and top.dest == "__rax":
self.pop()
@ -156,7 +180,7 @@ class Compiler(c_ast.NodeVisitor):
else:
self.push(Return(self.curr_function.name))
def optimize_psuedofunc_args(self, args):
def optimize_builtin_args(self, args):
if self.opt_level >= 1:
for i, arg in reversed(list(enumerate(args))):
if self.can_avoid_indirection(arg):
@ -191,15 +215,18 @@ class Compiler(c_ast.NodeVisitor):
self.visit(arg)
self.set_to_rax(f"__{name}_arg{i}")
argnames.append(f"__{name}_arg{i}")
return self.optimize_psuedofunc_args(argnames)
return self.optimize_builtin_args(argnames)
#visitors
def visit_FuncDef(self, node): # function definitions
func_name = node.decl.name
func_decl = node.decl.type
params = [param_decl.name for param_decl in func_decl.args.params]
self.curr_function = Function(func_name, params)
if func_name in self.functions:
self.curr_function = self.functions[func_name]
else:
func_decl = node.decl.type
print(func_decl)
params = [param_decl.name for param_decl in func_decl.args.params]
self.curr_function = Function(func_name, params)
self.visit(node.body)
#implicit return
#needed if loop/if body is at end of function or hasn't returned yet
@ -247,7 +274,7 @@ class Compiler(c_ast.NodeVisitor):
self.push(BinaryOp(varname, varname, self.pop().src, node.op[:-1]))
else:
self.push(BinaryOp(varname, varname, "__rax", node.op[:-1]))
if self.opt_level < 2:
if self.opt_level < 3:
self.push(Set("__rax", varname))
def visit_Constant(self, node): # literals
@ -274,13 +301,13 @@ class Compiler(c_ast.NodeVisitor):
def visit_UnaryOp(self, node):
if node.op == "p++" or node.op == "p--": #postincrement/decrement
varname = self.get_varname(node.expr.name)
if self.opt_level < 2:
if self.opt_level < 3:
self.push(Set("__rax", varname))
self.push(BinaryOp(varname, varname, "1", node.op[1]))
elif node.op == "++" or node.op == "--":
varname = self.get_varname(node.expr.name)
self.push(BinaryOp(varname, varname, "1", node.op[0]))
if self.opt_level < 2:
if self.opt_level < 3:
self.push(Set("__rax", varname))
elif node.op == "!":
self.visit(node.expr)
@ -383,7 +410,7 @@ class Compiler(c_ast.NodeVisitor):
self.visit(arg)
self.set_to_rax(f"__radar_arg{i}")
argnames.append(f"__radar_arg{i}")
argnames = self.optimize_psuedofunc_args(argnames)
argnames = self.optimize_builtin_args(argnames)
self.push(Radar("__rax", *argnames)) #pylint: disable=no-value-for-parameter
elif name == "sensor":
self.visit(args[0])
@ -419,6 +446,9 @@ class Compiler(c_ast.NodeVisitor):
func = self.functions[name]
except KeyError:
raise ValueError(f"{name} is not a function")
if self.opt_level >= 2:
self.curr_function.callees.add(name)
func.callers.add(self.curr_function.name)
for param, arg in zip(func.params, args):
self.visit(arg)
self.set_to_rax(f"_{param}_{name}")
@ -443,7 +473,7 @@ def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("file")
parser.add_argument("-O", "--optimization-level", type=int, choices=range(3), default=1)
parser.add_argument("-O", "--optimization-level", type=int, choices=range(4), default=1)
parser.add_argument("-o", "--output", type=argparse.FileType('w'), default="-")
args = parser.parse_args()
print(Compiler(args.optimization_level).compile(args.file), file=args.output)

25
examples/dead_code.c

@ -0,0 +1,25 @@
#include "c2logic/builtins.h"
// expected output on -O2 is only d should be in the compiled output
void a(void);
void b(void);
void c(void);
void d(void);
void a(void) {
print("a");
b();
}
void b(void) {
print("b");
c();
a();
d();
}
void c(void) {
print("c");
}
void d(void) {
print("d");
}
void main(void) {
d();
}
Loading…
Cancel
Save