You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

452 lines
14 KiB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
  1. import os
  2. import sysconfig
  3. import dataclasses
  4. from dataclasses import dataclass
  5. from pycparser import c_ast, parse_file
  6. from pycparser.c_ast import (
  7. Compound, Constant, DeclList, Enum, FileAST, FuncDecl, Struct, TypeDecl
  8. )
  9. from .consts import builtins, draw_funcs, func_binary_ops, func_unary_ops
  10. from .instructions import (
  11. BinaryOp, Draw, DrawFlush, Enable, End, FunctionCall, GetLink, Instruction, JumpCondition,
  12. Print, PrintFlush, Radar, RawAsm, Read, RelativeJump, Return, Sensor, Set, Shoot, UnaryOp, Write
  13. )
  14. @dataclass
  15. class Function():
  16. name: str
  17. params: list
  18. instructions: list = dataclasses.field(default_factory=list)
  19. locals: list = dataclasses.field(init=False)
  20. start: int = dataclasses.field(init=False)
  21. def __post_init__(self):
  22. self.locals = self.params[:]
  23. self.start = None
  24. self.called = False
  25. @dataclass
  26. class Loop():
  27. start: int
  28. end_jumps: list = dataclasses.field(default_factory=list)
  29. """
  30. @dataclass
  31. class Variable():
  32. type: str
  33. name: str
  34. """
  35. class Compiler(c_ast.NodeVisitor):
  36. def __init__(self, opt_level=0):
  37. self.opt_level = opt_level
  38. self.functions: dict = None
  39. self.curr_function: Function = None
  40. self.globals: list = None
  41. self.loops: list = None
  42. self.loop_end: int = None
  43. def compile(self, filename: str):
  44. self.functions = {}
  45. self.curr_function = None
  46. self.globals = []
  47. self.loops = []
  48. self.loop_end = None
  49. ast = parse_file(
  50. filename,
  51. use_cpp=True,
  52. cpp_args=["-I", get_include_path()]
  53. )
  54. self.visit(ast)
  55. init_call = FunctionCall("main")
  56. if self.opt_level >= 2:
  57. preamble = [init_call]
  58. else:
  59. preamble = [Set("__retaddr_main", "2"), init_call, End()]
  60. offset = len(preamble)
  61. #set function starts
  62. for function in self.functions.values():
  63. function.start = offset
  64. offset += len(function.instructions)
  65. #rewrite relative jumps and func calls
  66. init_call.func_start = self.functions["main"].start
  67. for function in self.functions.values():
  68. instructions = function.instructions
  69. for instruction in instructions:
  70. if isinstance(instruction, RelativeJump):
  71. instruction.func_start = function.start
  72. elif isinstance(instruction, FunctionCall):
  73. instruction.func_start = self.functions[instruction.func_name].start
  74. elif isinstance(instruction, Set) and instruction.dest.startswith("__retaddr"):
  75. instruction.src += function.start
  76. out = ["\n".join(map(str, preamble))]
  77. out.extend(
  78. "\n".join(map(str, function.instructions)) for function in self.functions.values()
  79. )
  80. return "\n".join(out)
  81. #utilities
  82. def push(self, instruction: Instruction):
  83. self.curr_function.instructions.append(instruction)
  84. def pop(self):
  85. return self.curr_function.instructions.pop()
  86. def peek(self):
  87. return self.curr_function.instructions[-1]
  88. def curr_offset(self):
  89. return len(self.curr_function.instructions) - 1
  90. def get_varname(self, varname):
  91. if varname in self.curr_function.locals:
  92. return f"_{varname}_{self.curr_function.name}"
  93. elif varname not in self.globals:
  94. raise NameError(f"Unknown variable {varname}")
  95. return varname
  96. def can_avoid_indirection(self, var="__rax"):
  97. top = self.peek()
  98. return self.opt_level >= 1 and isinstance(top, Set) and top.dest == var
  99. def set_to_rax(self, varname: str):
  100. top = self.peek()
  101. if self.opt_level >= 1 and hasattr(top, "dest") and top.dest == "__rax":
  102. #avoid indirection through __rax
  103. self.curr_function.instructions[-1].dest = varname
  104. else:
  105. self.push(Set(varname, "__rax"))
  106. def push_body_jump(self):
  107. """ jump over loop/if body when cond is false """
  108. if self.opt_level >= 1 and isinstance(self.peek(), BinaryOp):
  109. try:
  110. self.push(RelativeJump(None, JumpCondition.from_binaryop(self.pop().inverse())))
  111. except KeyError:
  112. self.push(RelativeJump(None, JumpCondition("==", "__rax", "0")))
  113. else:
  114. self.push(RelativeJump(None, JumpCondition("==", "__rax", "0")))
  115. def start_loop(self, cond):
  116. self.loops.append(Loop(self.curr_offset() + 1))
  117. self.visit(cond)
  118. self.push_body_jump()
  119. self.loops[-1].end_jumps = [self.curr_offset()] # also used for breaks
  120. def end_loop(self):
  121. loop = self.loops.pop()
  122. self.push(RelativeJump(loop.start, JumpCondition.always))
  123. self.loop_end = self.curr_offset() + 1
  124. for offset in loop.end_jumps:
  125. self.curr_function.instructions[offset].offset = self.loop_end
  126. def push_ret(self):
  127. if self.opt_level >= 2 and self.curr_function.name == "main":
  128. top = self.peek()
  129. if isinstance(top, Set) and top.dest == "__rax":
  130. self.pop()
  131. self.push(End())
  132. else:
  133. self.push(Return(self.curr_function.name))
  134. def optimize_psuedofunc_args(self, args):
  135. if self.opt_level >= 1:
  136. for i, arg in reversed(list(enumerate(args))):
  137. if self.can_avoid_indirection(arg):
  138. args[i] = self.pop().src
  139. else:
  140. break
  141. return args
  142. def get_unary_builtin_arg(self, args):
  143. self.visit(args[0])
  144. if self.can_avoid_indirection():
  145. return self.pop().src
  146. else:
  147. return "__rax"
  148. def get_binary_builtin_args(self, args, name):
  149. left_name = f"__{name}_arg0"
  150. self.visit(args[0])
  151. self.set_to_rax(left_name)
  152. self.visit(args[1])
  153. left = left_name
  154. right = "__rax"
  155. if self.can_avoid_indirection():
  156. right = self.pop().src
  157. if self.can_avoid_indirection(left_name):
  158. left = self.pop().src
  159. return left, right
  160. def get_multiple_builtin_args(self, args, name):
  161. argnames = []
  162. for i, arg in enumerate(args):
  163. self.visit(arg)
  164. self.set_to_rax(f"__{name}_arg{i}")
  165. argnames.append(f"__{name}_arg{i}")
  166. return self.optimize_psuedofunc_args(argnames)
  167. #visitors
  168. def visit_FuncDef(self, node): # function definitions
  169. func_name = node.decl.name
  170. func_decl = node.decl.type
  171. params = [param_decl.name for param_decl in func_decl.args.params]
  172. self.curr_function = Function(func_name, params)
  173. self.visit(node.body)
  174. #implicit return
  175. #needed if loop/if body is at end of function or hasn't returned yet
  176. if self.loop_end == self.curr_offset() + 1 or not isinstance(self.peek(), Return):
  177. self.push(Set("__rax", "null"))
  178. self.push_ret()
  179. self.functions[func_name] = self.curr_function
  180. def visit_Decl(self, node):
  181. if isinstance(node.type, TypeDecl): # variable declaration
  182. #TODO fix local/global split
  183. varname = node.name
  184. if self.curr_function is None: # globals
  185. self.globals.append(varname)
  186. else:
  187. self.curr_function.locals.append(varname)
  188. varname = f"_{varname}_{self.curr_function.name}"
  189. if node.init is not None:
  190. self.visit(node.init)
  191. self.set_to_rax(varname)
  192. elif isinstance(node.type, FuncDecl):
  193. if node.name not in builtins + func_unary_ops + func_binary_ops:
  194. #create placeholder function for forward declarations
  195. self.functions[node.name] = Function(
  196. node.name, [param_decl.name for param_decl in node.type.args.params]
  197. )
  198. elif isinstance(node.type, Struct):
  199. if node.type.name != "MindustryObject":
  200. #TODO structs
  201. raise NotImplementedError(node)
  202. elif isinstance(node.type, Enum):
  203. #TODO enums
  204. raise NotImplementedError(node)
  205. else:
  206. raise NotImplementedError(node)
  207. def visit_Assignment(self, node):
  208. self.visit(node.rvalue)
  209. varname = self.get_varname(node.lvalue.name)
  210. if node.op == "=": #normal assignment
  211. self.set_to_rax(varname)
  212. else: #augmented assignment(+=,-=,etc)
  213. if self.can_avoid_indirection():
  214. #avoid indirection through __rax
  215. self.push(BinaryOp(varname, varname, self.pop().src, node.op[:-1]))
  216. else:
  217. self.push(BinaryOp(varname, varname, "__rax", node.op[:-1]))
  218. if self.opt_level < 2:
  219. self.push(Set("__rax", varname))
  220. def visit_Constant(self, node): # literals
  221. self.push(Set("__rax", node.value))
  222. def visit_ID(self, node): # identifier
  223. varname = node.name
  224. if varname not in self.functions:
  225. varname = self.get_varname(varname)
  226. self.push(Set("__rax", varname))
  227. def visit_BinaryOp(self, node):
  228. self.visit(node.left)
  229. self.set_to_rax("__rbx")
  230. self.visit(node.right)
  231. left = "__rbx"
  232. right = "__rax"
  233. if self.can_avoid_indirection():
  234. right = self.pop().src
  235. if self.can_avoid_indirection("__rbx"):
  236. left = self.pop().src
  237. self.push(BinaryOp("__rax", left, right, node.op))
  238. def visit_UnaryOp(self, node):
  239. if node.op == "p++" or node.op == "p--": #postincrement/decrement
  240. varname = self.get_varname(node.expr.name)
  241. if self.opt_level < 2:
  242. self.push(Set("__rax", varname))
  243. self.push(BinaryOp(varname, varname, "1", node.op[1]))
  244. elif node.op == "++" or node.op == "--":
  245. varname = self.get_varname(node.expr.name)
  246. self.push(BinaryOp(varname, varname, "1", node.op[0]))
  247. if self.opt_level < 2:
  248. self.push(Set("__rax", varname))
  249. elif node.op == "!":
  250. self.visit(node.expr)
  251. if self.opt_level >= 1 and isinstance(self.peek(), BinaryOp):
  252. try:
  253. self.push(self.pop().inverse())
  254. except KeyError:
  255. self.push(BinaryOp("__rax", "__rax", "0", "=="))
  256. else:
  257. self.push(BinaryOp("__rax", "__rax", "0", "=="))
  258. else:
  259. self.visit(node.expr)
  260. self.push(UnaryOp("__rax", "__rax", node.op))
  261. def visit_For(self, node):
  262. self.visit(node.init)
  263. self.start_loop(node.cond)
  264. self.visit(node.stmt) # loop body
  265. self.visit(node.next)
  266. self.end_loop()
  267. def visit_While(self, node):
  268. self.start_loop(node.cond)
  269. self.visit(node.stmt)
  270. self.end_loop()
  271. def visit_DoWhile(self, node):
  272. #jump over the condition on the first iterattion
  273. self.push(RelativeJump(None, JumpCondition.always))
  274. init_jump_offset = self.curr_offset()
  275. self.start_loop(node.cond)
  276. self.curr_function.instructions[init_jump_offset].offset = len(
  277. self.curr_function.instructions
  278. )
  279. self.visit(node.stmt)
  280. self.end_loop()
  281. def visit_If(self, node):
  282. self.visit(node.cond)
  283. self.push_body_jump()
  284. cond_jump_offset = self.curr_offset()
  285. self.visit(node.iftrue)
  286. #jump over else body from end of if body
  287. if node.iffalse is not None:
  288. self.push(RelativeJump(None, JumpCondition.always))
  289. cond_jump_offset2 = self.curr_offset()
  290. self.curr_function.instructions[cond_jump_offset].offset = len(
  291. self.curr_function.instructions
  292. )
  293. if node.iffalse is not None:
  294. self.visit(node.iffalse)
  295. self.curr_function.instructions[cond_jump_offset2].offset = len(
  296. self.curr_function.instructions
  297. )
  298. def visit_Break(self, node): #pylint: disable=unused-argument
  299. self.push(RelativeJump(None, JumpCondition.always))
  300. self.loops[-1].end_jumps.append(self.curr_offset())
  301. def visit_Continue(self, node): #pylint: disable=unused-argument
  302. self.push(RelativeJump(self.loops[-1].start, JumpCondition.always))
  303. def visit_Return(self, node):
  304. self.visit(node.expr)
  305. self.push_ret()
  306. def visit_FuncCall(self, node):
  307. name = node.name.name
  308. if node.args is not None:
  309. args = node.args.exprs
  310. else:
  311. args = []
  312. #TODO avoid duplication in builtin calls
  313. builtins_dict = {
  314. "print": Print,
  315. "printd": Print,
  316. "printflush": PrintFlush,
  317. "enable": Enable,
  318. "shoot": Shoot,
  319. "get_link": GetLink,
  320. "read": lambda cell, index: Read("__rax", cell, index),
  321. "write": Write,
  322. "drawflush": DrawFlush
  323. }
  324. if name in builtins_dict:
  325. self.push(builtins_dict[name](*self.get_multiple_builtin_args(args, name)))
  326. elif name == "asm":
  327. arg = args[0]
  328. if not isinstance(arg, Constant) or arg.type != "string":
  329. raise TypeError("Non-string argument to asm", node)
  330. self.push(RawAsm(arg.value[1:-1]))
  331. elif name == "radar":
  332. argnames = []
  333. for i, arg in enumerate(args):
  334. if 1 <= i <= 4:
  335. if not isinstance(arg, Constant) or arg.type != "string":
  336. raise TypeError("Non-string argument to radar", node)
  337. self.push(Set("__rax", arg.value[1:-1]))
  338. else:
  339. self.visit(arg)
  340. self.set_to_rax(f"__radar_arg{i}")
  341. argnames.append(f"__radar_arg{i}")
  342. argnames = self.optimize_psuedofunc_args(argnames)
  343. self.push(Radar("__rax", *argnames)) #pylint: disable=no-value-for-parameter
  344. elif name == "sensor":
  345. self.visit(args[0])
  346. self.set_to_rax("__sensor_arg0")
  347. arg = args[1]
  348. if not isinstance(arg, Constant) or arg.type != "string":
  349. raise TypeError("Non-string argument to sensor", node)
  350. self.push(Set("__rax", arg.value[1:-1]))
  351. left = "__sensor_arg0"
  352. right = "__rax"
  353. if self.can_avoid_indirection():
  354. right = self.pop().src
  355. if self.can_avoid_indirection("__sensor_arg0"):
  356. left = self.pop().src
  357. self.push(Sensor("__rax", left, right))
  358. elif name == "end":
  359. self.push(End())
  360. elif name in draw_funcs:
  361. argnames = self.get_multiple_builtin_args(args, name)
  362. cmd = draw_funcs[name]
  363. self.push(Draw(cmd, *argnames))
  364. elif name in func_binary_ops:
  365. left, right = self.get_binary_builtin_args(args, "binary")
  366. self.push(BinaryOp("__rax", left, right, name))
  367. elif name in func_unary_ops:
  368. self.visit(args[0])
  369. if self.can_avoid_indirection():
  370. self.push(UnaryOp("__rax", self.pop().src, name))
  371. else:
  372. self.push(UnaryOp("__rax", "__rax", name))
  373. else:
  374. try:
  375. func = self.functions[name]
  376. except KeyError:
  377. raise ValueError(f"{name} is not a function")
  378. for param, arg in zip(func.params, args):
  379. self.visit(arg)
  380. self.set_to_rax(f"_{param}_{name}")
  381. self.push(Set("__retaddr_" + name, self.curr_offset() + 3))
  382. self.push(FunctionCall(name))
  383. def generic_visit(self, node):
  384. if isinstance(node, (FileAST, Compound, DeclList)):
  385. super().generic_visit(node)
  386. else:
  387. raise NotImplementedError(node)
  388. def get_include_path():
  389. if os.name == "posix":
  390. return sysconfig.get_path("include", "posix_user")
  391. elif os.name == "nt":
  392. return sysconfig.get_path("include", "nt")
  393. else:
  394. raise ValueError(f"Unknown os {os.name}")
  395. def main():
  396. import argparse
  397. parser = argparse.ArgumentParser()
  398. parser.add_argument("file")
  399. parser.add_argument("-O", "--optimization-level", type=int, choices=range(3), default=1)
  400. parser.add_argument("-o", "--output", type=argparse.FileType('w'), default="-")
  401. args = parser.parse_args()
  402. print(Compiler(args.optimization_level).compile(args.file), file=args.output)
  403. if __name__ == "__main__":
  404. main()