diff --git a/sqlelf/elf.py b/sqlelf/elf.py index 3a11ff5..734c0cd 100644 --- a/sqlelf/elf.py +++ b/sqlelf/elf.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from enum import Flag, auto -from typing import Any, Callable, Iterator, Optional, Sequence, cast +from typing import Any, Callable, Iterator, Sequence, cast import apsw import apsw.ext @@ -36,7 +36,7 @@ def make_generator( return Generator(columns, apsw.ext.VTColumnAccess.By_Name, generator) -class GeneratorFlag(Flag): +class CacheFlag(Flag): NONE = 0 DYNAMIC_ENTRIES = auto() HEADERS = auto() @@ -48,28 +48,43 @@ class GeneratorFlag(Flag): VERSION_DEFINITIONS = auto() @classmethod - def ALL(cls: type[GeneratorFlag]) -> GeneratorFlag: + def ALL(cls: type[CacheFlag]) -> CacheFlag: retval = cls.NONE for member in cls.__members__.values(): retval |= member return retval -@dataclass -class MakeGeneratorResponse: - """A response from a generator factory. +def register_generator( + connection: apsw.Connection, + generator: Generator, + table_name: str, + generator_flag: CacheFlag, + cache_flags: CacheFlag, +) -> None: + """Register a virtual table generator. - Contains everything needed to register the virtual table.""" + This method does a bit of duplicate work which checks if we need to cache + the given generator. - generator: Generator - table_name: str - flag: GeneratorFlag - sql: Optional[str] = None + If so we rename the table with a prefix 'raw' and then create a temp table""" + original_table_name = table_name + if generator_flag in cache_flags: + print("here") + table_name = f"raw_{table_name}" + apsw.ext.make_virtual_module(connection, table_name, generator) -def make_dynamic_entries_generator( - binaries: list[lief.Binary], -) -> MakeGeneratorResponse: + if generator_flag in cache_flags: + connection.execute( + f"""CREATE TEMP TABLE {original_table_name} + AS SELECT * FROM {table_name};""" + ) + + +def register_dynamic_entries_generator( + binaries: list[lief.Binary], connection: apsw.Connection, cache_flags: CacheFlag +) -> None: """Create the .dynamic section virtual table.""" def dynamic_entries_generator() -> Iterator[dict[str, Any]]: @@ -80,17 +95,23 @@ def dynamic_entries_generator() -> Iterator[dict[str, Any]]: for entry in binary.dynamic_entries: # type: ignore yield {"path": binary_name, "tag": entry.tag.name, "value": entry.value} - return MakeGeneratorResponse( - Generator.make_generator( - ["path", "tag", "value"], - dynamic_entries_generator, - ), + generator = Generator.make_generator( + ["path", "tag", "value"], + dynamic_entries_generator, + ) + + register_generator( + connection, + generator, "elf_dynamic_entries", - GeneratorFlag.DYNAMIC_ENTRIES, + CacheFlag.DYNAMIC_ENTRIES, + cache_flags, ) -def make_headers_generator(binaries: list[lief.Binary]) -> MakeGeneratorResponse: +def register_headers_generator( + binaries: list[lief.Binary], connection: apsw.Connection, cache_flags: CacheFlag +) -> None: """Create the ELF headers virtual table,""" def headers_generator() -> Iterator[dict[str, Any]]: @@ -103,17 +124,23 @@ def headers_generator() -> Iterator[dict[str, Any]]: "entry": binary.header.entrypoint, } - return MakeGeneratorResponse( - Generator.make_generator( - ["path", "type", "machine", "version", "entry"], - headers_generator, - ), + generator = Generator.make_generator( + ["path", "type", "machine", "version", "entry"], + headers_generator, + ) + + register_generator( + connection, + generator, "elf_headers", - GeneratorFlag.HEADERS, + CacheFlag.HEADERS, + cache_flags, ) -def make_instructions_generator(binaries: list[lief.Binary]) -> MakeGeneratorResponse: +def register_instructions_generator( + binaries: list[lief.Binary], connection: apsw.Connection, cache_flags: CacheFlag +) -> None: """Create the instructions virtual table. This table includes dissasembled instructions from the executable sections""" @@ -147,15 +174,17 @@ def instructions_generator() -> Iterator[dict[str, Any]]: "operands": op_str, } - return MakeGeneratorResponse( - Generator.make_generator( - ["path", "section", "mnemonic", "address", "operands"], - instructions_generator, - ), - "raw_elf_instructions", - GeneratorFlag.INSTRUCTIONS, - """CREATE TEMP TABLE elf_instructions - AS SELECT * FROM raw_elf_instructions;""", + generator = Generator.make_generator( + ["path", "section", "mnemonic", "address", "operands"], + instructions_generator, + ) + + register_generator( + connection, + generator, + "elf_instructions", + CacheFlag.INSTRUCTIONS, + cache_flags, ) @@ -171,7 +200,9 @@ def arch(binary: lief.Binary) -> int: raise RuntimeError(f"Unknown machine type for {binary.name}") -def make_sections_generator(binaries: list[lief.Binary]) -> MakeGeneratorResponse: +def register_sections_generator( + binaries: list[lief.Binary], connection: apsw.Connection, cache_flags: CacheFlag +) -> None: """Create the ELF sections virtual table.""" def sections_generator() -> Iterator[dict[str, Any]]: @@ -189,13 +220,17 @@ def sections_generator() -> Iterator[dict[str, Any]]: "content": bytes(section.content), } - return MakeGeneratorResponse( - Generator.make_generator( - ["path", "name", "offset", "size", "type", "content"], - sections_generator, - ), + generator = Generator.make_generator( + ["path", "name", "offset", "size", "type", "content"], + sections_generator, + ) + + register_generator( + connection, + generator, "elf_sections", - GeneratorFlag.SECTIONS, + CacheFlag.SECTIONS, + cache_flags, ) @@ -206,7 +241,9 @@ def coerce_section_name(name: str | None) -> str | None: return name -def make_strings_generator(binaries: list[lief.Binary]) -> MakeGeneratorResponse: +def register_strings_generator( + binaries: list[lief.Binary], connection: apsw.Connection, cache_flags: CacheFlag +) -> None: """Create the ELF strings virtual table. This goes through all string tables in the ELF binary and splits them on null bytes. @@ -240,13 +277,17 @@ def strings_generator() -> Iterator[dict[str, Any]]: "offset": offset + 1, } - return MakeGeneratorResponse( - Generator.make_generator( - ["path", "section", "value", "offset"], - strings_generator, - ), + generator = Generator.make_generator( + ["path", "section", "value", "offset"], + strings_generator, + ) + + register_generator( + connection, + generator, "elf_strings", - GeneratorFlag.STRINGS, + CacheFlag.STRINGS, + cache_flags, ) @@ -263,7 +304,9 @@ def split_with_index(str: str, delimiter: str) -> list[tuple[int, str]]: return result -def make_symbols_generator(binaries: list[lief.Binary]) -> MakeGeneratorResponse: +def register_symbols_generator( + binaries: list[lief.Binary], connection: apsw.Connection, cache_flags: CacheFlag +) -> None: """Create the ELF symbols virtual table.""" def symbols_generator() -> Iterator[dict[str, Any]]: @@ -310,32 +353,40 @@ def symbols_generator() -> Iterator[dict[str, Any]]: "value": symbol.value, } - return MakeGeneratorResponse( - Generator.make_generator( - [ - "path", - "name", - "demangled_name", - "imported", - "exported", - "section", - "size", - "version", - "type", - "value", - ], - symbols_generator, - ), - "raw_elf_symbols", - GeneratorFlag.SYMBOLS, - """CREATE TEMP TABLE elf_symbols - AS SELECT * FROM raw_elf_symbols; - CREATE INDEX elf_symbols_path_idx ON elf_symbols (path); - CREATE INDEX elf_symbols_name_idx ON elf_symbols (name);""", + generator = Generator.make_generator( + [ + "path", + "name", + "demangled_name", + "imported", + "exported", + "section", + "size", + "version", + "type", + "value", + ], + symbols_generator, ) + register_generator( + connection, + generator, + "elf_symbols", + CacheFlag.SYMBOLS, + cache_flags, + ) -def make_version_requirements(binaries: list[lief.Binary]) -> MakeGeneratorResponse: + if CacheFlag.SYMBOLS in cache_flags: + connection.execute( + """CREATE INDEX elf_symbols_path_idx ON elf_symbols (path); + CREATE INDEX elf_symbols_name_idx ON elf_symbols (name);""" + ) + + +def register_version_requirements( + binaries: list[lief.Binary], connection: apsw.Connection, cache_flags: CacheFlag +) -> None: """Create the ELF version requirements virtual table. This should match the values found in .gnu.version_r section. @@ -356,17 +407,23 @@ def version_requirements_generator() -> Iterator[dict[str, Any]]: "name": aux_requirement.name, } - return MakeGeneratorResponse( - Generator.make_generator( - ["path", "file", "name"], - version_requirements_generator, - ), + generator = Generator.make_generator( + ["path", "file", "name"], + version_requirements_generator, + ) + + register_generator( + connection, + generator, "elf_version_requirements", - GeneratorFlag.VERSION_REQUIREMENTS, + CacheFlag.VERSION_REQUIREMENTS, + cache_flags, ) -def make_version_definitions(binaries: list[lief.Binary]) -> MakeGeneratorResponse: +def register_version_definitions( + binaries: list[lief.Binary], connection: apsw.Connection, cache_flags: CacheFlag +) -> None: """Create the ELF version requirements virtual table. This should match the values found in .gnu.version_d section. @@ -387,13 +444,17 @@ def version_definitions_generator() -> Iterator[dict[str, Any]]: "flags": flags, } - return MakeGeneratorResponse( - Generator.make_generator( - ["path", "name", "flags"], - version_definitions_generator, - ), + generator = Generator.make_generator( + ["path", "name", "flags"], + version_definitions_generator, + ) + + register_generator( + connection, + generator, "elf_version_definitions", - GeneratorFlag.VERSION_DEFINITIONS, + CacheFlag.VERSION_DEFINITIONS, + cache_flags, ) @@ -419,7 +480,7 @@ def symbols(binary: lief.Binary) -> Sequence[lief.ELF.Symbol]: def register_virtual_tables( connection: apsw.Connection, binaries: list[lief.Binary], - flags: GeneratorFlag = GeneratorFlag.ALL(), + cache_flags: CacheFlag = CacheFlag.INSTRUCTIONS | CacheFlag.SYMBOLS, ) -> None: """Register the virtual table modules. @@ -430,25 +491,15 @@ def register_virtual_tables( connection: the connection to register the virtual tables on binaries: the list of binaries to analyze flags: the bitwise flags which controls which virtual table to enable""" - generator_factories = [ - make_dynamic_entries_generator, - make_headers_generator, - make_instructions_generator, - make_sections_generator, - make_strings_generator, - make_symbols_generator, - make_version_requirements, - make_version_definitions, + register_table_functions = [ + register_dynamic_entries_generator, + register_headers_generator, + register_instructions_generator, + register_sections_generator, + register_strings_generator, + register_symbols_generator, + register_version_requirements, + register_version_definitions, ] - for factory in generator_factories: - generator_response = factory(binaries) - - if generator_response.flag not in flags: - continue - - apsw.ext.make_virtual_module( - connection, generator_response.table_name, generator_response.generator - ) - - if generator_response.sql: - connection.execute(generator_response.sql) + for register_function in register_table_functions: + register_function(binaries, connection, cache_flags) diff --git a/sqlelf/sql.py b/sqlelf/sql.py index 76a6bc7..2adc3a6 100644 --- a/sqlelf/sql.py +++ b/sqlelf/sql.py @@ -62,7 +62,7 @@ def find_libraries(binary: lief.Binary) -> Dict[str, str]: def make_sql_engine( filenames: list[str], recursive: bool = False, - flags: elf.GeneratorFlag = elf.GeneratorFlag.ALL(), + cache_flags: elf.CacheFlag = elf.CacheFlag.INSTRUCTIONS | elf.CacheFlag.SYMBOLS, ) -> SQLEngine: """Create a SQL engine from a list of binaries @@ -75,7 +75,7 @@ def make_sql_engine( filenames: the list of binaries to analyze -- should be absolute path recursive: whether to recursively load all shared libraries needed by each binary - flags: the flags to use when generating the virtual tables + cache_flags: bit flag that controls which tables to cache """ binaries: list[lief.Binary] = [ lief.parse(filename) for filename in filenames if lief.is_elf(filename) @@ -98,5 +98,5 @@ def make_sql_engine( ) binaries = binaries + [lief.parse(library) for library in shared_libraries_set] - elf.register_virtual_tables(connection, binaries, flags) + elf.register_virtual_tables(connection, binaries, cache_flags) return SQLEngine(connection) diff --git a/tests/test_examples.py b/tests/test_examples.py index 164f75c..017a770 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -9,7 +9,7 @@ def test_symbol_resolutions() -> None: # TODO(fzakaria): Make sure this binary # is always present in the CI environment. sql_engine = sql.make_sql_engine( - ["/usr/bin/ruby"], recursive=True, flags=elf.GeneratorFlag.SYMBOLS + ["/usr/bin/ruby"], recursive=True, cache_flags=elf.CacheFlag.SYMBOLS ) result = sql_engine.execute( """