From 5f7ece673f9f7361e916fda474cc5e73c03f0f99 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 14 Aug 2024 14:23:43 -0400 Subject: [PATCH] Create duckdb connection during execution Signed-off-by: Thomas J. Fan --- .../flytekitplugins/duckdb/task.py | 31 ++++++++++++------- 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/plugins/flytekit-duckdb/flytekitplugins/duckdb/task.py b/plugins/flytekit-duckdb/flytekitplugins/duckdb/task.py index 71c15481f4..eda750fd33 100644 --- a/plugins/flytekit-duckdb/flytekitplugins/duckdb/task.py +++ b/plugins/flytekit-duckdb/flytekitplugins/duckdb/task.py @@ -34,9 +34,6 @@ def __init__( inputs: The query parameters to be used while executing the query """ self._query = query - # create an in-memory database that's non-persistent - self._con = duckdb.connect(":memory:") - outputs = {"result": StructuredDataset} super(DuckDBQuery, self).__init__( @@ -47,7 +44,9 @@ def __init__( **kwargs, ) - def _execute_query(self, params: list, query: str, counter: int, multiple_params: bool): + def _execute_query( + self, con: duckdb.DuckDBPyConnection, params: list, query: str, counter: int, multiple_params: bool + ): """ This method runs the DuckDBQuery. @@ -64,28 +63,32 @@ def _execute_query(self, params: list, query: str, counter: int, multiple_params raise ValueError("Parameter doesn't exist.") if "insert" in query.lower(): # run executemany disregarding the number of entries to store for an insert query - yield QueryOutput(output=self._con.executemany(query, params[counter]), counter=counter) + yield QueryOutput(output=con.executemany(query, params[counter]), counter=counter) else: - yield QueryOutput(output=self._con.execute(query, params[counter]), counter=counter) + yield QueryOutput(output=con.execute(query, params[counter]), counter=counter) else: if params: - yield QueryOutput(output=self._con.execute(query, params), counter=counter) + yield QueryOutput(output=con.execute(query, params), counter=counter) else: raise ValueError("Parameter not specified.") else: - yield QueryOutput(output=self._con.execute(query), counter=counter) + yield QueryOutput(output=con.execute(query), counter=counter) def execute(self, **kwargs) -> StructuredDataset: # TODO: Enable iterative download after adding the functionality to structured dataset code. + + # create an in-memory database that's non-persistent + con = duckdb.connect(":memory:") + params = None for key in self.python_interface.inputs.keys(): val = kwargs.get(key) if isinstance(val, StructuredDataset): # register structured dataset - self._con.register(key, val.open(pa.Table).all()) + con.register(key, val.open(pa.Table).all()) elif isinstance(val, (pd.DataFrame, pa.Table)): # register pandas dataframe/arrow table - self._con.register(key, val) + con.register(key, val) elif isinstance(val, list): # copy val into params params = val @@ -105,7 +108,11 @@ def execute(self, **kwargs) -> StructuredDataset: for query in self._query[:-1]: query_output = next( self._execute_query( - params=params, query=query, counter=query_output.counter, multiple_params=multiple_params + con=con, + params=params, + query=query, + counter=query_output.counter, + multiple_params=multiple_params, ) ) final_query = self._query[-1] @@ -114,7 +121,7 @@ def execute(self, **kwargs) -> StructuredDataset: # expecting a SELECT query dataframe = next( self._execute_query( - params=params, query=final_query, counter=query_output.counter, multiple_params=multiple_params + con=con, params=params, query=final_query, counter=query_output.counter, multiple_params=multiple_params ) ).output.arrow()