From b744ef61ea29d4ceffcd50fbfdbf2bc9c107ccb4 Mon Sep 17 00:00:00 2001 From: Larry Xue Date: Mon, 24 Aug 2020 15:48:35 -0400 Subject: [PATCH] added locals/globals, fixed nested loops --- README.md | 18 ++++++- c2logic/compiler.py | 115 ++++++++++++++++++++++++++++------------ c2logic/instructions.py | 2 + examples/func_calls.c | 10 ++-- 4 files changed, 105 insertions(+), 40 deletions(-) diff --git a/README.md b/README.md index 74c2427..572cb92 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,23 @@ Run the command line tool using: `c2logic filename -O optimization_level` -where `filename` is a string and `optimization_level` is an integer. +where `filename` is a string and `optimization_level` is an optional integer +Optimization Level: + +0. completely unoptimized +1. the default + - modify variables without using a temporary +2. turns on some potentially unsafe optimizations + - augmented assignment and pre/postincrement/decrement don't modify \_\_rax + - returning from main becomes equivalent to `end` + +Locals are rewritten as __. Globals are unchanged. + +Special Variables: + +- \_\_rax: similar to x86 rax +- \_\_rbx: stores left hand side of binary ops to avoid clobbering by the right side +- \__retaddr_\*: stores return address of func call When developing your script, you can include `c2logic/builtins.h` located in the python include directory(location depends on system, mine is at `~/.local/include/python3.8/`) diff --git a/c2logic/compiler.py b/c2logic/compiler.py index 1b77320..8569b14 100644 --- a/c2logic/compiler.py +++ b/c2logic/compiler.py @@ -1,4 +1,5 @@ import site +import dataclasses from dataclasses import dataclass from pycparser import c_ast, parse_file @@ -16,8 +17,20 @@ from .instructions import ( class Function(): name: str params: list - instructions: list + + instructions: list = dataclasses.field(default_factory=list) + locals: list = dataclasses.field(init=False) + start: int = dataclasses.field(init=False) + + def __post_init__(self): + self.locals = self.params[:] + self.start = None + self.called = False + +@dataclass +class Loop(): start: int + end_jumps: list = dataclasses.field(default_factory=list) """ @dataclass @@ -27,33 +40,37 @@ class Variable(): """ class Compiler(c_ast.NodeVisitor): - """special variables: - __rax: similar to x86 rax - __rbx: stores left hand side of binary ops to avoid clobbering by the right side - __retaddr: stores return address of func call - """ def __init__(self, opt_level=0): self.opt_level = opt_level self.functions: dict = None self.curr_function: Function = None - self.loop_start: int = None - self.loop_end_jumps: list = None - #self.cond_jump_offset: int = None + self.globals: list = None + self.loops: list = None + #self.loop_start: int = None + #self.loop_end_jumps: list = None + self.loop_end: int = None def compile(self, filename: str): self.functions = {} self.curr_function = None - self.loop_start = None - self.cond_jump_offset = None + self.globals = [] + self.loops = [] + #self.loop_start = None + #self.loop_end_jumps = None + self.loop_end = None ast = parse_file( filename, use_cpp=True, cpp_args=["-I", site.getuserbase() + "/include/python3.8"] ) self.visit(ast) - #TODO actually handle functions properly + init_call = FunctionCall("main") - preamble = [Set("__retaddr_main", "2"), init_call, End()] + if self.opt_level >= 2: + preamble = [init_call] + else: + preamble = [Set("__retaddr_main", "2"), init_call, End()] offset = len(preamble) + #set function starts for function in self.functions.values(): function.start = offset @@ -89,6 +106,13 @@ class Compiler(c_ast.NodeVisitor): def curr_offset(self): return len(self.curr_function.instructions) - 1 + def get_varname(self, varname): + if varname in self.curr_function.locals: + return f"_{varname}_{self.curr_function.name}" + elif varname not in self.globals: + raise NameError(f"Unknown variable {varname}") + return varname + def can_avoid_indirection(self, var="__rax"): top = self.peek() return self.opt_level >= 1 and isinstance(top, Set) and top.dest == var @@ -114,17 +138,26 @@ class Compiler(c_ast.NodeVisitor): self.push(RelativeJump(None, JumpCondition("==", "__rax", "0"))) def start_loop(self, cond): - self.loop_start = self.curr_offset() + 1 + self.loops.append(Loop(self.curr_offset() + 1)) self.visit(cond) self.push_body_jump() - self.loop_end_jumps = [self.curr_offset()] # also used for breaks + self.loops[-1].end_jumps = [self.curr_offset()] # also used for breaks def end_loop(self): - self.push(RelativeJump(self.loop_start, JumpCondition.always)) - for offset in self.loop_end_jumps: - self.curr_function.instructions[offset].offset = self.curr_offset() + 1 - self.loop_start = None - self.loop_end_jumps = None + loop = self.loops.pop() + self.push(RelativeJump(loop.start, JumpCondition.always)) + self.loop_end = self.curr_offset() + 1 + for offset in loop.end_jumps: + self.curr_function.instructions[offset].offset = self.loop_end + + def push_ret(self): + if self.opt_level >= 2 and self.curr_function.name == "main": + top = self.peek() + if isinstance(top, Set) and top.dest == "__rax": + self.pop() + self.push(End()) + else: + self.push(Return(self.curr_function.name)) def optimize_psuedofunc_args(self, args): if self.opt_level >= 1: @@ -141,25 +174,32 @@ class Compiler(c_ast.NodeVisitor): func_decl = node.decl.type params = [param_decl.name for param_decl in func_decl.args.params] - self.curr_function = Function(func_name, params, [], None) + self.curr_function = Function(func_name, params) self.visit(node.body) #implicit return - #needed unconditionally in case loop/if body is at end of function - self.push(Set("__rax", "null")) - self.push(Return(self.curr_function.name)) + #needed if loop/if body is at end of function or hasn't returned yet + if self.loop_end == self.curr_offset() + 1 or not isinstance(self.peek(), Return): + self.push(Set("__rax", "null")) + self.push_ret() self.functions[func_name] = self.curr_function def visit_Decl(self, node): if isinstance(node.type, TypeDecl): # variable declaration #TODO fix local/global split + varname = node.name + if self.curr_function is None: # globals + self.globals.append(varname) + else: + self.curr_function.locals.append(varname) + varname = f"_{varname}_{self.curr_function.name}" if node.init is not None: self.visit(node.init) - self.set_to_rax(node.name) + self.set_to_rax(varname) elif isinstance(node.type, FuncDecl): if node.name not in builtins + func_unary_ops + func_binary_ops: #create placeholder function for forward declarations self.functions[node.name] = Function( - node.name, [param_decl.name for param_decl in node.type.args.params], [], None + node.name, [param_decl.name for param_decl in node.type.args.params] ) elif isinstance(node.type, Struct): if node.type.name != "MindustryObject": @@ -173,7 +213,7 @@ class Compiler(c_ast.NodeVisitor): def visit_Assignment(self, node): self.visit(node.rvalue) - varname = node.lvalue.name + varname = self.get_varname(node.lvalue.name) if node.op == "=": #normal assignment self.set_to_rax(varname) else: #augmented assignment(+=,-=,etc) @@ -182,12 +222,17 @@ class Compiler(c_ast.NodeVisitor): self.push(BinaryOp(varname, varname, self.pop().src, node.op[:-1])) else: self.push(BinaryOp(varname, varname, "__rax", node.op[:-1])) + if self.opt_level < 2: + self.push(Set("__rax", varname)) def visit_Constant(self, node): # literals self.push(Set("__rax", node.value)) def visit_ID(self, node): # identifier - self.push(Set("__rax", node.name)) + varname = node.name + if varname not in self.functions: + varname = self.get_varname(varname) + self.push(Set("__rax", varname)) def visit_BinaryOp(self, node): self.visit(node.left) @@ -204,12 +249,14 @@ class Compiler(c_ast.NodeVisitor): def visit_UnaryOp(self, node): if node.op == "p++" or node.op == "p--": #postincrement/decrement varname = node.expr.name - self.push(Set("__rax", varname)) + if self.opt_level < 2: + self.push(Set("__rax", varname)) self.push(BinaryOp(varname, varname, "1", node.op[1])) elif node.op == "++" or node.op == "--": varname = node.expr.name self.push(BinaryOp(varname, varname, "1", node.op[0])) - self.push(Set("__rax", varname)) + if self.opt_level < 2: + self.push(Set("__rax", varname)) elif node.op == "!": self.visit(node.expr) if self.opt_level >= 1 and isinstance(self.peek(), BinaryOp): @@ -266,14 +313,14 @@ class Compiler(c_ast.NodeVisitor): def visit_Break(self, node): self.push(RelativeJump(None, JumpCondition.always)) - self.loop_end_jumps.append(self.curr_offset()) + self.loops[-1].end_jumps.append(self.curr_offset()) def visit_Continue(self, node): - self.push(RelativeJump(self.loop_start, JumpCondition.always)) + self.push(RelativeJump(self.loops[-1].start, JumpCondition.always)) def visit_Return(self, node): self.visit(node.expr) - self.push(Return(self.curr_function.name)) + self.push_ret() def visit_FuncCall(self, node): name = node.name.name @@ -371,7 +418,7 @@ class Compiler(c_ast.NodeVisitor): raise ValueError(f"{name} is not a function") for param, arg in zip(func.params, args): self.visit(arg) - self.set_to_rax(param) + self.set_to_rax(f"{param}_{name}") self.push(Set("__retaddr_" + name, self.curr_offset() + 3)) self.push(FunctionCall(name)) diff --git a/c2logic/instructions.py b/c2logic/instructions.py index 80c5050..2cc05c9 100644 --- a/c2logic/instructions.py +++ b/c2logic/instructions.py @@ -48,6 +48,8 @@ class JumpCondition: def from_binaryop(cls, binop: BinaryOp): return cls(binop.op, binop.left, binop.right) + always: "JumpCondition" = None + def __str__(self): return f"{condition_ops[self.op]} {self.left} {self.right}" diff --git a/examples/func_calls.c b/examples/func_calls.c index 935c27a..1905f74 100644 --- a/examples/func_calls.c +++ b/examples/func_calls.c @@ -1,12 +1,12 @@ #include "c2logic/builtins.h" extern struct MindustryObject message1; -/*double factorial(int x) { +double factorial(int x) { int ret = 1; for (int i = 2; i <= x; i++) { - ret *= x; + ret *= i; } return ret; -}*/ +} double a(void) { return 5; @@ -14,8 +14,8 @@ double a(void) { void main(void) { printd(a()); print("\n"); - /*printd(factorial(4)); + printd(factorial(4)); print("\n"); - printd(factorial(5));*/ + printd(factorial(5)); printflush(message1); } \ No newline at end of file