Skip to content

Commit

Permalink
Support direct TLS connections (i.e. no STARTTLS) (#923)
Browse files Browse the repository at this point in the history
Adding direct_tls param that when equal to True alongside the ssl param being set to a ssl.SSLContext will result in a direct SSL connection being made, skipping STARTTLS implementation.

Closes #906
  • Loading branch information
jackwotherspoon authored Jun 13, 2022
1 parent bd19262 commit f2a937d
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 7 deletions.
21 changes: 15 additions & 6 deletions asyncpg/connect_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def parse(cls, sslmode):
'database',
'ssl',
'sslmode',
'direct_tls',
'connect_timeout',
'server_settings',
])
Expand Down Expand Up @@ -258,7 +259,7 @@ def _dot_postgresql_path(filename) -> pathlib.Path:

def _parse_connect_dsn_and_args(*, dsn, host, port, user,
password, passfile, database, ssl,
connect_timeout, server_settings):
direct_tls, connect_timeout, server_settings):
# `auth_hosts` is the version of host information for the purposes
# of reading the pgpass file.
auth_hosts = None
Expand Down Expand Up @@ -601,8 +602,8 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,

params = _ConnectionParameters(
user=user, password=password, database=database, ssl=ssl,
sslmode=sslmode, connect_timeout=connect_timeout,
server_settings=server_settings)
sslmode=sslmode, direct_tls=direct_tls,
connect_timeout=connect_timeout, server_settings=server_settings)

return addrs, params

Expand All @@ -612,7 +613,7 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
statement_cache_size,
max_cached_statement_lifetime,
max_cacheable_statement_size,
ssl, server_settings):
ssl, direct_tls, server_settings):

local_vars = locals()
for var_name in {'max_cacheable_statement_size',
Expand Down Expand Up @@ -640,8 +641,8 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
addrs, params = _parse_connect_dsn_and_args(
dsn=dsn, host=host, port=port, user=user,
password=password, passfile=passfile, ssl=ssl,
database=database, connect_timeout=timeout,
server_settings=server_settings)
direct_tls=direct_tls, database=database,
connect_timeout=timeout, server_settings=server_settings)

config = _ClientConfiguration(
command_timeout=command_timeout,
Expand Down Expand Up @@ -812,6 +813,14 @@ async def __connect_addr(
if isinstance(addr, str):
# UNIX socket
connector = loop.create_unix_connection(proto_factory, addr)

elif params.ssl and params.direct_tls:
# if ssl and direct_tls are given, skip STARTTLS and perform direct
# SSL connection
connector = loop.create_connection(
proto_factory, *addr, ssl=params.ssl
)

elif params.ssl:
connector = _create_ssl_connection(
proto_factory, *addr, loop=loop, ssl_context=params.ssl,
Expand Down
6 changes: 6 additions & 0 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1789,6 +1789,7 @@ async def connect(dsn=None, *,
max_cacheable_statement_size=1024 * 15,
command_timeout=None,
ssl=None,
direct_tls=False,
connection_class=Connection,
record_class=protocol.Record,
server_settings=None):
Expand Down Expand Up @@ -1984,6 +1985,10 @@ async def connect(dsn=None, *,
... await con.close()
>>> asyncio.run(run())
:param bool direct_tls:
Pass ``True`` to skip PostgreSQL STARTTLS mode and perform a direct
SSL connection. Must be used alongside ``ssl`` param.
:param dict server_settings:
An optional dict of server runtime parameters. Refer to
PostgreSQL documentation for
Expand Down Expand Up @@ -2094,6 +2099,7 @@ async def connect(dsn=None, *,
password=password,
passfile=passfile,
ssl=ssl,
direct_tls=direct_tls,
database=database,
server_settings=server_settings,
command_timeout=command_timeout,
Expand Down
7 changes: 6 additions & 1 deletion tests/test_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,7 +811,8 @@ def run_testcase(self, testcase):
addrs, params = connect_utils._parse_connect_dsn_and_args(
dsn=dsn, host=host, port=port, user=user, password=password,
passfile=passfile, database=database, ssl=sslmode,
connect_timeout=None, server_settings=server_settings)
direct_tls=False, connect_timeout=None,
server_settings=server_settings)

params = {
k: v for k, v in params._asdict().items()
Expand All @@ -829,6 +830,10 @@ def run_testcase(self, testcase):
# unless explicitly tested for.
params.pop('ssl', None)
params.pop('sslmode', None)
if 'direct_tls' not in expected[1]:
# Avoid the hassle of specifying direct_tls
# unless explicitly tested for
params.pop('direct_tls', False)

self.assertEqual(expected, result, 'Testcase: {}'.format(testcase))

Expand Down

0 comments on commit f2a937d

Please sign in to comment.