From fde5e419610d3079377f1fbacfcecab3f3031f43 Mon Sep 17 00:00:00 2001 From: Farid Zakaria Date: Fri, 22 Sep 2023 21:14:00 +0000 Subject: [PATCH] Simplify the code (#12) * Simplify the code. Based on feedback from @markrwilliams I consolidated a bunch of the files into a single file (elf.py) for readability. * added mypy support * more typings: everything is typed now! * remove duplicate assignment of columns in elf.py * mypy passes * PR feedback * typo fix --- Makefile | 1 + pyproject.toml | 16 ++- sqlelf/cli.py | 27 ++-- sqlelf/elf.py | 265 ++++++++++++++++++++++++++++++++++++++ sqlelf/elf/__init__.py | 0 sqlelf/elf/dynamic.py | 31 ----- sqlelf/elf/header.py | 28 ---- sqlelf/elf/instruction.py | 71 ---------- sqlelf/elf/section.py | 42 ------ sqlelf/elf/strings.py | 39 ------ sqlelf/elf/symbol.py | 93 ------------- sqlelf/ldd.py | 23 ---- sqlelf/sql.py | 47 ++++--- tests/test_cli.py | 17 ++- tests/test_ldd.py | 39 ------ tests/test_sql.py | 39 +++++- 16 files changed, 380 insertions(+), 398 deletions(-) create mode 100644 sqlelf/elf.py delete mode 100644 sqlelf/elf/__init__.py delete mode 100644 sqlelf/elf/dynamic.py delete mode 100644 sqlelf/elf/header.py delete mode 100644 sqlelf/elf/instruction.py delete mode 100644 sqlelf/elf/section.py delete mode 100644 sqlelf/elf/strings.py delete mode 100644 sqlelf/elf/symbol.py delete mode 100644 sqlelf/ldd.py delete mode 100644 tests/test_ldd.py diff --git a/Makefile b/Makefile index 4b4d754..92ebc84 100644 --- a/Makefile +++ b/Makefile @@ -23,6 +23,7 @@ lint: ## Run pep8, black, mypy linters. flake8 sqlelf/ black --check sqlelf/ pyright + mypy --strict --install-types --non-interactive sqlelf tests .PHONY: test test: ## Run pytest primarily. diff --git a/pyproject.toml b/pyproject.toml index 3d7ea4d..7d3d97d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,15 @@ readme = "README.md" description = "Explore ELF objects through the power of SQL" license = { file = "LICENSE" } requires-python = ">=3.10,<4.0" - +keywords = [] +classifiers = [ + "Development Status :: 4 - Beta", + "Programming Language :: Python", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", +] dependencies = [ "capstone >= 5.0.1", "lief >=0.13.2", @@ -18,6 +26,11 @@ dependencies = [ "sh >= 2.0.6", ] +[project.urls] +Documentation = "https://github.com/fzakaria/sqlelf#readme" +Issues = "https://github.com/fzakaria/sqlelf/issues" +Source = "https://github.com/fzakaria/sqlelf" + [project.optional-dependencies] dev = [ "black >= 23.7.0", @@ -25,6 +38,7 @@ dev = [ "flake8 >= 6.1.0", "pyright >= 1.1.325", "pytest >= 7.4.0", + "mypy >= 1.0.0", ] [tool.setuptools] diff --git a/sqlelf/cli.py b/sqlelf/cli.py index c5cd880..5b2bd66 100644 --- a/sqlelf/cli.py +++ b/sqlelf/cli.py @@ -2,14 +2,23 @@ import os import os.path import sys +from dataclasses import dataclass, field from functools import reduce +from typing import TextIO import lief from sqlelf import sql as api_sql -def start(args=sys.argv[1:], stdin=sys.stdin): +@dataclass +class ProgramArguments: + filenames: list[str] = field(default_factory=list) + sql: list[str] = field(default_factory=list) + recursive: bool = False + + +def start(args: list[str] = sys.argv[1:], stdin: TextIO = sys.stdin) -> None: """ Start the main CLI @@ -37,7 +46,9 @@ def start(args=sys.argv[1:], stdin=sys.stdin): help="Load all shared libraries needed by each file using ldd", ) - args = parser.parse_args(args) + program_args: ProgramArguments = parser.parse_args( + args, namespace=ProgramArguments() + ) # Iterate through our arguments and if one of them is a directory explode it out filenames: list[str] = reduce( @@ -46,7 +57,7 @@ def start(args=sys.argv[1:], stdin=sys.stdin): lambda dir: [os.path.join(dir, f) for f in os.listdir(dir)] if os.path.isdir(dir) else [dir], - args.filenames, + program_args.filenames, ), ) # Filter the list of filenames to those that are ELF files only @@ -58,11 +69,11 @@ def start(args=sys.argv[1:], stdin=sys.stdin): binaries: list[lief.Binary] = [lief.parse(filename) for filename in filenames] - sql_engine = api_sql.make_sql_engine(binaries, recursive=args.recursive) + sql_engine = api_sql.make_sql_engine(binaries, recursive=program_args.recursive) shell = sql_engine.shell(stdin=stdin) - if args.sql: - for sql in args.sql: - shell.process_complete_line(sql) + if program_args.sql and len(program_args.filenames) > 0: + for sql in program_args.sql: + shell.process_complete_line(sql) # type: ignore[no-untyped-call] else: - shell.cmdloop() + shell.cmdloop() # type: ignore[no-untyped-call] diff --git a/sqlelf/elf.py b/sqlelf/elf.py new file mode 100644 index 0000000..05902ef --- /dev/null +++ b/sqlelf/elf.py @@ -0,0 +1,265 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable, Iterator, Sequence, cast + +import apsw +import apsw.ext +import capstone # type: ignore +import lief + + +@dataclass +class Generator: + """A generator for the virtual table SQLite module. + + This class is needed because apsw wants to assign columns and + column_access to the generator function itself.""" + + columns: Sequence[str] + column_access: apsw.ext.VTColumnAccess + callable: Callable[[], Iterator[dict[str, Any]]] + + def __call__(self) -> Iterator[dict[str, Any]]: + """Call the generator should return an iterator of dictionaries. + + The dictionaries should have keys that match the column names.""" + return self.callable() + + @staticmethod + def make_generator(generator: Callable[[], Iterator[dict[str, Any]]]) -> Generator: + """Create a generator from a callable that returns + an iterator of dictionaries.""" + columns, column_access = apsw.ext.get_column_names(next(generator())) + return Generator(columns, column_access, generator) + + +def make_dynamic_entries_generator(binaries: list[lief.Binary]) -> Generator: + """Create the .dynamic section virtual table.""" + + def dynamic_entries_generator() -> Iterator[dict[str, Any]]: + for binary in binaries: + # super important that these accessors are pulled out of the tight loop + # as they can be costly + binary_name = binary.name + for entry in binary.dynamic_entries: # type: ignore + yield {"path": binary_name, "tag": entry.tag.name, "value": entry.value} + + return Generator.make_generator(dynamic_entries_generator) + + +def make_headers_generator(binaries: list[lief.Binary]) -> Generator: + """Create the ELF headers virtual table,""" + + def headers_generator() -> Iterator[dict[str, Any]]: + for binary in binaries: + yield { + "path": binary.name, + "type": binary.header.file_type.name, + "machine": binary.header.machine_type.name, + "version": binary.header.identity_version.name, + "entry": binary.header.entrypoint, + } + + return Generator.make_generator(headers_generator) + + +def make_instructions_generator(binaries: list[lief.Binary]) -> Generator: + """Create the instructions virtual table. + + This table includes dissasembled instructions from the executable sections""" + + def instructions_generator() -> Iterator[dict[str, Any]]: + for binary in binaries: + # super important that these accessors are pulled out of the tight loop + # as they can be costly + binary_name = binary.name + + for section in binary.sections: + if section.has(lief.ELF.SECTION_FLAGS.EXECINSTR): + data = bytes(section.content) + md = capstone.Cs(arch(binary), mode(binary)) + # keep in mind that producing details costs more memory, + # complicates the internal operations and slows down + # the engine a bit, so only do that if needed. + md.detail = False + + # super important that these accessors are pulled out + # of the tight loop as they can be costly + section_name = section.name + for address, size, mnemonic, op_str in md.disasm_lite( + data, section.virtual_address + ): + yield { + "path": binary_name, + "section": section_name, + "mnemonic": mnemonic, + "address": address, + "operands": op_str, + } + + return Generator.make_generator(instructions_generator) + + +def mode(binary: lief.Binary) -> int: + if binary.header.identity_class == lief.ELF.ELF_CLASS.CLASS64: + return cast(int, capstone.CS_MODE_64) + raise RuntimeError(f"Unknown mode for {binary.name}") + + +def arch(binary: lief.Binary) -> int: + if binary.header.machine_type == lief.ELF.ARCH.x86_64: + return cast(int, capstone.CS_ARCH_X86) + raise RuntimeError(f"Unknown machine type for {binary.name}") + + +def make_sections_generator(binaries: list[lief.Binary]) -> Generator: + """Create the ELF sections virtual table.""" + + def sections_generator() -> Iterator[dict[str, Any]]: + for binary in binaries: + # super important that these accessors are pulled out of the tight loop + # as they can be costly + binary_name = binary.name + for section in binary.sections: + yield { + "path": binary_name, + "name": section.name, + "offset": section.offset, + "size": section.size, + "type": section.type.name, + "content": bytes(section.content), + } + + return Generator.make_generator(sections_generator) + + +def coerce_section_name(name: str | None) -> str | None: + """Return a section name or undefined if the name is empty.""" + if name == "": + return "undefined" + return name + + +def make_strings_generator(binaries: list[lief.Binary]) -> Generator: + """Create the ELF strings virtual table. + + This goes through all string tables in the ELF binary and splits them on null bytes. + """ + + def strings_generator() -> Iterator[dict[str, Any]]: + for binary in binaries: + strtabs = [ + section + for section in binary.sections + if section.type == lief.ELF.SECTION_TYPES.STRTAB + ] + # super important that these accessors are pulled out of the tight loop + # as they can be costly + binary_name = binary.name + for strtab in strtabs: + # The first byte is always the null byte in the STRTAB + # Python also treats the final null in the string by creating + # an empty item so we chop it off. + # https://stackoverflow.com/a/18970869 + for string in str(strtab.content[1:-1], "utf-8").split("\x00"): + yield {"path": binary_name, "section": strtab.name, "value": string} + + return Generator.make_generator(strings_generator) + + +def make_symbols_generator(binaries: list[lief.Binary]) -> Generator: + """Create the ELF symbols virtual table.""" + + def symbols_generator() -> Iterator[dict[str, Any]]: + for binary in binaries: + # super important that these accessors are pulled out of the tight loop + # as they can be costly + binary_name = binary.name + for symbol in symbols(binary): + # The section index can be special numbers like 65521 or 65522 + # that refer to special sections so they can't be indexed + section_name: str | None = next( + ( + section.name + for shndx, section in enumerate(binary.sections) + if shndx == symbol.shndx + ), + None, + ) + + yield { + "path": binary_name, + "name": symbol.name, + "demangled_name": symbol.demangled_name, + # A bit of detailed explanation here to explain these values. + # A symbol may point to the SHN_UNDEF section which is a good it's + # an "imported symbol" -- meaning it needs to be linked in. + # If the section is != SH_UNDEF then it is "exported" as it's + # logic resides within this shared object file. + # refs: + # https://github.com/lief-project/LIEF/blob/0875ee2467d5ae6628d8bf3f4f0b82ca5854c401/src/ELF/Symbol.cpp#L90 + # https://stackoverflow.com/questions/12666253/elf-imports-and-exports + # https://www.m4b.io/elf/export/binary/analysis/2015/05/25/what-is-an-elf-export.html + "imported": symbol.imported, + "exported": symbol.exported, + "section": coerce_section_name(section_name), + "size": symbol.size, + # TODO(fzakaria): Better understand why is it auxiliary? + # this returns versions like GLIBC_2.2.5 + "version": symbol.symbol_version.symbol_version_auxiliary.name + if symbol.symbol_version + and symbol.symbol_version.symbol_version_auxiliary + else None, + "type": symbol.type.name, + "value": symbol.value, + } + + return Generator.make_generator(symbols_generator) + + +def symbols(binary: lief.Binary) -> Sequence[lief.ELF.Symbol]: + """Use heuristic to either get static symbols or dynamic symbol table + + The static symbol table is a superset of the dynamic symbol table. + However it is often stripped from binaries as it's not needed beyond + debugging. + + This method uses the simplest heuristic of checking for its existence + to return the static symbol table. + + A bad actor is free to strip arbitrarily from the static symbol table + and it would affect this method. + """ + static_symbols: Sequence[lief.ELF.Symbol] = binary.static_symbols # type: ignore + if len(static_symbols) > 0: + return static_symbols + return binary.dynamic_symbols # type: ignore + + +def register_virtual_tables( + connection: apsw.Connection, binaries: list[lief.Binary] +) -> None: + """Register the virtual table modules.""" + factory_and_names = [ + (make_dynamic_entries_generator, "elf_dynamic_entries"), + (make_headers_generator, "elf_headers"), + (make_instructions_generator, "raw_elf_instructions"), + (make_sections_generator, "elf_sections"), + (make_strings_generator, "elf_strings"), + (make_symbols_generator, "raw_elf_symbols"), + ] + for factory, name in factory_and_names: + generator = factory(binaries) + apsw.ext.make_virtual_module(connection, name, generator) + connection.execute( + """ + CREATE TEMP TABLE elf_instructions + AS SELECT * FROM raw_elf_instructions; + + CREATE TEMP TABLE elf_symbols + AS SELECT * FROM raw_elf_symbols; + CREATE INDEX elf_symbols_path_idx ON elf_symbols (path); + CREATE INDEX elf_symbols_name_idx ON elf_symbols (name); + """ + ) diff --git a/sqlelf/elf/__init__.py b/sqlelf/elf/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/sqlelf/elf/dynamic.py b/sqlelf/elf/dynamic.py deleted file mode 100644 index d933b59..0000000 --- a/sqlelf/elf/dynamic.py +++ /dev/null @@ -1,31 +0,0 @@ -# Without this Python was complaining -from __future__ import annotations - -from typing import Any, Iterator - -import apsw -import apsw.ext -import lief - - -# This is effectively the .dynamic section but it is elevated as a table here -# since it is widely used and can benefit from simpler table access. -def elf_dynamic_entries(binaries: list[lief.Binary]): - def generator() -> Iterator[dict[str, Any]]: - for binary in binaries: - # super important that these accessors are pulled out of the tight loop - # as they can be costly - binary_name = binary.name - for entry in binary.dynamic_entries: # pyright: ignore - yield {"path": binary_name, "tag": entry.tag.name, "value": entry.value} - - return generator - - -def register(connection: apsw.Connection, binaries: list[lief.Binary]): - generator = elf_dynamic_entries(binaries) - # setup columns and access by providing an example of the first entry returned - generator.columns, generator.column_access = apsw.ext.get_column_names( - next(generator()) - ) - apsw.ext.make_virtual_module(connection, "elf_dynamic_entries", generator) diff --git a/sqlelf/elf/header.py b/sqlelf/elf/header.py deleted file mode 100644 index e4e7b51..0000000 --- a/sqlelf/elf/header.py +++ /dev/null @@ -1,28 +0,0 @@ -from typing import Any, Iterator - -import apsw -import apsw.ext -import lief - - -def elf_headers(binaries: list[lief.Binary]): - def generator() -> Iterator[dict[str, Any]]: - for binary in binaries: - yield { - "path": binary.name, - "type": binary.header.file_type.name, - "machine": binary.header.machine_type.name, - "version": binary.header.identity_version.name, - "entry": binary.header.entrypoint, - } - - return generator - - -def register(connection: apsw.Connection, binaries: list[lief.Binary]): - generator = elf_headers(binaries) - # setup columns and access by providing an example of the first entry returned - generator.columns, generator.column_access = apsw.ext.get_column_names( - next(generator()) - ) - apsw.ext.make_virtual_module(connection, "elf_headers", generator) diff --git a/sqlelf/elf/instruction.py b/sqlelf/elf/instruction.py deleted file mode 100644 index 7637e17..0000000 --- a/sqlelf/elf/instruction.py +++ /dev/null @@ -1,71 +0,0 @@ -# Without this Python was complaining -from __future__ import annotations - -from typing import Any, Iterator - -import apsw -import apsw.ext - -# TODO(fzkakaria): https://github.com/capstone-engine/capstone/issues/1993 -import capstone # pyright: ignore -import lief - - -def elf_instructions(binaries: list[lief.Binary]): - def generator() -> Iterator[dict[str, Any]]: - for binary in binaries: - # super important that these accessors are pulled out of the tight loop - # as they can be costly - binary_name = binary.name - - for section in binary.sections: - if section.has(lief.ELF.SECTION_FLAGS.EXECINSTR): - data = bytes(section.content) - md = capstone.Cs(arch(binary), mode(binary)) - # keep in mind that producing details costs more memory, - # complicates the internal operations and slows down - # the engine a bit, so only do that if needed. - md.detail = False - - # super important that these accessors are pulled out - # of the tight loop as they can be costly - section_name = section.name - for address, size, mnemonic, op_str in md.disasm_lite( - data, section.virtual_address - ): - yield { - "path": binary_name, - "section": section_name, - "mnemonic": mnemonic, - "address": address, - "operands": op_str, - } - - return generator - - -def mode(binary: lief.Binary) -> int: - if binary.header.identity_class == lief.ELF.ELF_CLASS.CLASS64: - return capstone.CS_MODE_64 - raise Exception(f"Unknown mode for {binary.name}") - - -def arch(binary: lief.Binary) -> int: - if binary.header.machine_type == lief.ELF.ARCH.x86_64: - return capstone.CS_ARCH_X86 - raise Exception(f"Unknown machine type for {binary.name}") - - -def register(connection: apsw.Connection, binaries: list[lief.Binary]): - generator = elf_instructions(binaries) - # setup columns and access by providing an example of the first entry returned - generator.columns, generator.column_access = apsw.ext.get_column_names( - next(generator()) - ) - apsw.ext.make_virtual_module(connection, "raw_elf_instructions", generator) - connection.execute( - """ - CREATE TEMP TABLE elf_instructions - AS SELECT * FROM raw_elf_instructions; - """ - ) diff --git a/sqlelf/elf/section.py b/sqlelf/elf/section.py deleted file mode 100644 index 969440c..0000000 --- a/sqlelf/elf/section.py +++ /dev/null @@ -1,42 +0,0 @@ -# Without this Python was complaining -from __future__ import annotations - -from typing import Any, Iterator - -import apsw -import apsw.ext -import lief - - -def elf_sections(binaries: list[lief.Binary]): - def generator() -> Iterator[dict[str, Any]]: - for binary in binaries: - # super important that these accessors are pulled out of the tight loop - # as they can be costly - binary_name = binary.name - for section in binary.sections: - yield { - "path": binary_name, - "name": section.name, - "offset": section.offset, - "size": section.size, - "type": section.type.name, - "content": bytes(section.content), - } - - return generator - - -def section_name(name: str | None) -> str | None: - if name == "": - return "undefined" - return name - - -def register(connection: apsw.Connection, binaries: list[lief.Binary]): - generator = elf_sections(binaries) - # setup columns and access by providing an example of the first entry returned - generator.columns, generator.column_access = apsw.ext.get_column_names( - next(generator()) - ) - apsw.ext.make_virtual_module(connection, "elf_sections", generator) diff --git a/sqlelf/elf/strings.py b/sqlelf/elf/strings.py deleted file mode 100644 index 456ead3..0000000 --- a/sqlelf/elf/strings.py +++ /dev/null @@ -1,39 +0,0 @@ -# Without this Python was complaining -from __future__ import annotations - -from typing import Any, Iterator - -import apsw -import apsw.ext -import lief - - -def elf_strings(binaries: list[lief.Binary]): - def generator() -> Iterator[dict[str, Any]]: - for binary in binaries: - strtabs = [ - section - for section in binary.sections - if section.type == lief.ELF.SECTION_TYPES.STRTAB - ] - # super important that these accessors are pulled out of the tight loop - # as they can be costly - binary_name = binary.name - for strtab in strtabs: - # The first byte is always the null byte in the STRTAB - # Python also treats the final null in the string by creating - # an empty item so we chop it off. - # https://stackoverflow.com/a/18970869 - for string in str(strtab.content[1:-1], "utf-8").split("\x00"): - yield {"path": binary_name, "section": strtab.name, "value": string} - - return generator - - -def register(connection: apsw.Connection, binaries: list[lief.Binary]): - generator = elf_strings(binaries) - # setup columns and access by providing an example of the first entry returned - generator.columns, generator.column_access = apsw.ext.get_column_names( - next(generator()) - ) - apsw.ext.make_virtual_module(connection, "elf_strings", generator) diff --git a/sqlelf/elf/symbol.py b/sqlelf/elf/symbol.py deleted file mode 100644 index 968b039..0000000 --- a/sqlelf/elf/symbol.py +++ /dev/null @@ -1,93 +0,0 @@ -# Without this Python was complaining -from __future__ import annotations - -from typing import Any, Iterator - -import apsw -import apsw.ext -import lief -from sqlelf.elf.section import section_name as elf_section_name - - -def elf_symbols(binaries: list[lief.Binary]): - def generator() -> Iterator[dict[str, Any]]: - for binary in binaries: - # super important that these accessors are pulled out of the tight loop - # as they can be costly - binary_name = binary.name - for symbol in symbols(binary): - # The section index can be special numbers like 65521 or 65522 - # that refer to special sections so they can't be indexed - section_name: str | None = next( - ( - section.name - for shndx, section in enumerate(binary.sections) - if shndx == symbol.shndx - ), - None, - ) - - yield { - "path": binary_name, - "name": symbol.name, - "demangled_name": symbol.demangled_name, - # A bit of detailed explanation here to explain these values. - # A symbol may point to the SHN_UNDEF section which is a good it's - # an "imported symbol" -- meaning it needs to be linked in. - # If the section is != SH_UNDEF then it is "exported" as it's - # logic resides within this shared object file. - # refs: - # https://github.com/lief-project/LIEF/blob/0875ee2467d5ae6628d8bf3f4f0b82ca5854c401/src/ELF/Symbol.cpp#L90 - # https://stackoverflow.com/questions/12666253/elf-imports-and-exports - # https://www.m4b.io/elf/export/binary/analysis/2015/05/25/what-is-an-elf-export.html - "imported": symbol.imported, - "exported": symbol.exported, - "section": elf_section_name(section_name), - "size": symbol.size, - # TODO(fzakaria): Better understand why is it auxiliary? - # this returns versions like GLIBC_2.2.5 - "version": symbol.symbol_version.symbol_version_auxiliary.name - if symbol.symbol_version - and symbol.symbol_version.symbol_version_auxiliary - else None, - "type": symbol.type.name, - "value": symbol.value, - } - - return generator - - -def symbols(binary: lief.Binary) -> Iterator[lief.ELF.Symbol]: - """Use heuristic to either get static symbols or dynamic symbol table - - The static symbol table is a superset of the dynamic symbol table. - However it is often stripped from binaries as it's not needed beyond - debugging. - - This method uses the simplest heuristic of checking for it's existence - to return the static symbol table. - - A bad actor is free to strip arbitrarily from the static symbol table - and it would affect this method. - """ - static_symbols = binary.static_symbols # pyright: ignore - missing from pyi - if len(static_symbols) > 0: - return static_symbols - return binary.dynamic_symbols # pyright: ignore - missing from pyi - - -def register(connection: apsw.Connection, binaries: list[lief.Binary]): - generator = elf_symbols(binaries) - # setup columns and access by providing an example of the first entry returned - generator.columns, generator.column_access = apsw.ext.get_column_names( - next(generator()) - ) - apsw.ext.make_virtual_module(connection, "raw_elf_symbols", generator) - connection.execute( - """ - CREATE TEMP TABLE elf_symbols - AS SELECT * FROM raw_elf_symbols; - CREATE INDEX elf_symbols_path_idx ON elf_symbols (path); - CREATE INDEX elf_symbols_name_idx ON elf_symbols (name); - """ - ) diff --git a/sqlelf/ldd.py b/sqlelf/ldd.py deleted file mode 100644 index 0db6832..0000000 --- a/sqlelf/ldd.py +++ /dev/null @@ -1,23 +0,0 @@ -import re -from collections import OrderedDict -from typing import Dict - -import lief -import sh # pyright: ignore - - -def libraries(binary: lief.Binary) -> Dict[str, str]: - """Use the interpreter in a binary to determine the path of each linked library""" - interpreter = sh.Command(binary.interpreter) # pyright: ignore - resolution = interpreter("--list", binary.name) - result = OrderedDict() - # TODO: Figure out why `--list` and `ldd` produce different outcomes - # specifically for the interpreter. - # https://gist.github.com/fzakaria/3dc42a039401598d8e0fdbc57f5e7eae - for line in resolution.splitlines(): # pyright: ignore - m = re.match(r"\s*([^ ]+) => ([^ ]+)", line) - if not m: - continue - soname, lib = m.group(1), m.group(2) - result[soname] = lib - return result diff --git a/sqlelf/sql.py b/sqlelf/sql.py index 9aa8979..2019e63 100644 --- a/sqlelf/sql.py +++ b/sqlelf/sql.py @@ -1,23 +1,25 @@ import os +import re import sys +from collections import OrderedDict from dataclasses import dataclass -from typing import Any, Iterator +from typing import Any, Dict, Iterator, TextIO import apsw import apsw.shell import lief +import sh # type: ignore -from sqlelf import ldd -from sqlelf.elf import dynamic, header, instruction, section, strings, symbol +from sqlelf import elf @dataclass class SQLEngine: connection: apsw.Connection - def shell(self, stdin=sys.stdin) -> apsw.shell.Shell: + def shell(self, stdin: TextIO = sys.stdin) -> apsw.shell.Shell: shell = apsw.shell.Shell(db=self.connection, stdin=stdin) - shell.command_prompt(["sqlelf> "]) + shell.command_prompt(["sqlelf> "]) # type: ignore[no-untyped-call] return shell def execute_raw(self, sql: str) -> apsw.Cursor: @@ -31,29 +33,42 @@ def execute(self, sql: str) -> Iterator[dict[str, Any]]: yield dict(zip(column_names, row)) -def make_sql_engine(binaries: list[lief.Binary], recursive=False) -> SQLEngine: +def find_libraries(binary: lief.Binary) -> Dict[str, str]: + """Use the interpreter in a binary to determine the path of each linked library""" + interpreter = binary.interpreter # type: ignore + interpreter_cmd = sh.Command(interpreter) + resolution = interpreter_cmd("--list", binary.name) + result = OrderedDict() + # TODO: Figure out why `--list` and `ldd` produce different outcomes + # specifically for the interpreter. + # https://gist.github.com/fzakaria/3dc42a039401598d8e0fdbc57f5e7eae + for line in resolution.splitlines(): # type: ignore[unused-ignore] + m = re.match(r"\s*([^ ]+) => ([^ ]+)", line) + if not m: + continue + soname, lib = m.group(1), m.group(2) + result[soname] = lib + return result + + +def make_sql_engine(binaries: list[lief.Binary], recursive: bool = False) -> SQLEngine: connection = apsw.Connection(":memory:") if recursive: # We want to load all the shared libraries needed by each binary # so we can analyze them as well - shared_libraries = [ldd.libraries(binary).values() for binary in binaries] + shared_libraries = [find_libraries(binary).values() for binary in binaries] # We want to readlink on the libraries to resolve # symlinks such as libm -> libc # also make this is a set in the case that multiple binaries use the same - shared_libraries = set( + shared_libraries_set = set( [ os.path.realpath(library) for sub_list in shared_libraries for library in sub_list ] ) - binaries = binaries + [lief.parse(library) for library in shared_libraries] - - header.register(connection, binaries) - section.register(connection, binaries) - symbol.register(connection, binaries) - dynamic.register(connection, binaries) - strings.register(connection, binaries) - instruction.register(connection, binaries) + binaries = binaries + [lief.parse(library) for library in shared_libraries_set] + + elf.register_virtual_tables(connection, binaries) return SQLEngine(connection) diff --git a/tests/test_cli.py b/tests/test_cli.py index ceec5a1..85f1881 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -2,24 +2,29 @@ import pytest from io import StringIO -def test_cli_bad_arguments(): + +def test_cli_bad_arguments() -> None: with pytest.raises(SystemExit): cli.start(["--does-not-exist"]) -def test_cli_no_arguments(): + +def test_cli_no_arguments() -> None: with pytest.raises(SystemExit): cli.start([]) -def test_cli_single_file_arguments(): + +def test_cli_single_file_arguments() -> None: stdin = StringIO("") cli.start(["/bin/ls"], stdin) -def test_cli_single_non_existent_file_arguments(): + +def test_cli_single_non_existent_file_arguments() -> None: with pytest.raises(SystemExit) as err: cli.start(["does_not_exist"]) -def test_cli_prompt_single_file_arguments(): + +def test_cli_prompt_single_file_arguments() -> None: stdin = StringIO(".exit 56\n") with pytest.raises(SystemExit) as err: cli.start(["/bin/ls"], stdin) - assert err.value.code == 56 \ No newline at end of file + assert err.value.code == 56 diff --git a/tests/test_ldd.py b/tests/test_ldd.py deleted file mode 100644 index 6fa3242..0000000 --- a/tests/test_ldd.py +++ /dev/null @@ -1,39 +0,0 @@ -from sqlelf import ldd -import lief -from unittest.mock import patch - - -def test_simple_binary_real(): - binary = lief.parse("/bin/ls") - result = ldd.libraries(binary) - assert len(result) > 0 - - -@patch("sh.Command") -def test_simple_binary_mocked(Command): - binary = lief.parse("/bin/ls") - interpreter = binary.interpreter # pyright: ignore - Command( - interpreter - ).return_value = """ - linux-vdso.so.1 (0x00007ffc5d8ff000) - /lib/x86_64-linux-gnu/libnss_cache.so.2 (0x00007f6995d92000) - libselinux.so.1 => not found - fake.so.6 => /some-path/fake.so.6 - libc.so.6 => /nix/store/46m4xx889wlhsdj72j38fnlyyvvvvbyb-glibc-2.37-8/lib/libc.so.6 (0x00007f6995bac000) - /lib64/ld-linux-x86-64.so.2 => /nix/store/46m4xx889wlhsdj72j38fnlyyvvvvbyb-glibc-2.37-8/lib64/ld-linux-x86-64.so.2 (0x00007f6995dc1000) -""" - result = ldd.libraries(binary) - assert len(result) == 4 - assert result["fake.so.6"] == "/some-path/fake.so.6" - assert ( - result["/lib64/ld-linux-x86-64.so.2"] - == "/nix/store/46m4xx889wlhsdj72j38fnlyyvvvvbyb-glibc-2.37-8/lib64/ld-linux-x86-64.so.2" - ) - assert ( - result["libc.so.6"] - == "/nix/store/46m4xx889wlhsdj72j38fnlyyvvvvbyb-glibc-2.37-8/lib/libc.so.6" - ) - # TODO(fzakaria):better handling for not found - # kind of a weird one since this should never happen though - assert result["libselinux.so.1"] == "not" diff --git a/tests/test_sql.py b/tests/test_sql.py index bcfdba8..3adb4d3 100644 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -1,8 +1,45 @@ from sqlelf import sql import lief +from unittest.mock import patch +import sh # type: ignore -def test_simple_select_header(): +def test_simple_binary_real() -> None: + binary = lief.parse("/bin/ls") + result = sql.find_libraries(binary) + assert len(result) > 0 + + +@patch("sh.Command") +def test_simple_binary_mocked(Command: sh.Command) -> None: + binary = lief.parse("/bin/ls") + interpreter = binary.interpreter # type: ignore + expected_return_value = """ + linux-vdso.so.1 (0x00007ffc5d8ff000) + /lib/x86_64-linux-gnu/libnss_cache.so.2 (0x00007f6995d92000) + libselinux.so.1 => not found + fake.so.6 => /some-path/fake.so.6 + libc.so.6 => /nix/store/46m4xx889wlhsdj72j38fnlyyvvvvbyb-glibc-2.37-8/lib/libc.so.6 (0x00007f6995bac000) + /lib64/ld-linux-x86-64.so.2 => /nix/store/46m4xx889wlhsdj72j38fnlyyvvvvbyb-glibc-2.37-8/lib64/ld-linux-x86-64.so.2 (0x00007f6995dc1000) + """ + Command(interpreter).return_value = expected_return_value # pyright: ignore + result = sql.find_libraries(binary) + assert len(result) == 4 + assert result["fake.so.6"] == "/some-path/fake.so.6" + assert ( + result["/lib64/ld-linux-x86-64.so.2"] + == "/nix/store/46m4xx889wlhsdj72j38fnlyyvvvvbyb-glibc-2.37-8/lib64/ld-linux-x86-64.so.2" + ) + assert ( + result["libc.so.6"] + == "/nix/store/46m4xx889wlhsdj72j38fnlyyvvvvbyb-glibc-2.37-8/lib/libc.so.6" + ) + # TODO(fzakaria):better handling for not found + # kind of a weird one since this should never happen though + assert result["libselinux.so.1"] == "not" + + +def test_simple_select_header() -> None: engine = sql.make_sql_engine([lief.parse("/bin/ls")]) result = list(engine.execute("SELECT * FROM elf_headers LIMIT 1")) assert len(result) == 1