Skip to content

Commit

Permalink
Simplify extract_metadata.py by moving logic in tools/webassembly.py.…
Browse files Browse the repository at this point in the history
… NFC (emscripten-core#17255)

This creates a caching layer and some extra helper functions on the
module object which avoid the need to track start such as the number
of imports elements of a given type.
  • Loading branch information
sbc100 authored Jun 18, 2022
1 parent a8a5f77 commit 1706b8e
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 19 deletions.
33 changes: 14 additions & 19 deletions tools/extract_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def find_segment_with_address(module, address, size=0):
if seg.size == size:
return (seg, 0)

raise AssertionError('unable to find segment for address: %s' % address)


def data_to_string(data):
data = data.decode('utf8')
Expand All @@ -82,14 +84,14 @@ def data_to_string(data):
return data


def get_asm_strings(module, globls, export_map, imported_globals):
def get_asm_strings(module, export_map):
if '__start_em_asm' not in export_map or '__stop_em_asm' not in export_map:
return {}

start = export_map['__start_em_asm']
end = export_map['__stop_em_asm']
start_global = globls[start.index - imported_globals]
end_global = globls[end.index - imported_globals]
start_global = module.get_global(start.index)
end_global = module.get_global(end.index)
start_addr = get_global_value(start_global)
end_addr = get_global_value(end_global)

Expand All @@ -110,29 +112,28 @@ def get_asm_strings(module, globls, export_map, imported_globals):
return asm_strings


def get_main_reads_params(module, export_map, imported_funcs):
def get_main_reads_params(module, export_map):
if settings.STANDALONE_WASM:
return 1

main = export_map.get('main') or export_map.get('__main_argc_argv')
if not main or main.kind != webassembly.ExternType.FUNC:
return 0

functions = module.get_functions()
main_func = functions[main.index - imported_funcs]
main_func = module.get_function(main.index)
if is_wrapper_function(module, main_func):
return 0
else:
return 1


def get_names_globals(globls, exports, imported_globals):
def get_named_globals(module, exports):
named_globals = {}
for export in exports:
if export.kind == webassembly.ExternType.GLOBAL:
if export.name in ('__start_em_asm', '__stop_em_asm') or export.name.startswith('__em_js__'):
continue
g = globls[export.index - imported_globals]
g = module.get_global(export.index)
named_globals[export.name] = str(get_global_value(g))
return named_globals

Expand Down Expand Up @@ -167,26 +168,20 @@ def extract_metadata(filename):
export_names = []
declares = []
invoke_funcs = []
imported_funcs = 0
imported_globals = 0
global_imports = []
em_js_funcs = {}
exports = module.get_exports()
imports = module.get_imports()
globls = module.get_globals()

for i in imports:
if i.kind == webassembly.ExternType.FUNC:
imported_funcs += 1
elif i.kind == webassembly.ExternType.GLOBAL:
imported_globals += 1
if i.kind == webassembly.ExternType.GLOBAL:
global_imports.append(i.field)

export_map = {e.name: e for e in exports}
for e in exports:
if e.kind == webassembly.ExternType.GLOBAL and e.name.startswith('__em_js__'):
name = e.name[len('__em_js__'):]
globl = globls[e.index - imported_globals]
globl = module.get_global(e.index)
string_address = get_global_value(globl)
em_js_funcs[name] = get_string_at(module, string_address)

Expand All @@ -208,14 +203,14 @@ def extract_metadata(filename):
# If main does not read its parameters, it will just be a stub that
# calls __original_main (which has no parameters).
metadata = {}
metadata['asmConsts'] = get_asm_strings(module, globls, export_map, imported_globals)
metadata['asmConsts'] = get_asm_strings(module, export_map)
metadata['declares'] = declares
metadata['emJsFuncs'] = em_js_funcs
metadata['exports'] = export_names
metadata['features'] = features
metadata['globalImports'] = global_imports
metadata['invokeFuncs'] = invoke_funcs
metadata['mainReadsParams'] = get_main_reads_params(module, export_map, imported_funcs)
metadata['namedGlobals'] = get_names_globals(globls, exports, imported_globals)
metadata['mainReadsParams'] = get_main_reads_params(module, export_map)
metadata['namedGlobals'] = get_named_globals(module, exports)
# print("Metadata parsed: " + pprint.pformat(metadata))
return metadata
66 changes: 66 additions & 0 deletions tools/webassembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,32 @@ def read_sleb(iobuf):
return leb128.i.decode_reader(iobuf)[0]


# TODO(sbc): Use the builtin functools.cache once we update to python 3.9
def cache(f):
results = {}

def helper(*args, **kwargs):
assert not kwargs
key = args
if key not in results:
results[key] = f(*args, **kwargs)
return results[key]

return helper


def once(f):
done = False

def helper(*args, **kwargs):
nonlocal done
if not done:
done = True
f(*args, **kwargs)

return helper


class Type(IntEnum):
I32 = 0x7f # -0x1
I64 = 0x7e # -0x2
Expand Down Expand Up @@ -141,6 +167,7 @@ def __init__(self, filename):
version = self.buf.read(4)
if magic != MAGIC or version != VERSION:
raise InvalidWasmError(f'{filename} is not a valid wasm file')
self._done_calc_indexes = False

def __del__(self):
if self.buf:
Expand Down Expand Up @@ -250,6 +277,7 @@ def parse_features_section(self):
feature_count -= 1
return features

@cache
def parse_dylink_section(self):
dylink_section = next(self.sections())
assert dylink_section.type == SecType.CUSTOM
Expand Down Expand Up @@ -314,6 +342,7 @@ def parse_dylink_section(self):

return Dylink(mem_size, mem_align, table_size, table_align, needed, export_info, import_info)

@cache
def get_exports(self):
export_section = self.get_section(SecType.EXPORT)
if not export_section:
Expand All @@ -330,6 +359,7 @@ def get_exports(self):

return exports

@cache
def get_imports(self):
import_section = self.get_section(SecType.IMPORT)
if not import_section:
Expand Down Expand Up @@ -362,6 +392,7 @@ def get_imports(self):

return imports

@cache
def get_globals(self):
global_section = self.get_section(SecType.GLOBAL)
if not global_section:
Expand All @@ -376,6 +407,7 @@ def get_globals(self):
globls.append(Global(global_type, mutable, init))
return globls

@cache
def get_functions(self):
code_section = self.get_section(SecType.CODE)
if not code_section:
Expand All @@ -393,12 +425,14 @@ def get_functions(self):
def get_section(self, section_code):
return next((s for s in self.sections() if s.type == section_code), None)

@cache
def get_custom_section(self, name):
for section in self.sections():
if section.type == SecType.CUSTOM and section.name == name:
return section
return None

@cache
def get_segments(self):
segments = []
data_section = self.get_section(SecType.DATA)
Expand All @@ -416,6 +450,7 @@ def get_segments(self):
self.seek(offset + size)
return segments

@cache
def get_tables(self):
table_section = self.get_section(SecType.TABLE)
if not table_section:
Expand All @@ -434,6 +469,37 @@ def get_tables(self):
def has_name_section(self):
return self.get_custom_section('name') is not None

@once
def _calc_indexes(self):
self.num_imported_funcs = 0
self.num_imported_globals = 0
self.num_imported_memories = 0
self.num_imported_tables = 0
self.num_imported_tags = 0
for i in self.get_imports():
if i.kind == ExternType.FUNC:
self.num_imported_funcs += 1
elif i.kind == ExternType.GLOBAL:
self.num_imported_globals += 1
elif i.kind == ExternType.MEMORY:
self.num_imported_memories += 1
elif i.kind == ExternType.TABLE:
self.num_imported_tables += 1
elif i.kind == ExternType.TAG:
self.num_imported_tags += 1
else:
assert False, 'unhandled export type: %s' % i.kind

def get_function(self, idx):
self._calc_indexes()
assert idx >= self.num_imported_funcs
return self.get_functions()[idx - self.num_imported_funcs]

def get_global(self, idx):
self._calc_indexes()
assert idx >= self.num_imported_globals
return self.get_globals()[idx - self.num_imported_globals]


def parse_dylink_section(wasm_file):
module = Module(wasm_file)
Expand Down

0 comments on commit 1706b8e

Please sign in to comment.