|
|
|
@ -21,12 +21,12 @@ class Function(): |
|
|
|
|
|
|
|
instructions: list = dataclasses.field(default_factory=list) |
|
|
|
locals: list = dataclasses.field(init=False) |
|
|
|
start: int = dataclasses.field(init=False) |
|
|
|
start: int = dataclasses.field(default=None, init=False) |
|
|
|
callees: set = dataclasses.field(init=False, default_factory=set) |
|
|
|
callers: set = dataclasses.field(init=False, default_factory=set) |
|
|
|
|
|
|
|
def __post_init__(self): |
|
|
|
self.locals = self.params[:] |
|
|
|
self.start = None |
|
|
|
self.called = False |
|
|
|
|
|
|
|
@dataclass |
|
|
|
class Loop(): |
|
|
|
@ -46,6 +46,7 @@ class Compiler(c_ast.NodeVisitor): |
|
|
|
self.functions: dict = None |
|
|
|
self.curr_function: Function = None |
|
|
|
self.globals: list = None |
|
|
|
#TODO replace this with "blocks" attr on Function |
|
|
|
self.loops: list = None |
|
|
|
self.loop_end: int = None |
|
|
|
|
|
|
|
@ -55,15 +56,14 @@ class Compiler(c_ast.NodeVisitor): |
|
|
|
self.globals = [] |
|
|
|
self.loops = [] |
|
|
|
self.loop_end = None |
|
|
|
ast = parse_file( |
|
|
|
filename, |
|
|
|
use_cpp=True, |
|
|
|
cpp_args=["-I", get_include_path()] |
|
|
|
) |
|
|
|
ast = parse_file(filename, use_cpp=True, cpp_args=["-I", get_include_path()]) |
|
|
|
self.visit(ast) |
|
|
|
|
|
|
|
init_call = FunctionCall("main") |
|
|
|
#remove uncalled functions |
|
|
|
if self.opt_level >= 2: |
|
|
|
self.functions["main"].callers.add("__start") |
|
|
|
self.remove_uncalled_funcs() |
|
|
|
init_call = FunctionCall("main") |
|
|
|
if self.opt_level >= 3: |
|
|
|
preamble = [init_call] |
|
|
|
else: |
|
|
|
preamble = [Set("__retaddr_main", "2"), init_call, End()] |
|
|
|
@ -92,6 +92,30 @@ class Compiler(c_ast.NodeVisitor): |
|
|
|
) |
|
|
|
return "\n".join(out) |
|
|
|
|
|
|
|
def remove_uncalled_funcs(self): |
|
|
|
to_remove = set() |
|
|
|
for name, function in list(self.functions.items()): |
|
|
|
print(function) |
|
|
|
if name in to_remove: |
|
|
|
continue |
|
|
|
callers = set() |
|
|
|
if not self.is_called(function, callers): |
|
|
|
to_remove.add(name) |
|
|
|
to_remove |= callers |
|
|
|
for name in to_remove: |
|
|
|
del self.functions[name] |
|
|
|
|
|
|
|
def is_called(self, function, callers): |
|
|
|
if function.name in callers: #avoid infinite loops |
|
|
|
return False |
|
|
|
for func_name in function.callers: |
|
|
|
if func_name == "__start": |
|
|
|
return True |
|
|
|
callers.add(function.name) |
|
|
|
if self.is_called(self.functions[func_name], callers): |
|
|
|
return True |
|
|
|
return False |
|
|
|
|
|
|
|
#utilities |
|
|
|
def push(self, instruction: Instruction): |
|
|
|
self.curr_function.instructions.append(instruction) |
|
|
|
@ -148,7 +172,7 @@ class Compiler(c_ast.NodeVisitor): |
|
|
|
self.curr_function.instructions[offset].offset = self.loop_end |
|
|
|
|
|
|
|
def push_ret(self): |
|
|
|
if self.opt_level >= 2 and self.curr_function.name == "main": |
|
|
|
if self.opt_level >= 3 and self.curr_function.name == "main": |
|
|
|
top = self.peek() |
|
|
|
if isinstance(top, Set) and top.dest == "__rax": |
|
|
|
self.pop() |
|
|
|
@ -156,7 +180,7 @@ class Compiler(c_ast.NodeVisitor): |
|
|
|
else: |
|
|
|
self.push(Return(self.curr_function.name)) |
|
|
|
|
|
|
|
def optimize_psuedofunc_args(self, args): |
|
|
|
def optimize_builtin_args(self, args): |
|
|
|
if self.opt_level >= 1: |
|
|
|
for i, arg in reversed(list(enumerate(args))): |
|
|
|
if self.can_avoid_indirection(arg): |
|
|
|
@ -191,15 +215,18 @@ class Compiler(c_ast.NodeVisitor): |
|
|
|
self.visit(arg) |
|
|
|
self.set_to_rax(f"__{name}_arg{i}") |
|
|
|
argnames.append(f"__{name}_arg{i}") |
|
|
|
return self.optimize_psuedofunc_args(argnames) |
|
|
|
return self.optimize_builtin_args(argnames) |
|
|
|
|
|
|
|
#visitors |
|
|
|
def visit_FuncDef(self, node): # function definitions |
|
|
|
func_name = node.decl.name |
|
|
|
func_decl = node.decl.type |
|
|
|
params = [param_decl.name for param_decl in func_decl.args.params] |
|
|
|
|
|
|
|
self.curr_function = Function(func_name, params) |
|
|
|
if func_name in self.functions: |
|
|
|
self.curr_function = self.functions[func_name] |
|
|
|
else: |
|
|
|
func_decl = node.decl.type |
|
|
|
print(func_decl) |
|
|
|
params = [param_decl.name for param_decl in func_decl.args.params] |
|
|
|
self.curr_function = Function(func_name, params) |
|
|
|
self.visit(node.body) |
|
|
|
#implicit return |
|
|
|
#needed if loop/if body is at end of function or hasn't returned yet |
|
|
|
@ -247,7 +274,7 @@ class Compiler(c_ast.NodeVisitor): |
|
|
|
self.push(BinaryOp(varname, varname, self.pop().src, node.op[:-1])) |
|
|
|
else: |
|
|
|
self.push(BinaryOp(varname, varname, "__rax", node.op[:-1])) |
|
|
|
if self.opt_level < 2: |
|
|
|
if self.opt_level < 3: |
|
|
|
self.push(Set("__rax", varname)) |
|
|
|
|
|
|
|
def visit_Constant(self, node): # literals |
|
|
|
@ -274,13 +301,13 @@ class Compiler(c_ast.NodeVisitor): |
|
|
|
def visit_UnaryOp(self, node): |
|
|
|
if node.op == "p++" or node.op == "p--": #postincrement/decrement |
|
|
|
varname = self.get_varname(node.expr.name) |
|
|
|
if self.opt_level < 2: |
|
|
|
if self.opt_level < 3: |
|
|
|
self.push(Set("__rax", varname)) |
|
|
|
self.push(BinaryOp(varname, varname, "1", node.op[1])) |
|
|
|
elif node.op == "++" or node.op == "--": |
|
|
|
varname = self.get_varname(node.expr.name) |
|
|
|
self.push(BinaryOp(varname, varname, "1", node.op[0])) |
|
|
|
if self.opt_level < 2: |
|
|
|
if self.opt_level < 3: |
|
|
|
self.push(Set("__rax", varname)) |
|
|
|
elif node.op == "!": |
|
|
|
self.visit(node.expr) |
|
|
|
@ -383,7 +410,7 @@ class Compiler(c_ast.NodeVisitor): |
|
|
|
self.visit(arg) |
|
|
|
self.set_to_rax(f"__radar_arg{i}") |
|
|
|
argnames.append(f"__radar_arg{i}") |
|
|
|
argnames = self.optimize_psuedofunc_args(argnames) |
|
|
|
argnames = self.optimize_builtin_args(argnames) |
|
|
|
self.push(Radar("__rax", *argnames)) #pylint: disable=no-value-for-parameter |
|
|
|
elif name == "sensor": |
|
|
|
self.visit(args[0]) |
|
|
|
@ -419,6 +446,9 @@ class Compiler(c_ast.NodeVisitor): |
|
|
|
func = self.functions[name] |
|
|
|
except KeyError: |
|
|
|
raise ValueError(f"{name} is not a function") |
|
|
|
if self.opt_level >= 2: |
|
|
|
self.curr_function.callees.add(name) |
|
|
|
func.callers.add(self.curr_function.name) |
|
|
|
for param, arg in zip(func.params, args): |
|
|
|
self.visit(arg) |
|
|
|
self.set_to_rax(f"_{param}_{name}") |
|
|
|
@ -443,7 +473,7 @@ def main(): |
|
|
|
import argparse |
|
|
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument("file") |
|
|
|
parser.add_argument("-O", "--optimization-level", type=int, choices=range(3), default=1) |
|
|
|
parser.add_argument("-O", "--optimization-level", type=int, choices=range(4), default=1) |
|
|
|
parser.add_argument("-o", "--output", type=argparse.FileType('w'), default="-") |
|
|
|
args = parser.parse_args() |
|
|
|
print(Compiler(args.optimization_level).compile(args.file), file=args.output) |
|
|
|
|