Skip to content

Commit

Permalink
Fixes to sql execution code
Browse files Browse the repository at this point in the history
* catch StopExecution errors before we call get_description for when
  there are 0 rows
* create generator with fixed column names for when the table is empty
* added unit tests for the two above
  • Loading branch information
fzakaria committed Sep 25, 2023
1 parent 2c97934 commit 5b58e6e
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 18 deletions.
58 changes: 44 additions & 14 deletions sqlelf/elf.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,12 @@ def __call__(self) -> Iterator[dict[str, Any]]:
return self.callable()

@staticmethod
def make_generator(generator: Callable[[], Iterator[dict[str, Any]]]) -> Generator:
def make_generator(
columns: list[str], generator: Callable[[], Iterator[dict[str, Any]]]
) -> Generator:
"""Create a generator from a callable that returns
an iterator of dictionaries."""
columns, column_access = apsw.ext.get_column_names(next(generator()))
return Generator(columns, column_access, generator)
return Generator(columns, apsw.ext.VTColumnAccess.By_Name, generator)


def make_dynamic_entries_generator(binaries: list[lief.Binary]) -> Generator:
Expand All @@ -45,7 +46,10 @@ def dynamic_entries_generator() -> Iterator[dict[str, Any]]:
for entry in binary.dynamic_entries: # type: ignore
yield {"path": binary_name, "tag": entry.tag.name, "value": entry.value}

return Generator.make_generator(dynamic_entries_generator)
return Generator.make_generator(
["path", "tag", "value"],
dynamic_entries_generator,
)


def make_headers_generator(binaries: list[lief.Binary]) -> Generator:
Expand All @@ -61,7 +65,10 @@ def headers_generator() -> Iterator[dict[str, Any]]:
"entry": binary.header.entrypoint,
}

return Generator.make_generator(headers_generator)
return Generator.make_generator(
["path", "type", "machine", "version", "entry"],
headers_generator,
)


def make_instructions_generator(binaries: list[lief.Binary]) -> Generator:
Expand Down Expand Up @@ -98,7 +105,10 @@ def instructions_generator() -> Iterator[dict[str, Any]]:
"operands": op_str,
}

return Generator.make_generator(instructions_generator)
return Generator.make_generator(
["path", "section", "mnemonic", "address", "operands"],
instructions_generator,
)


def mode(binary: lief.Binary) -> int:
Expand Down Expand Up @@ -131,7 +141,10 @@ def sections_generator() -> Iterator[dict[str, Any]]:
"content": bytes(section.content),
}

return Generator.make_generator(sections_generator)
return Generator.make_generator(
["path", "name", "offset", "size", "type", "content"],
sections_generator,
)


def coerce_section_name(name: str | None) -> str | None:
Expand Down Expand Up @@ -165,7 +178,10 @@ def strings_generator() -> Iterator[dict[str, Any]]:
for string in str(strtab.content[1:-1], "utf-8").split("\x00"):
yield {"path": binary_name, "section": strtab.name, "value": string}

return Generator.make_generator(strings_generator)
return Generator.make_generator(
["path", "section", "value"],
strings_generator,
)


def make_symbols_generator(binaries: list[lief.Binary]) -> Generator:
Expand Down Expand Up @@ -215,7 +231,21 @@ def symbols_generator() -> Iterator[dict[str, Any]]:
"value": symbol.value,
}

return Generator.make_generator(symbols_generator)
return Generator.make_generator(
[
"path",
"name",
"demangled_name",
"imported",
"exported",
"section",
"size",
"version",
"type",
"value",
],
symbols_generator,
)


def make_version_requirements(binaries: list[lief.Binary]) -> Generator:
Expand All @@ -239,7 +269,9 @@ def version_requirements_generator() -> Iterator[dict[str, Any]]:
"name": aux_requirement.name,
}

return Generator.make_generator(version_requirements_generator)
return Generator.make_generator(
["path", "file", "name"], version_requirements_generator
)


def make_version_definitions(binaries: list[lief.Binary]) -> Generator:
Expand All @@ -263,10 +295,8 @@ def version_definitions_generator() -> Iterator[dict[str, Any]]:
"flags": flags,
}

return Generator(
["path", "name", "flags"],
apsw.ext.VTColumnAccess.By_Name,
version_definitions_generator,
return Generator.make_generator(
["path", "name", "flags"], version_definitions_generator
)


Expand Down
12 changes: 8 additions & 4 deletions sqlelf/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,14 @@ def execute_raw(self, sql: str) -> apsw.Cursor:

def execute(self, sql: str) -> Iterator[dict[str, Any]]:
cursor = self.execute_raw(sql)
description = cursor.getdescription()
column_names = [n for n, _ in description]
for row in cursor:
yield dict(zip(column_names, row))
try:
description = cursor.getdescription()
column_names = [n for n, _ in description]
for row in cursor:
yield dict(zip(column_names, row))
except apsw.ExecutionCompleteError:
# This can happen if we LIMIT 0 or there are no results
pass


def find_libraries(binary: lief.Binary) -> Dict[str, str]:
Expand Down
7 changes: 7 additions & 0 deletions tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,10 @@ def test_simple_select_version_requirements() -> None:
assert "path" in result[0]
assert "file" in result[0]
assert "name" in result[0]


def test_select_zero_rows() -> None:
# TODO(fzakaria): Figure out a better binary to be doing that we control
engine = sql.make_sql_engine(["/bin/ls"])
result = list(engine.execute("SELECT * FROM elf_headers LIMIT 0"))
assert len(result) == 0

0 comments on commit 5b58e6e

Please sign in to comment.