From f2a937d2f25d1f997a066e6ba02acc3c4de676a4 Mon Sep 17 00:00:00 2001 From: Jack Wotherspoon Date: Mon, 13 Jun 2022 16:39:43 -0400 Subject: [PATCH] Support direct TLS connections (i.e. no STARTTLS) (#923) Adding direct_tls param that when equal to True alongside the ssl param being set to a ssl.SSLContext will result in a direct SSL connection being made, skipping STARTTLS implementation. Closes #906 --- asyncpg/connect_utils.py | 21 +++++++++++++++------ asyncpg/connection.py | 6 ++++++ tests/test_connect.py | 7 ++++++- 3 files changed, 27 insertions(+), 7 deletions(-) diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index c09bf5e0..90a61503 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -53,6 +53,7 @@ def parse(cls, sslmode): 'database', 'ssl', 'sslmode', + 'direct_tls', 'connect_timeout', 'server_settings', ]) @@ -258,7 +259,7 @@ def _dot_postgresql_path(filename) -> pathlib.Path: def _parse_connect_dsn_and_args(*, dsn, host, port, user, password, passfile, database, ssl, - connect_timeout, server_settings): + direct_tls, connect_timeout, server_settings): # `auth_hosts` is the version of host information for the purposes # of reading the pgpass file. auth_hosts = None @@ -601,8 +602,8 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, params = _ConnectionParameters( user=user, password=password, database=database, ssl=ssl, - sslmode=sslmode, connect_timeout=connect_timeout, - server_settings=server_settings) + sslmode=sslmode, direct_tls=direct_tls, + connect_timeout=connect_timeout, server_settings=server_settings) return addrs, params @@ -612,7 +613,7 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile, statement_cache_size, max_cached_statement_lifetime, max_cacheable_statement_size, - ssl, server_settings): + ssl, direct_tls, server_settings): local_vars = locals() for var_name in {'max_cacheable_statement_size', @@ -640,8 +641,8 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile, addrs, params = _parse_connect_dsn_and_args( dsn=dsn, host=host, port=port, user=user, password=password, passfile=passfile, ssl=ssl, - database=database, connect_timeout=timeout, - server_settings=server_settings) + direct_tls=direct_tls, database=database, + connect_timeout=timeout, server_settings=server_settings) config = _ClientConfiguration( command_timeout=command_timeout, @@ -812,6 +813,14 @@ async def __connect_addr( if isinstance(addr, str): # UNIX socket connector = loop.create_unix_connection(proto_factory, addr) + + elif params.ssl and params.direct_tls: + # if ssl and direct_tls are given, skip STARTTLS and perform direct + # SSL connection + connector = loop.create_connection( + proto_factory, *addr, ssl=params.ssl + ) + elif params.ssl: connector = _create_ssl_connection( proto_factory, *addr, loop=loop, ssl_context=params.ssl, diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 3914826a..3327360b 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -1789,6 +1789,7 @@ async def connect(dsn=None, *, max_cacheable_statement_size=1024 * 15, command_timeout=None, ssl=None, + direct_tls=False, connection_class=Connection, record_class=protocol.Record, server_settings=None): @@ -1984,6 +1985,10 @@ async def connect(dsn=None, *, ... await con.close() >>> asyncio.run(run()) + :param bool direct_tls: + Pass ``True`` to skip PostgreSQL STARTTLS mode and perform a direct + SSL connection. Must be used alongside ``ssl`` param. + :param dict server_settings: An optional dict of server runtime parameters. Refer to PostgreSQL documentation for @@ -2094,6 +2099,7 @@ async def connect(dsn=None, *, password=password, passfile=passfile, ssl=ssl, + direct_tls=direct_tls, database=database, server_settings=server_settings, command_timeout=command_timeout, diff --git a/tests/test_connect.py b/tests/test_connect.py index d90ad8a4..db7817f6 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -811,7 +811,8 @@ def run_testcase(self, testcase): addrs, params = connect_utils._parse_connect_dsn_and_args( dsn=dsn, host=host, port=port, user=user, password=password, passfile=passfile, database=database, ssl=sslmode, - connect_timeout=None, server_settings=server_settings) + direct_tls=False, connect_timeout=None, + server_settings=server_settings) params = { k: v for k, v in params._asdict().items() @@ -829,6 +830,10 @@ def run_testcase(self, testcase): # unless explicitly tested for. params.pop('ssl', None) params.pop('sslmode', None) + if 'direct_tls' not in expected[1]: + # Avoid the hassle of specifying direct_tls + # unless explicitly tested for + params.pop('direct_tls', False) self.assertEqual(expected, result, 'Testcase: {}'.format(testcase))