From 7e6edff4b2b4df5f540d3ee70d41cf2187b30fdb Mon Sep 17 00:00:00 2001 From: Ashhar Hasan Date: Fri, 30 Sep 2022 23:00:02 +0530 Subject: [PATCH] Fix prepared statement handling The prepared statement handling code assumed that for each query we'll always receive some non-empty response even after the initial response which is not a valid assumption. This assumption worked because earlier Trino used to send empty fake results even for queries which don't return results (like PREPARE and DEALLOCATE) but is now invalid with trinodb/trino@bc794cd49a616fa95eecfcc384b761f67176240b. The other problem with the code was that it leaked HTTP protocol details into dbapi.py and worked around it by keeping a deep copy of the request object from the PREPARE execution and re-using it for the actual query execution. The new code fixes both issues by processing the prepared statement headers as they are received and storing the resulting set of active prepared statements on the ClientSession object. The ClientSession's set of prepared statements is then rendered into the prepared statement request header in TrinoRequest. Since the ClientSession is created and reused for the entire Connection this also means that we can now actually implement re-use of prepared statements within a single Connection. --- tests/integration/test_dbapi_integration.py | 35 +++++++++ tests/unit/test_client.py | 15 ---- trino/client.py | 49 +++++++++--- trino/dbapi.py | 85 +++++---------------- trino/exceptions.py | 16 ---- 5 files changed, 90 insertions(+), 110 deletions(-) diff --git a/tests/integration/test_dbapi_integration.py b/tests/integration/test_dbapi_integration.py index 99d3ae83..06a869dd 100644 --- a/tests/integration/test_dbapi_integration.py +++ b/tests/integration/test_dbapi_integration.py @@ -1067,3 +1067,38 @@ def test_set_role_trino_351(run_trino): cur.execute("SET ROLE ALL") cur.fetchall() assert cur._request._client_session.role == "tpch=ALL" + + +def test_prepared_statements(run_trino): + _, host, port = run_trino + + trino_connection = trino.dbapi.Connection( + host=host, port=port, user="test", catalog="tpch", + ) + cur = trino_connection.cursor() + + # Implicit prepared statements must work and deallocate statements on finish + assert cur._request._client_session.prepared_statements == {} + cur.execute('SELECT count(1) FROM tpch.tiny.nation WHERE nationkey = ?', (1,)) + result = cur.fetchall() + assert result[0][0] == 1 + assert cur._request._client_session.prepared_statements == {} + + # Explicit prepared statements must also work + cur.execute('PREPARE test_prepared_statements FROM SELECT count(1) FROM tpch.tiny.nation WHERE nationkey = ?') + cur.fetchall() + assert 'test_prepared_statements' in cur._request._client_session.prepared_statements + cur.execute('EXECUTE test_prepared_statements USING 1') + cur.fetchall() + assert result[0][0] == 1 + + # An implicit prepared statement must not deallocate explicit statements + cur.execute('SELECT count(1) FROM tpch.tiny.nation WHERE nationkey = ?', (1,)) + result = cur.fetchall() + assert result[0][0] == 1 + assert 'test_prepared_statements' in cur._request._client_session.prepared_statements + + assert 'test_prepared_statements' in cur._request._client_session.prepared_statements + cur.execute('DEALLOCATE PREPARE test_prepared_statements') + cur.fetchall() + assert cur._request._client_session.prepared_statements == {} diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index d7da4983..4100c39e 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -881,21 +881,6 @@ def __call__(self, *args, **kwargs): return http_response -def test_trino_result_response_headers(): - """ - Validates that the `TrinoResult.response_headers` property returns the - headers associated to the TrinoQuery instance provided to the `TrinoResult` - class. - """ - mock_trino_query = mock.Mock(respone_headers={ - 'X-Trino-Fake-1': 'one', - 'X-Trino-Fake-2': 'two', - }) - - result = TrinoResult(query=mock_trino_query, rows=[]) - assert result.response_headers == mock_trino_query.response_headers - - def test_trino_query_response_headers(sample_get_response_data): """ Validates that the `TrinoQuery.execute` function can take addtional headers diff --git a/trino/client.py b/trino/client.py index b973ff6b..9b319cf2 100644 --- a/trino/client.py +++ b/trino/client.py @@ -125,6 +125,7 @@ def __init__( self._extra_credential = extra_credential self._client_tags = client_tags self._role = role + self._prepared_statements: Dict[str, str] = {} self._object_lock = threading.Lock() @property @@ -206,6 +207,15 @@ def role(self, role): with self._object_lock: self._role = role + @property + def prepared_statements(self): + return self._prepared_statements + + @prepared_statements.setter + def prepared_statements(self, prepared_statements): + with self._object_lock: + self._prepared_statements = prepared_statements + def get_header_values(headers, header): return [val.strip() for val in headers[header].split(",")] @@ -219,6 +229,14 @@ def get_session_property_values(headers, header): ] +def get_prepared_statement_values(headers, header): + kvs = get_header_values(headers, header) + return [ + (k.strip(), urllib.parse.unquote_plus(v.strip())) + for k, v in (kv.split("=", 1) for kv in kvs) + ] + + class TrinoStatus(object): def __init__(self, id, stats, warnings, info_uri, next_uri, update_type, rows, columns=None): self.id = id @@ -392,6 +410,13 @@ def http_headers(self) -> Dict[str, str]: for name, value in self._client_session.properties.items() ) + if len(self._client_session.prepared_statements) != 0: + # ``name`` must not contain ``=`` + headers[constants.HEADER_PREPARED_STATEMENT] = ",".join( + "{}={}".format(name, urllib.parse.quote_plus(statement)) + for name, statement in self._client_session.prepared_statements.items() + ) + # merge custom http headers for key in self._client_session.headers: if key in headers.keys(): @@ -556,6 +581,18 @@ def process(self, http_response) -> TrinoStatus: if constants.HEADER_SET_ROLE in http_response.headers: self._client_session.role = http_response.headers[constants.HEADER_SET_ROLE] + if constants.HEADER_ADDED_PREPARE in http_response.headers: + for name, statement in get_prepared_statement_values( + http_response.headers, constants.HEADER_ADDED_PREPARE + ): + self._client_session.prepared_statements[name] = statement + + if constants.HEADER_DEALLOCATED_PREPARE in http_response.headers: + for name in get_header_values( + http_response.headers, constants.HEADER_DEALLOCATED_PREPARE + ): + self._client_session.prepared_statements.pop(name) + self._next_uri = response.get("nextUri") return TrinoStatus( @@ -622,10 +659,6 @@ def __iter__(self): self._rows = next_rows - @property - def response_headers(self): - return self._query.response_headers - class TrinoQuery(object): """Represent the execution of a SQL statement by Trino.""" @@ -648,7 +681,6 @@ def __init__( self._update_type = None self._sql = sql self._result: Optional[TrinoResult] = None - self._response_headers = None self._experimental_python_types = experimental_python_types self._row_mapper: Optional[RowMapper] = None @@ -705,7 +737,7 @@ def execute(self, additional_http_headers=None) -> TrinoResult: rows = self._row_mapper.map(status.rows) if self._row_mapper else status.rows self._result = TrinoResult(self, rows) - # Execute should block until at least one row is received + # Execute should block until at least one row is received or query is finished or cancelled while not self.finished and not self.cancelled and len(self._result.rows) == 0: self._result.rows += self.fetch() return self._result @@ -725,7 +757,6 @@ def fetch(self) -> List[List[Any]]: status = self._request.process(response) self._update_state(status) logger.debug(status) - self._response_headers = response.headers if status.next_uri is None: self._finished = True @@ -763,10 +794,6 @@ def finished(self) -> bool: def cancelled(self) -> bool: return self._cancelled - @property - def response_headers(self): - return self._response_headers - def _retry_with(handle_retry, handled_exceptions, conditions, max_attempts): def wrapper(func): diff --git a/trino/dbapi.py b/trino/dbapi.py index 70fb43bb..8adc4d18 100644 --- a/trino/dbapi.py +++ b/trino/dbapi.py @@ -20,7 +20,6 @@ from decimal import Decimal from typing import Any, List, Optional # NOQA for mypy types -import copy import uuid import datetime import math @@ -301,52 +300,25 @@ def setinputsizes(self, sizes): def setoutputsize(self, size, column): raise trino.exceptions.NotSupportedError - def _prepare_statement(self, operation, statement_name): + def _prepare_statement(self, statement: str, name: str) -> None: """ - Prepends the given `operation` with "PREPARE FROM" and - executes as a prepare statement. + Registers a prepared statement for the provided `operation` with the + `name` assigned to it. - :param operation: sql to be executed. - :param statement_name: name that will be assigned to the prepare - statement. - - :raises trino.exceptions.FailedToObtainAddedPrepareHeader: Error raised - when unable to find the 'X-Trino-Added-Prepare' for the PREPARE - statement request. - - :return: string representing the value of the 'X-Trino-Added-Prepare' - header. + :param statement: sql to be executed. + :param name: name that will be assigned to the prepared statement. """ - sql = 'PREPARE {statement_name} FROM {operation}'.format( - statement_name=statement_name, - operation=operation - ) - - # Send prepare statement. Copy the _request object to avoid polluting the - # one that is going to be used to execute the actual operation. - query = trino.client.TrinoQuery(copy.deepcopy(self._request), sql=sql, + sql = f"PREPARE {name} FROM {statement}" + query = trino.client.TrinoQuery(self.connection._create_request(), sql=sql, experimental_python_types=self._experimental_pyton_types) - result = query.execute() + query.execute() - # Iterate until the 'X-Trino-Added-Prepare' header is found or - # until there are no more results - for _ in result: - response_headers = result.response_headers - - if constants.HEADER_ADDED_PREPARE in response_headers: - return response_headers[constants.HEADER_ADDED_PREPARE] - - raise trino.exceptions.FailedToObtainAddedPrepareHeader - - def _get_added_prepare_statement_trino_query( + def _execute_prepared_statement( self, statement_name, params ): sql = 'EXECUTE ' + statement_name + ' USING ' + ','.join(map(self._format_prepared_param, params)) - - # No need to deepcopy _request here because this is the actual request - # operation return trino.client.TrinoQuery(self._request, sql=sql, experimental_python_types=self._experimental_pyton_types) def _format_prepared_param(self, param): @@ -422,28 +394,11 @@ def _format_prepared_param(self, param): raise trino.exceptions.NotSupportedError("Query parameter of type '%s' is not supported." % type(param)) - def _deallocate_prepare_statement(self, added_prepare_header, statement_name): + def _deallocate_prepared_statement(self, statement_name: str) -> None: sql = 'DEALLOCATE PREPARE ' + statement_name - - # Send deallocate statement. Copy the _request object to avoid poluting the - # one that is going to be used to execute the actual operation. - query = trino.client.TrinoQuery(copy.deepcopy(self._request), sql=sql, + query = trino.client.TrinoQuery(self.connection._create_request(), sql=sql, experimental_python_types=self._experimental_pyton_types) - result = query.execute( - additional_http_headers={ - constants.HEADER_PREPARED_STATEMENT: added_prepare_header - } - ) - - # Iterate until the 'X-Trino-Deallocated-Prepare' header is found or - # until there are no more results - for _ in result: - response_headers = result.response_headers - - if constants.HEADER_DEALLOCATED_PREPARE in response_headers: - return response_headers[constants.HEADER_DEALLOCATED_PREPARE] - - raise trino.exceptions.FailedToObtainDeallocatedPrepareHeader + query.execute() def _generate_unique_statement_name(self): return 'st_' + uuid.uuid4().hex.replace('-', '') @@ -456,27 +411,21 @@ def execute(self, operation, params=None): ) statement_name = self._generate_unique_statement_name() - # Send prepare statement - added_prepare_header = self._prepare_statement( - operation, statement_name - ) + self._prepare_statement(operation, statement_name) try: # Send execute statement and assign the return value to `results` # as it will be returned by the function - self._query = self._get_added_prepare_statement_trino_query( + self._query = self._execute_prepared_statement( statement_name, params ) - result = self._query.execute( - additional_http_headers={ - constants.HEADER_PREPARED_STATEMENT: added_prepare_header - } - ) + result = self._query.execute() finally: # Send deallocate statement # At this point the query can be deallocated since it has already # been executed - self._deallocate_prepare_statement(added_prepare_header, statement_name) + # TODO: Consider caching prepared statements if requested by caller + self._deallocate_prepared_statement(statement_name) else: self._query = trino.client.TrinoQuery(self._request, sql=operation, diff --git a/trino/exceptions.py b/trino/exceptions.py index 86708fd0..bfd4fef4 100644 --- a/trino/exceptions.py +++ b/trino/exceptions.py @@ -134,22 +134,6 @@ class TrinoUserError(TrinoQueryError, ProgrammingError): pass -class FailedToObtainAddedPrepareHeader(Error): - """ - Raise this exception when unable to find the 'X-Trino-Added-Prepare' - header in the response of a PREPARE statement request. - """ - pass - - -class FailedToObtainDeallocatedPrepareHeader(Error): - """ - Raise this exception when unable to find the 'X-Trino-Deallocated-Prepare' - header in the response of a DEALLOCATED statement request. - """ - pass - - # client module errors class HttpError(Exception): pass