Skip to content

Commit

Permalink
Modify RedshiftConnectionManager to extend from SQLConnectionManager,…
Browse files Browse the repository at this point in the history
… migrate from psycopg2 to redshift python connector (#251)

* Change RedshiftConnectionManager to extend from SQLConnectionManager, define a _get_connect_method method to leverage Redshift python connector to retrieve the connect method

* Add/fix unit tests, create RedshiftConnectMethodFactory to vend connect_method

* Fix _connection_keys to mimic PostgresConnectionManager

* Remove unneeded functions for tmp_cluster_creds and env_var creds auth due to in-built support in Redshift Python Connector

* Resolve some TODOs

* Fix references to old exceptions, add changelog

* Fix errors with functional tests by overriding add_query & execute and modifying multi statement execution

* Attempt to fix integration tests by adding `valid_incremental_strategies` in impl.py

* Fix unit tests

* Attempt to fix integration tests

* add unit tests for execute

* add unit tests for add_query

* make get_connection_method work with serverless

* add unit tests for serverless iam connections

* add redshift connector version, remove sslmode, connection time out, role, application_name

* change redshift_connector version

---------

Co-authored-by: jiezhec <[email protected]>
Co-authored-by: colin-rogers-dbt <[email protected]>
  • Loading branch information
3 people authored Feb 15, 2023
1 parent e0598b8 commit 2cc47bb
Show file tree
Hide file tree
Showing 5 changed files with 395 additions and 235 deletions.
8 changes: 8 additions & 0 deletions .changes/unreleased/Under the Hood-20230118-071542.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
kind: Under the Hood
body: Replace psycopg2 connector with Redshift python connector when connecting to
Redshift
time: 2023-01-18T07:15:42.183304-08:00
custom:
Author: sathiish-kumar
Issue: "219"
PR: "251"
272 changes: 191 additions & 81 deletions dbt/adapters/redshift/connections.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
import re
from multiprocessing import Lock
from contextlib import contextmanager
from typing import NewType
from typing import NewType, Tuple

from dbt.adapters.postgres import PostgresConnectionManager
from dbt.adapters.postgres import PostgresCredentials
import agate
import sqlparse
from dbt.adapters.sql import SQLConnectionManager
from dbt.contracts.connection import AdapterResponse, Connection, Credentials
from dbt.events import AdapterLogger
import dbt.exceptions
import dbt.flags

import boto3

import redshift_connector
from dbt.dataclass_schema import FieldEncoder, dbtClassMixin, StrEnum

from dataclasses import dataclass, field
from typing import Optional, List

from dbt.helper_types import Port
from redshift_connector import OperationalError, DatabaseError, DataError

logger = AdapterLogger("Redshift")

drop_lock: Lock = dbt.flags.MP_CONTEXT.Lock() # type: ignore
Expand All @@ -38,33 +42,154 @@ class RedshiftConnectionMethod(StrEnum):


@dataclass
class RedshiftCredentials(PostgresCredentials):
class RedshiftCredentials(Credentials):
host: str
user: str
port: Port
method: str = RedshiftConnectionMethod.DATABASE # type: ignore
password: Optional[str] = None # type: ignore
cluster_id: Optional[str] = field(
default=None,
metadata={"description": "If using IAM auth, the name of the cluster"},
)
iam_profile: Optional[str] = None
iam_duration_seconds: int = 900
search_path: Optional[str] = None
keepalives_idle: int = 4
autocreate: bool = False
db_groups: List[str] = field(default_factory=list)
ra3_node: Optional[bool] = False
connect_timeout: int = 30
role: Optional[str] = None
sslmode: Optional[str] = None
retries: int = 1

_ALIASES = {"dbname": "database", "pass": "password"}

@property
def type(self):
return "redshift"

def _connection_keys(self):
keys = super()._connection_keys()
return keys + ("method", "cluster_id", "iam_profile", "iam_duration_seconds")
return "host", "port", "user", "database", "schema", "method", "cluster_id", "iam_profile"

@property
def unique_field(self) -> str:
return self.host


class RedshiftConnectMethodFactory:
credentials: RedshiftCredentials

def __init__(self, credentials):
self.credentials = credentials

def get_connect_method(self):
method = self.credentials.method
kwargs = {
"host": self.credentials.host,
"database": self.credentials.database,
"port": self.credentials.port if self.credentials.port else 5439,
"auto_create": self.credentials.autocreate,
"db_groups": self.credentials.db_groups,
"region": self.credentials.host.split(".")[2],
"timeout": self.credentials.connect_timeout,
}
if self.credentials.sslmode:
kwargs["sslmode"] = self.credentials.sslmode

# Support missing 'method' for backwards compatibility
if method == RedshiftConnectionMethod.DATABASE or method is None:
# this requirement is really annoying to encode into json schema,
# so validate it here
if self.credentials.password is None:
raise dbt.exceptions.FailedToConnectError(
"'password' field is required for 'database' credentials"
)

def connect():
logger.debug("Connecting to redshift with username/password based auth...")
c = redshift_connector.connect(
user=self.credentials.user, password=self.credentials.password, **kwargs
)
if self.credentials.role:
c.cursor().execute("set role {}".format(self.credentials.role))
return c

return connect

class RedshiftConnectionManager(PostgresConnectionManager):
elif method == RedshiftConnectionMethod.IAM:
if not self.credentials.cluster_id and "serverless" not in self.credentials.host:
raise dbt.exceptions.FailedToConnectError(
"Failed to use IAM method. 'cluster_id' must be provided for provisioned cluster. "
"'host' must be provided for serverless endpoint."
)

def connect():
logger.debug("Connecting to redshift with IAM based auth...")
c = redshift_connector.connect(
iam=True,
db_user=self.credentials.user,
password="",
user="",
cluster_identifier=self.credentials.cluster_id,
profile=self.credentials.iam_profile,
**kwargs,
)
if self.credentials.role:
c.cursor().execute("set role {}".format(self.credentials.role))
return c

return connect
else:
raise dbt.exceptions.FailedToConnectError(
"Invalid 'method' in profile: '{}'".format(method)
)


class RedshiftConnectionManager(SQLConnectionManager):
TYPE = "redshift"

def _get_backend_pid(self):
sql = "select pg_backend_pid()"
_, cursor = self.add_query(sql)
res = cursor.fetchone()
return res

def cancel(self, connection: Connection):
connection_name = connection.name
try:
pid = self._get_backend_pid()
sql = "select pg_terminate_backend({})".format(pid)
_, cursor = self.add_query(sql)
res = cursor.fetchone()
logger.debug("Cancel query '{}': {}".format(connection_name, res))
except redshift_connector.error.InterfaceError as e:
if "is closed" in str(e):
logger.debug(f"Connection {connection_name} was already closed")
return
raise

@classmethod
def get_response(cls, cursor: redshift_connector.Cursor) -> AdapterResponse:
rows = cursor.rowcount
message = f"cursor.rowcount = {rows}"
return AdapterResponse(_message=message, rows_affected=rows)

@contextmanager
def exception_handler(self, sql):
try:
yield
except redshift_connector.error.DatabaseError as e:
logger.debug(f"Redshift error: {str(e)}")
self.rollback_if_open()
raise dbt.exceptions.DbtDatabaseError(str(e))
except Exception as e:
logger.debug("Error running SQL: {}", sql)
logger.debug("Rolling back transaction.")
self.rollback_if_open()
# Raise DBT native exceptions as is.
if isinstance(e, dbt.exceptions.Exception):
raise
raise dbt.exceptions.DbtRuntimeError(str(e)) from e

@contextmanager
def fresh_transaction(self, name=None):
"""On entrance to this context manager, hold an exclusive lock and
Expand All @@ -89,83 +214,68 @@ def fresh_transaction(self, name=None):
self.begin()

@classmethod
def fetch_cluster_credentials(
cls, db_user, db_name, cluster_id, iam_profile, duration_s, autocreate, db_groups
):
"""Fetches temporary login credentials from AWS. The specified user
must already exist in the database, or else an error will occur"""

if iam_profile is None:
session = boto3.Session()
boto_client = session.client("redshift")
def open(cls, connection):
if connection.state == "open":
logger.debug("Connection is already open, skipping open.")
return connection

credentials = connection.credentials
connect_method_factory = RedshiftConnectMethodFactory(credentials)

def exponential_backoff(attempt: int):
return attempt * attempt

retryable_exceptions = [OperationalError, DatabaseError, DataError]

return cls.retry_connection(
connection,
connect=connect_method_factory.get_connect_method(),
logger=logger,
retry_limit=credentials.retries,
retry_timeout=exponential_backoff,
retryable_exceptions=retryable_exceptions,
)

def execute(
self, sql: str, auto_begin: bool = False, fetch: bool = False
) -> Tuple[AdapterResponse, agate.Table]:
_, cursor = self.add_query(sql, auto_begin)
response = self.get_response(cursor)
if fetch:
table = self.get_result_from_cursor(cursor)
else:
logger.debug("Connecting to Redshift using 'IAM'" + f"with profile {iam_profile}")
boto_session = boto3.Session(profile_name=iam_profile)
boto_client = boto_session.client("redshift")
table = dbt.clients.agate_helper.empty_table()
return response, table

try:
return boto_client.get_cluster_credentials(
DbUser=db_user,
DbName=db_name,
ClusterIdentifier=cluster_id,
DurationSeconds=duration_s,
AutoCreate=autocreate,
DbGroups=db_groups,
)
def add_query(self, sql, auto_begin=True, bindings=None, abridge_sql_log=False):

except boto_client.exceptions.ClientError as e:
raise dbt.exceptions.FailedToConnectError(
"Unable to get temporary Redshift cluster credentials: {}".format(e)
)
connection = None
cursor = None

@classmethod
def get_tmp_iam_cluster_credentials(cls, credentials):
cluster_id = credentials.cluster_id
queries = sqlparse.split(sql)

# default via:
# boto3.readthedocs.io/en/latest/reference/services/redshift.html
iam_duration_s = credentials.iam_duration_seconds
for query in queries:
# Strip off comments from the current query
without_comments = re.sub(
re.compile(r"(\".*?\"|\'.*?\')|(/\*.*?\*/|--[^\r\n]*$)", re.MULTILINE),
"",
query,
).strip()

if not cluster_id:
raise dbt.exceptions.FailedToConnectError(
"'cluster_id' must be provided in profile if IAM " "authentication method selected"
if without_comments == "":
continue

connection, cursor = super().add_query(
query, auto_begin, bindings=bindings, abridge_sql_log=abridge_sql_log
)

cluster_creds = cls.fetch_cluster_credentials(
credentials.user,
credentials.database,
credentials.cluster_id,
credentials.iam_profile,
iam_duration_s,
credentials.autocreate,
credentials.db_groups,
)
if cursor is None:
conn = self.get_thread_connection()
conn_name = conn.name if conn and conn.name else "<None>"
raise dbt.exceptions.DbtRuntimeError(f"Tried to run invalid SQL: {sql} on {conn_name}")

# replace username and password with temporary redshift credentials
return credentials.replace(
user=cluster_creds.get("DbUser"), password=cluster_creds.get("DbPassword")
)
return connection, cursor

@classmethod
def get_credentials(cls, credentials):
method = credentials.method

# Support missing 'method' for backwards compatibility
if method == "database" or method is None:
logger.debug("Connecting to Redshift using 'database' credentials")
# this requirement is really annoying to encode into json schema,
# so validate it here
if credentials.password is None:
raise dbt.exceptions.FailedToConnectError(
"'password' field is required for 'database' credentials"
)
return credentials

elif method == "iam":
logger.debug("Connecting to Redshift using 'IAM' credentials")
return cls.get_tmp_iam_cluster_credentials(credentials)

else:
raise dbt.exceptions.FailedToConnectError(
"Invalid 'method' in profile: '{}'".format(method)
)
return credentials
12 changes: 10 additions & 2 deletions dbt/adapters/redshift/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from dbt.adapters.base.impl import AdapterConfig
from dbt.adapters.sql import SQLAdapter
from dbt.adapters.base.meta import available
from dbt.adapters.postgres import PostgresAdapter
from dbt.adapters.redshift import RedshiftConnectionManager
from dbt.adapters.redshift.column import RedshiftColumn
from dbt.adapters.redshift import RedshiftRelation
Expand All @@ -22,7 +21,7 @@ class RedshiftConfig(AdapterConfig):
backup: Optional[bool] = True


class RedshiftAdapter(PostgresAdapter, SQLAdapter):
class RedshiftAdapter(SQLAdapter):
Relation = RedshiftRelation
ConnectionManager = RedshiftConnectionManager
Column = RedshiftColumn # type: ignore
Expand Down Expand Up @@ -91,3 +90,12 @@ def _get_catalog_schemas(self, manifest):
self.type(), exc.msg
)
)

def valid_incremental_strategies(self):
"""The set of standard builtin strategies which this adapter supports out-of-the-box.
Not used to validate custom strategies defined by end users.
"""
return ["append", "delete+insert"]

def timestamp_add_sql(self, add_to: str, number: int = 1, interval: str = "hour") -> str:
return f"{add_to} + interval '{number} {interval}'"
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def _core_version(plugin_version: str = _plugin_version()) -> str:
f"dbt-core~={_core_version()}",
f"dbt-postgres~={_core_version()}",
"boto3~=1.26.26",
"redshift-connector~=2.0.910",
],
zip_safe=False,
classifiers=[
Expand Down
Loading

0 comments on commit 2cc47bb

Please sign in to comment.