From 96c1f8acc10c2696d0fdd72f5c7b7607282501a1 Mon Sep 17 00:00:00 2001 From: Farid Zakaria Date: Mon, 25 Sep 2023 18:40:07 +0000 Subject: [PATCH] add binding support to SQLEngine execute * added unit test --- sqlelf/sql.py | 14 +++++++++----- tests/test_sql.py | 19 +++++++++++++++++++ 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/sqlelf/sql.py b/sqlelf/sql.py index 81bf926..6d4d134 100644 --- a/sqlelf/sql.py +++ b/sqlelf/sql.py @@ -3,7 +3,7 @@ import sys from collections import OrderedDict from dataclasses import dataclass -from typing import Any, Dict, Iterator, TextIO +from typing import Any, Dict, Iterator, Optional, TextIO import apsw import apsw.shell @@ -22,11 +22,15 @@ def shell(self, stdin: TextIO = sys.stdin) -> apsw.shell.Shell: shell.command_prompt(["sqlelf> "]) # type: ignore[no-untyped-call] return shell - def execute_raw(self, sql: str) -> apsw.Cursor: - return self.connection.execute(sql) + def execute_raw( + self, sql: str, bindings: Optional["apsw.Bindings"] = None + ) -> apsw.Cursor: + return self.connection.execute(sql, bindings=bindings) - def execute(self, sql: str) -> Iterator[dict[str, Any]]: - cursor = self.execute_raw(sql) + def execute( + self, sql: str, bindings: Optional["apsw.Bindings"] = None + ) -> Iterator[dict[str, Any]]: + cursor = self.execute_raw(sql, bindings=bindings) try: description = cursor.getdescription() column_names = [n for n, _ in description] diff --git a/tests/test_sql.py b/tests/test_sql.py index 7aac94b..eefd168 100644 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -72,3 +72,22 @@ def test_non_existent_file() -> None: engine = sql.make_sql_engine(["/doesnotexist"]) result = list(engine.execute("SELECT * FROM elf_headers LIMIT 1")) assert len(result) == 0 + + +def test_select_with_bindings() -> None: + engine = sql.make_sql_engine(["/bin/ls", "/bin/cat"]) + result = list( + engine.execute( + """ + SELECT * FROM elf_version_requirements + WHERE path = :path + LIMIT 1 + """, + {"path": "/bin/ls"}, + ) + ) + assert len(result) == 1 + assert "path" in result[0] + assert result[0]["path"] == "/bin/ls" + assert "file" in result[0] + assert "name" in result[0]