diff --git a/sqlelf/cli.py b/sqlelf/cli.py index 5b2bd66..984cebb 100644 --- a/sqlelf/cli.py +++ b/sqlelf/cli.py @@ -6,8 +6,6 @@ from functools import reduce from typing import TextIO -import lief - from sqlelf import sql as api_sql @@ -60,16 +58,14 @@ def start(args: list[str] = sys.argv[1:], stdin: TextIO = sys.stdin) -> None: program_args.filenames, ), ) - # Filter the list of filenames to those that are ELF files only - filenames = [f for f in filenames if os.path.isfile(f) and lief.is_elf(f)] + # Filter the list of filenames to those that are files only + filenames = [f for f in filenames if os.path.isfile(f)] # If none of the inputs are valid files, simply return if len(filenames) == 0: sys.exit("No valid ELF files were provided") - binaries: list[lief.Binary] = [lief.parse(filename) for filename in filenames] - - sql_engine = api_sql.make_sql_engine(binaries, recursive=program_args.recursive) + sql_engine = api_sql.make_sql_engine(filenames, recursive=program_args.recursive) shell = sql_engine.shell(stdin=stdin) if program_args.sql and len(program_args.filenames) > 0: diff --git a/sqlelf/sql.py b/sqlelf/sql.py index 2019e63..eb76abf 100644 --- a/sqlelf/sql.py +++ b/sqlelf/sql.py @@ -51,7 +51,15 @@ def find_libraries(binary: lief.Binary) -> Dict[str, str]: return result -def make_sql_engine(binaries: list[lief.Binary], recursive: bool = False) -> SQLEngine: +def make_sql_engine(filenames: list[str], recursive: bool = False) -> SQLEngine: + """Create a SQL engine from a list of binaries + + Args: + filenames: the list of binaries to analyze -- should be absolute path + recursive: whether to recursively load all shared + libraries needed by each binary + """ + binaries: list[lief.Binary] = [lief.parse(filename) for filename in filenames] connection = apsw.Connection(":memory:") if recursive: diff --git a/tests/test_sql.py b/tests/test_sql.py index 9c745fd..256a89d 100644 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -41,7 +41,7 @@ def test_simple_binary_mocked(Command: sh.Command) -> None: def test_simple_select_header() -> None: # TODO(fzakaria): Figure out a better binary to be doing that we control - engine = sql.make_sql_engine([lief.parse("/bin/ls")]) + engine = sql.make_sql_engine(["/bin/ls"]) result = list(engine.execute("SELECT * FROM elf_headers LIMIT 1")) assert len(result) == 1 assert "path" in result[0] @@ -53,7 +53,7 @@ def test_simple_select_header() -> None: def test_simple_select_version_requirements() -> None: # TODO(fzakaria): Figure out a better binary to be doing that we control - engine = sql.make_sql_engine([lief.parse("/bin/ls")]) + engine = sql.make_sql_engine(["/bin/ls"]) result = list(engine.execute("SELECT * FROM elf_version_requirements LIMIT 1")) assert len(result) == 1 assert "path" in result[0]