Skip to content

Commit

Permalink
Create duckdb connection during execution (#2684)
Browse files Browse the repository at this point in the history
Signed-off-by: Thomas J. Fan <[email protected]>
  • Loading branch information
thomasjpfan authored Aug 15, 2024
1 parent 03d2301 commit 556dad2
Showing 1 changed file with 19 additions and 12 deletions.
31 changes: 19 additions & 12 deletions plugins/flytekit-duckdb/flytekitplugins/duckdb/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@ def __init__(
inputs: The query parameters to be used while executing the query
"""
self._query = query
# create an in-memory database that's non-persistent
self._con = duckdb.connect(":memory:")

outputs = {"result": StructuredDataset}

super(DuckDBQuery, self).__init__(
Expand All @@ -47,7 +44,9 @@ def __init__(
**kwargs,
)

def _execute_query(self, params: list, query: str, counter: int, multiple_params: bool):
def _execute_query(
self, con: duckdb.DuckDBPyConnection, params: list, query: str, counter: int, multiple_params: bool
):
"""
This method runs the DuckDBQuery.
Expand All @@ -64,28 +63,32 @@ def _execute_query(self, params: list, query: str, counter: int, multiple_params
raise ValueError("Parameter doesn't exist.")
if "insert" in query.lower():
# run executemany disregarding the number of entries to store for an insert query
yield QueryOutput(output=self._con.executemany(query, params[counter]), counter=counter)
yield QueryOutput(output=con.executemany(query, params[counter]), counter=counter)
else:
yield QueryOutput(output=self._con.execute(query, params[counter]), counter=counter)
yield QueryOutput(output=con.execute(query, params[counter]), counter=counter)
else:
if params:
yield QueryOutput(output=self._con.execute(query, params), counter=counter)
yield QueryOutput(output=con.execute(query, params), counter=counter)
else:
raise ValueError("Parameter not specified.")
else:
yield QueryOutput(output=self._con.execute(query), counter=counter)
yield QueryOutput(output=con.execute(query), counter=counter)

def execute(self, **kwargs) -> StructuredDataset:
# TODO: Enable iterative download after adding the functionality to structured dataset code.

# create an in-memory database that's non-persistent
con = duckdb.connect(":memory:")

params = None
for key in self.python_interface.inputs.keys():
val = kwargs.get(key)
if isinstance(val, StructuredDataset):
# register structured dataset
self._con.register(key, val.open(pa.Table).all())
con.register(key, val.open(pa.Table).all())
elif isinstance(val, (pd.DataFrame, pa.Table)):
# register pandas dataframe/arrow table
self._con.register(key, val)
con.register(key, val)
elif isinstance(val, list):
# copy val into params
params = val
Expand All @@ -105,7 +108,11 @@ def execute(self, **kwargs) -> StructuredDataset:
for query in self._query[:-1]:
query_output = next(
self._execute_query(
params=params, query=query, counter=query_output.counter, multiple_params=multiple_params
con=con,
params=params,
query=query,
counter=query_output.counter,
multiple_params=multiple_params,
)
)
final_query = self._query[-1]
Expand All @@ -114,7 +121,7 @@ def execute(self, **kwargs) -> StructuredDataset:
# expecting a SELECT query
dataframe = next(
self._execute_query(
params=params, query=final_query, counter=query_output.counter, multiple_params=multiple_params
con=con, params=params, query=final_query, counter=query_output.counter, multiple_params=multiple_params
)
).output.arrow()

Expand Down

0 comments on commit 556dad2

Please sign in to comment.