Skip to content

Commit

Permalink
Extract enum info from the ast and store it in HeaderInfo
Browse files Browse the repository at this point in the history
  • Loading branch information
natezb committed Feb 22, 2019
1 parent 9ae0ce8 commit efae58a
Showing 1 changed file with 54 additions and 3 deletions.
57 changes: 54 additions & 3 deletions nicelib/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -1240,6 +1240,55 @@ def visit_PtrDecl(self, node):
return self.visit(node.type)


class EnumGrabber(c_ast.NodeVisitor):
def __init__(self):
self.enum_map = {}
self.outer_typedef = None

def run(self, root):
self.visit(root)
return list(self.enum_map.values())

def find_existing_enum(self, enum_node):
"""Return seen enum node that's either identical or shares the same tag name"""
for existing_enum_node, einfo in self.enum_map.items():
ee = existing_enum_node
ne = enum_node
if ee is ne or (ee.name and ne.name and ee.name == ne.name):
return existing_enum_node
return None

def visit_Typedef(self, node):
self.outer_typedef = node
self.visit(node.type)
self.outer_typedef = None

def visit_Enum(self, node):
existing_enum = self.find_existing_enum(node)
if existing_enum:
if self.outer_typedef:
existing_enum_info = self.enum_map[existing_enum]
existing_enum_info.typedef_names.append(self.outer_typedef.name)
else:
typedef_names = ([self.outer_typedef.name] if self.outer_typedef
else [])
value_names = (None if node.values is None
else [e.name for e in node.values.enumerators])
enum_info = EnumInfo(node.name, typedef_names, value_names)
self.enum_map[node] = enum_info


class EnumInfo(object):
def __init__(self, tag_name, typedef_names, value_names):
self.tag_name = tag_name
self.typedef_names = typedef_names
self.value_names = value_names

def __repr__(self):
return '<EnumInfo(tag_name={!r}, typedef_names={!r}, value_names={!r})>'.format(
self.tag_name, self.typedef_names, self.value_names)


class Generator(object):
def __init__(self, tokens, macros, macro_expand, token_hooks=(), string_hooks=(), ast_hooks=(),
debug_file=None):
Expand Down Expand Up @@ -1390,6 +1439,8 @@ def get_ext_chunks(chunk_tokens):
argname_grabber.visit(self.tree)
argnames = argname_grabber.argnames

einfo_list = EnumGrabber().run(self.tree)

# Generate cleaned C source
generator = cpp_generator.CPPGenerator()
header_src = generator.visit(self.tree)
Expand Down Expand Up @@ -1417,7 +1468,7 @@ def get_ext_chunks(chunk_tokens):
else:
macro_src.write("defs['{}'] = {}\n".format(macro.name, py_src))

return HeaderInfo(header_src, macro_src.getvalue(), self.tree, argnames)
return HeaderInfo(header_src, macro_src.getvalue(), self.tree, argnames, einfo_list)

def gen_py_src(self, macro):
if isinstance(macro, FuncMacro):
Expand Down Expand Up @@ -1796,12 +1847,12 @@ def process_source(source, predef_path=None, update_cb=None, ignored_headers=(),


class HeaderInfo(object):
def __init__(self, header_src, macro_src, ast, argnames):
def __init__(self, header_src, macro_src, ast, argnames, enums):
self.header_src = header_src
self.macro_src = macro_src
self.ast = ast
self.argnames = argnames
self.enums = None
self.enums = enums


def generate_bindings(header_info, outfile, prefix=(), add_ret_ignore=False, niceobj_prefix={},
Expand Down

0 comments on commit efae58a

Please sign in to comment.