#------------------------------------------------------------------------------ # pycparser: cython_generator.py # # Cython code generator from pycparser AST nodes. # # Copyright (C) 2008-2012, Eli Bendersky and Brett Hartshorn # License: BSD #------------------------------------------------------------------------------ from . import c_ast CLASSIFY = True MARKER = '#######$$$$$' class CythonGenerator(object): """ Uses the same visitor pattern as c_ast.NodeVisitor, but modified to return a value from each visit method, using string accumulation in generic_visit. """ def __init__(self): self.output = '' # Statements start with indentation of self.indent_level spaces, using # the _make_indent method # self.indent_level = 0 self.classes = {} # dict of class (typedef) names of generated classes : [] def visit_FileAST(self, n): s = '' for ext in n.ext: if isinstance(ext, c_ast.FuncDef): s += self.visit(ext) else: s += self.visit(ext) + ';\n' for class_name in self.classes: funcs = self.classes[ class_name ] a = [] for func in funcs: for i,line in enumerate(func.splitlines()): a.append( ' '+line ) s = s.replace(MARKER+class_name, '\n'.join(a)) result = [] ## clean up for line in s.splitlines(): if line.endswith(';'): line = line[:-1] if line.strip(): result.append( line ) return '\n'.join( result ) def _make_indent(self): return ' ' * self.indent_level def visit(self, node, cdef=True): method = 'visit_' + node.__class__.__name__ print(method) if method in 'visit_Decl visit_ParamList'.split(): return getattr(self, method, self.generic_visit)(node, cdef=cdef) else: return getattr(self, method, self.generic_visit)(node) def generic_visit(self, node): #~ print('generic:', type(node)) if node is None: return '' else: return ''.join(self.visit(c) for c in node.children()) def visit_Constant(self, n): return n.value def visit_ID(self, n): return n.name def visit_ArrayRef(self, n): arrref = self._parenthesize_unless_simple(n.name) return arrref + '[' + self.visit(n.subscript) + ']' def visit_StructRef(self, n): sref = self._parenthesize_unless_simple(n.name) if CLASSIFY: n.type = '.' # TODO check if struct is in self.classes or 'self' return sref + n.type + self.visit(n.field) def visit_FuncCall(self, n): fref = self._parenthesize_unless_simple(n.name) if fref == 'printf': fref = 'print' return fref + '(' + self.visit(n.args) + ')' def visit_UnaryOp(self, n): operand = self._parenthesize_unless_simple(n.expr) if n.op == 'p++': return '%s += 1' % operand elif n.op == 'p--': return '%s -= 1' % operand elif n.op == 'sizeof': # Always parenthesize the argument of sizeof since it can be # a name. return 'sizeof(%s)' % self.visit(n.expr) else: return '%s%s' % (n.op, operand) def visit_BinaryOp(self, n): lval_str = self._parenthesize_if(n.left, lambda d: not self._is_simple_node(d)) rval_str = self._parenthesize_if(n.right, lambda d: not self._is_simple_node(d)) return '%s %s %s' % (lval_str, n.op, rval_str) def visit_Assignment(self, n): rval_str = self._parenthesize_if( n.rvalue, lambda n: isinstance(n, c_ast.Assignment)) return '%s %s %s' % (self.visit(n.lvalue), n.op, rval_str) def visit_IdentifierType(self, n): return ' '.join(n.names) def visit_Decl(self, n, no_type=False, cdef=True): # no_type is used when a Decl is part of a DeclList, where the type is # explicitly only for the first delaration in a list. # s = n.name if no_type else self._generate_decl(n) print('visit_Decl', s) n.show() print( n.type, type(n.type) ) if not isinstance(n.type, (c_ast.Enum, c_ast.Struct)): s = 'cdef ' + s if n.bitsize: s += ' : ' + self.visit(n.bitsize) if n.init: if isinstance(n.init, c_ast.InitList): s += ' = {' + self.visit(n.init) + '}' elif isinstance(n.init, c_ast.ExprList): s += ' = (' + self.visit(n.init) + ')' else: s += ' = ' + self.visit(n.init) return s def visit_DeclList(self, n): s = self.visit(n.decls[0]) if len(n.decls) > 1: s += ', ' + ', '.join(self.visit_Decl(decl, no_type=True) for decl in n.decls[1:]) return s def visit_Typedef(self, n): a = self._generate_type(n.type) a = '\n'.join( a.splitlines()[:-1] ) # remove name at end. a = a.replace(';', '') if n.storage: s = ' '.join(n.storage) + ' ' s = s.replace('typedef ', 'ctypedef ') simple = True sname = None tname = None classify = False if isinstance(n.type, c_ast.TypeDecl): if isinstance(n.type.type, c_ast.Struct): sname = n.type.type.name tname = n.type.declname if CLASSIFY: classify = True if not n.type.type.name: simple = True elif n.type.declname == n.type.type.name: simple = True elif n.type.declname != n.type.type.name: simple = False if classify: return self.struct_to_class(a, sname, tname) elif simple: if a.startswith('cdef '): a = a[5:] return s + a else: return a + '\nctypedef %s %s\n'%(n.type.type.name, n.type.declname) else: return a def struct_to_class(self, txt, struct_name, type_name): head = txt.splitlines()[0] body = txt.splitlines()[1:] head = head.replace(' struct ', ' class ') if not struct_name: head = head.replace(':', '%s:'%type_name) elif type_name != struct_name: head = head.replace('%s:'%struct_name, '%s:'%type_name) lines = [ head ] args = ['self'] # no need to type self init = [] for line in body: a = line.strip() lines.append( ' cdef public %s'%a) args.append( a ) n = a.split()[-1] init.append( ' self.%s = %s'%(n,n) ) ## __init__ requires "def" not "cdef" lines.append( ' def __init__(%s):'%(','.join(args)) ) lines.extend( init ) lines.append(MARKER+type_name) assert type_name not in self.classes self.classes[ type_name ] = [] # list of methods return '\n'.join( lines ) def visit_Cast(self, n): s = '(' + self._generate_type(n.to_type) + ')' return s + ' ' + self._parenthesize_unless_simple(n.expr) def visit_ExprList(self, n): visited_subexprs = [] for expr in n.exprs: if isinstance(expr, c_ast.ExprList): visited_subexprs.append('{' + self.visit(expr) + '}') else: visited_subexprs.append(self.visit(expr)) return ', '.join(visited_subexprs) def visit_InitList(self, n): visited_subexprs = [] for expr in n.exprs: if isinstance(expr, c_ast.InitList): visited_subexprs.append('(' + self.visit(expr) + ')') else: visited_subexprs.append(self.visit(expr)) return ', '.join(visited_subexprs) def visit_Enum(self, n): s = '# enum' if n.name: s += ' ' + n.name items = [] if n.values: for i, enumerator in enumerate(n.values.enumerators): if enumerator.value: items.append( 'cdef int '+enumerator.name + ' = ' + self.visit(enumerator.value) ) else: items.append( 'cdef int '+enumerator.name + ' = %s'%i) return s + '\n' + '\n'.join( items ) def visit_FuncDef(self, n): print('visit funcdef', n.decl) classify = False funcdecl = n.decl.type if CLASSIFY and funcdecl.args: arg = funcdecl.args.params[0] if isinstance(arg, c_ast.Decl) and isinstance(arg.type, c_ast.PtrDecl): ptr = arg.type if isinstance( ptr.type, c_ast.TypeDecl) and isinstance(ptr.type.type, c_ast.IdentifierType): name = ptr.type.type.names[0] if name in self.classes: self.function_to_method( n, name, arg.name ) classify = (name, arg.name) decl = self.visit(n.decl) + ':' decl = 'cpdef ' + decl.replace('cdef ', '') ## to keep things simple just remove "cdef " from parameters self.indent_level = 0 body = self.visit(n.body) if n.param_decls: knrdecls = ';\n'.join(self.visit(p) for p in n.param_decls) result = decl + '\n' + knrdecls + ';\n' + body + '\n' else: result = decl + '\n' + body + '\n' if classify: cname, vname = classify result = result.replace('*'+vname, '') self.classes[ cname ].append( result ) return '' else: return result def function_to_method(self, n, class_name, var_name): print( 'n', n, class_name, var_name) for a in n.children(): tag, child = a print('child', child) if isinstance(child, c_ast.IdentifierType): if len(child.names)==1 and child.names[0] == class_name: child.names[0] = 'self' elif isinstance(child, c_ast.ID) and child.name == var_name: child.name = 'self' else: self.function_to_method( child, class_name, var_name ) def visit_Compound(self, n): #s = self._make_indent() + '{\n' s = self._make_indent() + '\n' self.indent_level += 2 if n.block_items: s += ''.join(self._generate_stmt(stmt) for stmt in n.block_items) self.indent_level -= 2 #s += self._make_indent() + '}\n' s += self._make_indent() + '\n' return s def visit_EmptyStatement(self, n): return ';' def visit_ParamList(self, n, cdef=False): return ', '.join(self.visit(param, cdef=cdef) for param in n.params) def visit_Return(self, n): s = 'return' if n.expr: s += ' ' + self.visit(n.expr) return s + ';' def visit_Break(self, n): return 'break;' def visit_Continue(self, n): return 'continue;' def visit_TernaryOp(self, n): s = self.visit(n.cond) + ' ? ' s += self.visit(n.iftrue) + ' : ' s += self.visit(n.iffalse) return s def visit_If(self, n): s = 'if (' if n.cond: s += self.visit(n.cond) s += '):\n' s += self._generate_stmt(n.iftrue, add_indent=True) if n.iffalse: s += self._make_indent() + 'else:\n' s += self._generate_stmt(n.iffalse, add_indent=True) return s def is_simple_for_loop(self, n): if isinstance(n.init, c_ast.Assignment) and n.init.op=='=' and isinstance(n.init.lvalue, c_ast.ID): if isinstance( n.init.rvalue, c_ast.Constant) and isinstance( n.cond, c_ast.BinaryOp): if isinstance( n.next, c_ast.UnaryOp ) and n.next.op == 'p++': return True return False def visit_For(self, n): if self.is_simple_for_loop( n ): var = n.init.lvalue.name start = n.init.rvalue.value s = 'for %s from %s <= %s'%(var,start,var) s += ' %s %s:\n'%(n.cond.op, n.cond.right.name) else: ## not cython yet TODO s = 'for (' if n.init: s += self.visit(n.init) s += ';' if n.cond: s += ' ' + self.visit(n.cond) s += ';' if n.next: s += ' ' + self.visit(n.next) s += '):\n' s += self._generate_stmt(n.stmt, add_indent=True) return s def visit_While(self, n): s = 'while (' if n.cond: s += self.visit(n.cond) s += '):\n' s += self._generate_stmt(n.stmt, add_indent=True) return s def visit_DoWhile(self, n): s = 'do\n' s += self._generate_stmt(n.stmt, add_indent=True) s += self._make_indent() + 'while (' if n.cond: s += self.visit(n.cond) s += ');' return s def visit_Switch(self, n): s = 'switch (' + self.visit(n.cond) + ')\n' s += self._generate_stmt(n.stmt, add_indent=True) return s def visit_Case(self, n): s = 'case ' + self.visit(n.expr) + ':\n' for stmt in n.stmts: s += self._generate_stmt(stmt, add_indent=True) return s def visit_Default(self, n): s = 'default:\n' for stmt in n.stmts: s += self._generate_stmt(stmt, add_indent=True) return s def visit_Label(self, n): return n.name + ':\n' + self._generate_stmt(n.stmt) def visit_Goto(self, n): return 'goto ' + n.name + ';' def visit_EllipsisParam(self, n): return '...' def visit_Struct(self, n): return self._generate_struct_union(n, 'struct') def visit_Typename(self, n): return self._generate_type(n.type) def visit_Union(self, n): return self._generate_struct_union(n, 'union') def visit_NamedInitializer(self, n): s = '' for name in n.name: if isinstance(name, c_ast.ID): s += '.' + name.name elif isinstance(name, c_ast.Constant): s += '[' + name.value + ']' s += ' = ' + self.visit(n.expr) return s def _generate_struct_union(self, n, name): """ Generates code for structs and unions. name should be either 'struct' or union. """ a = 'cdef ' + name + ' ' + (n.name or '') + ':' if n.decls: s = '\n' s += self._make_indent() self.indent_level += 2 for decl in n.decls: s += self._generate_stmt(decl) self.indent_level -= 2 s += self._make_indent() return a + s.replace('cdef ', '') def _generate_struct_union_as_class(self, n, name): """ Generates code for structs and unions. name should be either 'struct' or union. """ assert n.name ## TODO unnamed structs s = 'cdef class %s:'%n.name if n.decls: s += '\n' s += self._make_indent() self.indent_level += 2 for decl in n.decls: s += self._generate_stmt(decl) self.indent_level -= 2 s += self._make_indent() return s def _generate_stmt(self, n, add_indent=False): """ Generation from a statement node. This method exists as a wrapper for individual visit_* methods to handle different treatment of some statements in this context. """ typ = type(n) if add_indent: self.indent_level += 2 indent = self._make_indent() if add_indent: self.indent_level -= 2 if typ in ( c_ast.Decl, c_ast.Assignment, c_ast.Cast, c_ast.UnaryOp, c_ast.BinaryOp, c_ast.TernaryOp, c_ast.FuncCall, c_ast.ArrayRef, c_ast.StructRef, c_ast.Constant, c_ast.ID, c_ast.Typedef): # These can also appear in an expression context so no semicolon # is added to them automatically # return indent + self.visit(n) + ';\n' elif typ in (c_ast.Compound,): # No extra indentation required before the opening brace of a # compound - because it consists of multiple lines it has to # compute its own indentation. # return self.visit(n) else: return indent + self.visit(n) + '\n' def _generate_decl(self, n): """ Generation from a Decl node. """ s = '' if n.funcspec: s = ' '.join(n.funcspec) + ' ' if n.storage: s += ' '.join(n.storage) + ' ' s += self._generate_type(n.type) return s def _generate_type(self, n, modifiers=[]): """ Recursive generation from a type node. n is the type node. modifiers collects the PtrDecl, ArrayDecl and FuncDecl modifiers encountered on the way down to a TypeDecl, to allow proper generation from it. """ typ = type(n) #~ print(n, modifiers) if typ == c_ast.TypeDecl: s = '' if n.quals: s += ' '.join(n.quals) + ' ' s += self.visit(n.type) nstr = n.declname if n.declname else '' # Resolve modifiers. # Wrap in parens to distinguish pointer to array and pointer to # function syntax. # for i, modifier in enumerate(modifiers): if isinstance(modifier, c_ast.ArrayDecl): if (i != 0 and isinstance(modifiers[i - 1], c_ast.PtrDecl)): nstr = '(' + nstr + ')' nstr += '[' + self.visit(modifier.dim) + ']' elif isinstance(modifier, c_ast.FuncDecl): if (i != 0 and isinstance(modifiers[i - 1], c_ast.PtrDecl)): nstr = '(' + nstr + ')' nstr += '(' + self.visit(modifier.args) + ')' elif isinstance(modifier, c_ast.PtrDecl): if modifier.quals: nstr = '* %s %s' % (' '.join(modifier.quals), nstr) else: nstr = '*' + nstr if nstr: s += ' ' + nstr return s elif typ == c_ast.Decl: return self._generate_decl(n.type) elif typ == c_ast.Typename: return self._generate_type(n.type) elif typ == c_ast.IdentifierType: return ' '.join(n.names) + ' ' elif typ in (c_ast.ArrayDecl, c_ast.PtrDecl, c_ast.FuncDecl): return self._generate_type(n.type, modifiers + [n]) else: return self.visit(n) def _parenthesize_if(self, n, condition): """ Visits 'n' and returns its string representation, parenthesized if the condition function applied to the node returns True. """ s = self.visit(n) if condition(n): return '(' + s + ')' else: return s def _parenthesize_unless_simple(self, n): """ Common use case for _parenthesize_if """ return self._parenthesize_if(n, lambda d: not self._is_simple_node(d)) def _is_simple_node(self, n): """ Returns True for nodes that are "simple" - i.e. nodes that always have higher precedence than operators. """ return isinstance(n,( c_ast.Constant, c_ast.ID, c_ast.ArrayRef, c_ast.StructRef, c_ast.FuncCall))