From b2ff5ffa2a18d0485c13a799af406d8da352819c Mon Sep 17 00:00:00 2001 From: Ashhar Hasan Date: Fri, 30 Sep 2022 23:00:02 +0530 Subject: [PATCH] Fix prepared statement handling The prepared statement handling code assumed that for each query we'll always receive some non-empty response even after the initial response which is not a valid assumption. This assumption worked because earlier Trino used to send empty fake results even for queries which don't return results (like PREPARE and DEALLOCATE) but is now invalid with trinodb/trino@bc794cd49a616fa95eecfcc384b761f67176240b. The other problem with the code was that it leaked HTTP protocol details into dbapi.py and worked around it by keeping a deep copy of the request object from the PREPARE execution and re-using it for the actual query execution. The new code fixes both issues by processing the prepared statement headers as they are received and storing the resulting set of active prepared statements on the ClientSession object. The ClientSession's set of prepared statements is then rendered into the prepared statement request header in TrinoRequest. Since the ClientSession is created and reused for the entire Connection this also means that we can now actually implement re-use of prepared statements within a single Connection. --- tests/unit/test_client.py | 3 +++ trino/client.py | 45 ++++++++++++++++++++++++++++++--------- 2 files changed, 38 insertions(+), 10 deletions(-) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index d7da4983..cf58164d 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -881,6 +881,9 @@ def __call__(self, *args, **kwargs): return http_response +# TODO: What was this test added to verify? +# test_trino_query_response_headers already verifies that custom headers can be passed. +# Possibly we can remove this. def test_trino_result_response_headers(): """ Validates that the `TrinoResult.response_headers` property returns the diff --git a/trino/client.py b/trino/client.py index b973ff6b..58d17c4e 100644 --- a/trino/client.py +++ b/trino/client.py @@ -125,6 +125,7 @@ def __init__( self._extra_credential = extra_credential self._client_tags = client_tags self._role = role + self._prepared_statements = {} self._object_lock = threading.Lock() @property @@ -206,6 +207,15 @@ def role(self, role): with self._object_lock: self._role = role + @property + def prepared_statements(self): + return self._prepared_statements + + @prepared_statements.setter + def prepared_statements(self, prepared_statements): + with self._object_lock: + self._prepared_statements = prepared_statements + def get_header_values(headers, header): return [val.strip() for val in headers[header].split(",")] @@ -218,6 +228,12 @@ def get_session_property_values(headers, header): for k, v in (kv.split("=", 1) for kv in kvs) ] +def get_prepared_statement_values(headers, header): + kvs = get_header_values(headers, header) + return [ + (k.strip(), urllib.parse.unquote(v.strip())) + for k, v in (kv.split("=", 1) for kv in kvs) + ] class TrinoStatus(object): def __init__(self, id, stats, warnings, info_uri, next_uri, update_type, rows, columns=None): @@ -392,6 +408,13 @@ def http_headers(self) -> Dict[str, str]: for name, value in self._client_session.properties.items() ) + if len(self._client_session.prepared_statements) != 0: + # ``name`` must not contain ``==`` + headers[constants.HEADER_PREPARED_STATEMENT] = ",".join( + "{}={}".format(name, urllib.parse.quote(statement)) + for name, statement in self._client_session.prepared_statements + ) + # merge custom http headers for key in self._client_session.headers: if key in headers.keys(): @@ -556,6 +579,18 @@ def process(self, http_response) -> TrinoStatus: if constants.HEADER_SET_ROLE in http_response.headers: self._client_session.role = http_response.headers[constants.HEADER_SET_ROLE] + if constants.HEADER_ADDED_PREPARE in http_response.headers: + for name, statement in get_prepared_statement_values( + http_response.headers, constants.HEADER_ADDED_PREPARE + ): + self._client_session.prepared_statements[name] = statement + + if constants.HEADER_DEALLOCATED_PREPARE in http_response.headers: + for name in get_header_values( + http_response.headers, constants.HEADER_DEALLOCATED_PREPARE + ): + self._client_session.prepared_statements.pop(name) + self._next_uri = response.get("nextUri") return TrinoStatus( @@ -622,10 +657,6 @@ def __iter__(self): self._rows = next_rows - @property - def response_headers(self): - return self._query.response_headers - class TrinoQuery(object): """Represent the execution of a SQL statement by Trino.""" @@ -648,7 +679,6 @@ def __init__( self._update_type = None self._sql = sql self._result: Optional[TrinoResult] = None - self._response_headers = None self._experimental_python_types = experimental_python_types self._row_mapper: Optional[RowMapper] = None @@ -725,7 +755,6 @@ def fetch(self) -> List[List[Any]]: status = self._request.process(response) self._update_state(status) logger.debug(status) - self._response_headers = response.headers if status.next_uri is None: self._finished = True @@ -763,10 +792,6 @@ def finished(self) -> bool: def cancelled(self) -> bool: return self._cancelled - @property - def response_headers(self): - return self._response_headers - def _retry_with(handle_retry, handled_exceptions, conditions, max_attempts): def wrapper(func):