-
Notifications
You must be signed in to change notification settings - Fork 104
/
Copy pathsql_server_connection_manager.py
158 lines (122 loc) · 5.09 KB
/
sql_server_connection_manager.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from typing import Callable, Mapping
import pyodbc
from azure.core.credentials import AccessToken
from azure.identity import ClientSecretCredential, ManagedIdentityCredential
from dbt.adapters.fabric import FabricConnectionManager
from dbt.adapters.fabric.fabric_connection_manager import (
AZURE_AUTH_FUNCTIONS as AZURE_AUTH_FUNCTIONS_FABRIC,
)
from dbt.adapters.fabric.fabric_connection_manager import (
AZURE_CREDENTIAL_SCOPE,
bool_to_connection_string_arg,
get_pyodbc_attrs_before,
)
from dbt.contracts.connection import Connection, ConnectionState
from dbt.events import AdapterLogger
from dbt.adapters.sqlserver import __version__
from dbt.adapters.sqlserver.sql_server_credentials import SQLServerCredentials
AZURE_AUTH_FUNCTION_TYPE = Callable[[SQLServerCredentials], AccessToken]
logger = AdapterLogger("SQLServer")
def get_msi_access_token(credentials: SQLServerCredentials) -> AccessToken:
"""
Get an Azure access token from the system's managed identity
Parameters
-----------
credentials: SQLServerCredentials
Credentials.
Returns
-------
out : AccessToken
The access token.
"""
token = ManagedIdentityCredential().get_token(AZURE_CREDENTIAL_SCOPE)
return token
def get_sp_access_token(credentials: SQLServerCredentials) -> AccessToken:
"""
Get an Azure access token using the SP credentials.
Parameters
----------
credentials : SQLServerCredentials
Credentials.
Returns
-------
out : AccessToken
The access token.
"""
token = ClientSecretCredential(
str(credentials.tenant_id),
str(credentials.client_id),
str(credentials.client_secret),
).get_token(AZURE_CREDENTIAL_SCOPE)
return token
AZURE_AUTH_FUNCTIONS: Mapping[str, AZURE_AUTH_FUNCTION_TYPE] = {
**AZURE_AUTH_FUNCTIONS_FABRIC,
"serviceprincipal": get_sp_access_token,
"msi": get_msi_access_token,
}
class SQLServerConnectionManager(FabricConnectionManager):
TYPE = "sqlserver"
@classmethod
def open(cls, connection: Connection) -> Connection:
if connection.state == ConnectionState.OPEN:
logger.debug("Connection is already open, skipping open.")
return connection
credentials = cls.get_credentials(connection.credentials)
if credentials.authentication != "sql":
return super().open(connection)
# sql login authentication
con_str = [f"DRIVER={{{credentials.driver}}}"]
if "\\" in credentials.host:
# If there is a backslash \ in the host name, the host is a
# SQL Server named instance. In this case then port number has to be omitted.
con_str.append(f"SERVER={credentials.host}")
else:
con_str.append(f"SERVER={credentials.host},{credentials.port}")
con_str.append(f"Database={credentials.database}")
assert credentials.authentication is not None
con_str.append(f"UID={{{credentials.UID}}}")
con_str.append(f"PWD={{{credentials.PWD}}}")
# https://docs.microsoft.com/en-us/sql/relational-databases/native-client/features/using-encryption-without-validation?view=sql-server-ver15
assert credentials.encrypt is not None
assert credentials.trust_cert is not None
con_str.append(bool_to_connection_string_arg("encrypt", credentials.encrypt))
con_str.append(
bool_to_connection_string_arg("TrustServerCertificate", credentials.trust_cert)
)
plugin_version = __version__.version
application_name = f"dbt-{credentials.type}/{plugin_version}"
con_str.append(f"Application Name={application_name}")
con_str_concat = ";".join(con_str)
index = []
for i, elem in enumerate(con_str):
if "pwd=" in elem.lower():
index.append(i)
if len(index) != 0:
con_str[index[0]] = "PWD=***"
con_str_display = ";".join(con_str)
retryable_exceptions = [ # https://github.com/mkleehammer/pyodbc/wiki/Exceptions
pyodbc.InternalError, # not used according to docs, but defined in PEP-249
pyodbc.OperationalError,
]
if credentials.authentication.lower() in AZURE_AUTH_FUNCTIONS:
# Temporary login/token errors fall into this category when using AAD
retryable_exceptions.append(pyodbc.InterfaceError)
def connect():
logger.debug(f"Using connection string: {con_str_display}")
attrs_before = get_pyodbc_attrs_before(credentials)
handle = pyodbc.connect(
con_str_concat,
attrs_before=attrs_before,
autocommit=True,
timeout=credentials.login_timeout,
)
handle.timeout = credentials.query_timeout
logger.debug(f"Connected to db: {credentials.database}")
return handle
return cls.retry_connection(
connection,
connect=connect,
logger=logger,
retry_limit=credentials.retries,
retryable_exceptions=retryable_exceptions,
)