@ -1,4 +1,5 @@
import site
import dataclasses
from dataclasses import dataclass
from pycparser import c_ast , parse_file
@ -16,8 +17,20 @@ from .instructions import (
class Function ( ) :
name : str
params : list
instructions : list
instructions : list = dataclasses . field ( default_factory = list )
locals : list = dataclasses . field ( init = False )
start : int = dataclasses . field ( init = False )
def __post_init__ ( self ) :
self . locals = self . params [ : ]
self . start = None
self . called = False
@dataclass
class Loop ( ) :
start : int
end_jumps : list = dataclasses . field ( default_factory = list )
"""
@dataclass
@ -27,33 +40,37 @@ class Variable():
"""
class Compiler ( c_ast . NodeVisitor ) :
""" special variables:
__rax : similar to x86 rax
__rbx : stores left hand side of binary ops to avoid clobbering by the right side
__retaddr : stores return address of func call
"""
def __init__ ( self , opt_level = 0 ) :
self . opt_level = opt_level
self . functions : dict = None
self . curr_function : Function = None
self . loop_start : int = None
self . loop_end_jumps : list = None
#self.cond_jump_offset: int = None
self . globals : list = None
self . loops : list = None
#self.loop_start: int = None
#self.loop_end_jumps: list = None
self . loop_end : int = None
def compile ( self , filename : str ) :
self . functions = { }
self . curr_function = None
self . loop_start = None
self . cond_jump_offset = None
self . globals = [ ]
self . loops = [ ]
#self.loop_start = None
#self.loop_end_jumps = None
self . loop_end = None
ast = parse_file (
filename , use_cpp = True , cpp_args = [ " -I " , site . getuserbase ( ) + " /include/python3.8 " ]
)
self . visit ( ast )
#TODO actually handle functions properly
init_call = FunctionCall ( " main " )
if self . opt_level > = 2 :
preamble = [ init_call ]
else :
preamble = [ Set ( " __retaddr_main " , " 2 " ) , init_call , End ( ) ]
offset = len ( preamble )
#set function starts
for function in self . functions . values ( ) :
function . start = offset
@ -89,6 +106,13 @@ class Compiler(c_ast.NodeVisitor):
def curr_offset ( self ) :
return len ( self . curr_function . instructions ) - 1
def get_varname ( self , varname ) :
if varname in self . curr_function . locals :
return f " _{varname}_{self.curr_function.name} "
elif varname not in self . globals :
raise NameError ( f " Unknown variable {varname} " )
return varname
def can_avoid_indirection ( self , var = " __rax " ) :
top = self . peek ( )
return self . opt_level > = 1 and isinstance ( top , Set ) and top . dest == var
@ -114,17 +138,26 @@ class Compiler(c_ast.NodeVisitor):
self . push ( RelativeJump ( None , JumpCondition ( " == " , " __rax " , " 0 " ) ) )
def start_loop ( self , cond ) :
self . loop_start = self . curr_offset ( ) + 1
self . loops . append ( Loop ( self . curr_offset ( ) + 1 ) )
self . visit ( cond )
self . push_body_jump ( )
self . loop_ end_jumps = [ self . curr_offset ( ) ] # also used for breaks
self . loops [ - 1 ] . end_jumps = [ self . curr_offset ( ) ] # also used for breaks
def end_loop ( self ) :
self . push ( RelativeJump ( self . loop_start , JumpCondition . always ) )
for offset in self . loop_end_jumps :
self . curr_function . instructions [ offset ] . offset = self . curr_offset ( ) + 1
self . loop_start = None
self . loop_end_jumps = None
loop = self . loops . pop ( )
self . push ( RelativeJump ( loop . start , JumpCondition . always ) )
self . loop_end = self . curr_offset ( ) + 1
for offset in loop . end_jumps :
self . curr_function . instructions [ offset ] . offset = self . loop_end
def push_ret ( self ) :
if self . opt_level > = 2 and self . curr_function . name == " main " :
top = self . peek ( )
if isinstance ( top , Set ) and top . dest == " __rax " :
self . pop ( )
self . push ( End ( ) )
else :
self . push ( Return ( self . curr_function . name ) )
def optimize_psuedofunc_args ( self , args ) :
if self . opt_level > = 1 :
@ -141,25 +174,32 @@ class Compiler(c_ast.NodeVisitor):
func_decl = node . decl . type
params = [ param_decl . name for param_decl in func_decl . args . params ]
self . curr_function = Function ( func_name , params , [ ] , None )
self . curr_function = Function ( func_name , params )
self . visit ( node . body )
#implicit return
#needed unconditionally in case loop/if body is at end of function
#needed if loop/if body is at end of function or hasn't returned yet
if self . loop_end == self . curr_offset ( ) + 1 or not isinstance ( self . peek ( ) , Return ) :
self . push ( Set ( " __rax " , " null " ) )
self . push ( Return ( self . curr_function . name ) )
self . push_ret ( )
self . functions [ func_name ] = self . curr_function
def visit_Decl ( self , node ) :
if isinstance ( node . type , TypeDecl ) : # variable declaration
#TODO fix local/global split
varname = node . name
if self . curr_function is None : # globals
self . globals . append ( varname )
else :
self . curr_function . locals . append ( varname )
varname = f " _{varname}_{self.curr_function.name} "
if node . init is not None :
self . visit ( node . init )
self . set_to_rax ( node . name )
self . set_to_rax ( var name)
elif isinstance ( node . type , FuncDecl ) :
if node . name not in builtins + func_unary_ops + func_binary_ops :
#create placeholder function for forward declarations
self . functions [ node . name ] = Function (
node . name , [ param_decl . name for param_decl in node . type . args . params ] , [ ] , None
node . name , [ param_decl . name for param_decl in node . type . args . params ]
)
elif isinstance ( node . type , Struct ) :
if node . type . name != " MindustryObject " :
@ -173,7 +213,7 @@ class Compiler(c_ast.NodeVisitor):
def visit_Assignment ( self , node ) :
self . visit ( node . rvalue )
varname = node . lvalue . name
varname = self . get_varname ( node . lvalue . name )
if node . op == " = " : #normal assignment
self . set_to_rax ( varname )
else : #augmented assignment(+=,-=,etc)
@ -182,12 +222,17 @@ class Compiler(c_ast.NodeVisitor):
self . push ( BinaryOp ( varname , varname , self . pop ( ) . src , node . op [ : - 1 ] ) )
else :
self . push ( BinaryOp ( varname , varname , " __rax " , node . op [ : - 1 ] ) )
if self . opt_level < 2 :
self . push ( Set ( " __rax " , varname ) )
def visit_Constant ( self , node ) : # literals
self . push ( Set ( " __rax " , node . value ) )
def visit_ID ( self , node ) : # identifier
self . push ( Set ( " __rax " , node . name ) )
varname = node . name
if varname not in self . functions :
varname = self . get_varname ( varname )
self . push ( Set ( " __rax " , varname ) )
def visit_BinaryOp ( self , node ) :
self . visit ( node . left )
@ -204,11 +249,13 @@ class Compiler(c_ast.NodeVisitor):
def visit_UnaryOp ( self , node ) :
if node . op == " p++ " or node . op == " p-- " : #postincrement/decrement
varname = node . expr . name
if self . opt_level < 2 :
self . push ( Set ( " __rax " , varname ) )
self . push ( BinaryOp ( varname , varname , " 1 " , node . op [ 1 ] ) )
elif node . op == " ++ " or node . op == " -- " :
varname = node . expr . name
self . push ( BinaryOp ( varname , varname , " 1 " , node . op [ 0 ] ) )
if self . opt_level < 2 :
self . push ( Set ( " __rax " , varname ) )
elif node . op == " ! " :
self . visit ( node . expr )
@ -266,14 +313,14 @@ class Compiler(c_ast.NodeVisitor):
def visit_Break ( self , node ) :
self . push ( RelativeJump ( None , JumpCondition . always ) )
self . loop_ end_jumps . append ( self . curr_offset ( ) )
self . loops [ - 1 ] . end_jumps . append ( self . curr_offset ( ) )
def visit_Continue ( self , node ) :
self . push ( RelativeJump ( self . loop_ start , JumpCondition . always ) )
self . push ( RelativeJump ( self . loops [ - 1 ] . start , JumpCondition . always ) )
def visit_Return ( self , node ) :
self . visit ( node . expr )
self . push ( Return ( self . curr_function . name ) )
self . push_ret ( )
def visit_FuncCall ( self , node ) :
name = node . name . name
@ -371,7 +418,7 @@ class Compiler(c_ast.NodeVisitor):
raise ValueError ( f " {name} is not a function " )
for param , arg in zip ( func . params , args ) :
self . visit ( arg )
self . set_to_rax ( param )
self . set_to_rax ( f " { param}_{name} " )
self . push ( Set ( " __retaddr_ " + name , self . curr_offset ( ) + 3 ) )
self . push ( FunctionCall ( name ) )