Skip to content

Commit

Permalink
Adding: Snowflake Role in snowflake provider hook (#16735)
Browse files Browse the repository at this point in the history
* Adding:
1. 'extra__snowflake__role' to get_connection_form_widgets() to enable snowflake role capture.
2. 'extra__snowflake__role' to get_ui_field_behaviour() to placeholders to return snowflake role to the UI.
3. Updated _get_conn_params() to capture snowflake role from 'extra__snowflake__role'.


Co-authored-by: saurasingh <[email protected]>
  • Loading branch information
saurasingh and saurasingh authored Jul 7, 2021
1 parent b0f7f91 commit 5999cb9
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 2 deletions.
5 changes: 3 additions & 2 deletions airflow/providers/snowflake/hooks/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def get_connection_form_widgets() -> Dict[str, Any]:
"extra__snowflake__aws_secret_access_key": PasswordField(
lazy_gettext('AWS Secret Key'), widget=BS3PasswordFieldWidget()
),
"extra__snowflake__role": StringField(lazy_gettext('Role'), widget=BS3TextFieldWidget()),
}

@staticmethod
Expand All @@ -112,7 +113,6 @@ def get_ui_field_behaviour() -> Dict:
"placeholders": {
'extra': json.dumps(
{
"role": "snowflake role",
"authenticator": "snowflake oauth",
"private_key_file": "private key",
"session_parameters": "session parameters",
Expand All @@ -129,6 +129,7 @@ def get_ui_field_behaviour() -> Dict:
'extra__snowflake__region': 'snowflake hosted region',
'extra__snowflake__aws_access_key_id': 'aws access key id (S3ToSnowflakeOperator)',
'extra__snowflake__aws_secret_access_key': 'aws secret access key (S3ToSnowflakeOperator)',
'extra__snowflake__role': 'snowflake role',
},
}

Expand Down Expand Up @@ -160,7 +161,7 @@ def _get_conn_params(self) -> Dict[str, Optional[str]]:
'database', ''
)
region = conn.extra_dejson.get('extra__snowflake__region', '') or conn.extra_dejson.get('region', '')
role = conn.extra_dejson.get('role', '')
role = conn.extra_dejson.get('extra__snowflake__role', '') or conn.extra_dejson.get('role', '')
schema = conn.schema or ''
authenticator = conn.extra_dejson.get('authenticator', 'snowflake')
session_parameters = conn.extra_dejson.get('session_parameters')
Expand Down
157 changes: 157 additions & 0 deletions tests/providers/snowflake/hooks/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,160 @@ def test_key_pair_auth_not_encrypted(self):
self.conn.password = None
params = self.db_hook._get_conn_params()
assert 'private_key' in params


"""
Testing hooks with assigning`extra_` parameters
"""


class TestSnowflakeHookExtra(unittest.TestCase):
def setUp(self):
super().setUp()

self.conn = conn = mock.MagicMock()

self.conn.login = 'user'
self.conn.password = 'pw'
self.conn.schema = 'public'
self.conn.extra_dejson = {
'extra__snowflake__database': 'db',
'extra__snowflake__account': 'airflow',
'extra__snowflake__warehouse': 'af_wh',
'extra__snowflake__region': 'af_region',
'extra__snowflake__role': 'af_role',
}

class UnitTestSnowflakeHookExtra(SnowflakeHook):
conn_name_attr = 'snowflake_conn_id'

def get_conn(self):
return conn

def get_connection(self, _):
return conn

self.db_hook_extra = UnitTestSnowflakeHookExtra(
session_parameters={"QUERY_TAG": "This is a test hook"}
)

self.non_encrypted_private_key = "/tmp/test_key.pem"
self.encrypted_private_key = "/tmp/test_key.p8"

# Write some temporary private keys. First is not encrypted, second is with a passphrase.
key = rsa.generate_private_key(backend=default_backend(), public_exponent=65537, key_size=2048)
private_key = key.private_bytes(
serialization.Encoding.PEM, serialization.PrivateFormat.PKCS8, serialization.NoEncryption()
)

with open(self.non_encrypted_private_key, "wb") as file:
file.write(private_key)

key = rsa.generate_private_key(backend=default_backend(), public_exponent=65537, key_size=2048)
private_key = key.private_bytes(
serialization.Encoding.PEM,
serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.BestAvailableEncryption(self.conn.password.encode()),
)

with open(self.encrypted_private_key, "wb") as file:
file.write(private_key)

def tearDownExtra(self):
os.remove(self.encrypted_private_key)
os.remove(self.non_encrypted_private_key)

def test_get_uri_extra(self):
uri_shouldbe = (
'snowflake://user:pw@airflow/db/public?warehouse=af_wh&role=af_role&authenticator=snowflake'
)
assert uri_shouldbe == self.db_hook_extra.get_uri()

@parameterized.expand(
[
('select * from table', ['uuid', 'uuid']),
('select * from table;select * from table2', ['uuid', 'uuid', 'uuid2', 'uuid2']),
(['select * from table;'], ['uuid', 'uuid']),
(['select * from table;', 'select * from table2;'], ['uuid', 'uuid', 'uuid2', 'uuid2']),
],
)
def test_run_storing_query_ids_extra(self, sql, query_ids):
cur = mock.MagicMock(rowcount=0)
self.conn.cursor.return_value = cur
type(cur).sfqid = mock.PropertyMock(side_effect=query_ids)
mock_params = {"mock_param": "mock_param"}
self.db_hook_extra.run(sql, parameters=mock_params)

sql_list = sql if isinstance(sql, list) else re.findall(".*?[;]", sql)
cur.execute.assert_has_calls([mock.call(query, mock_params) for query in sql_list])
assert self.db_hook_extra.query_ids == query_ids[::2]
cur.close.assert_called()

def test_get_conn_params_extra(self):
conn_params_shouldbe = {
'user': 'user',
'password': 'pw',
'schema': 'public',
'database': 'db',
'account': 'airflow',
'warehouse': 'af_wh',
'region': 'af_region',
'role': 'af_role',
'authenticator': 'snowflake',
'session_parameters': {"QUERY_TAG": "This is a test hook"},
"application": "AIRFLOW",
}
assert self.db_hook_extra.snowflake_conn_id == 'snowflake_default'
assert conn_params_shouldbe == self.db_hook_extra._get_conn_params()

def test_get_conn_params_env_variable_extra(self):
conn_params_shouldbe = {
'user': 'user',
'password': 'pw',
'schema': 'public',
'database': 'db',
'account': 'airflow',
'warehouse': 'af_wh',
'region': 'af_region',
'role': 'af_role',
'authenticator': 'snowflake',
'session_parameters': {"QUERY_TAG": "This is a test hook"},
"application": "AIRFLOW_TEST",
}
with patch_environ({"AIRFLOW_SNOWFLAKE_PARTNER": 'AIRFLOW_TEST'}):
assert self.db_hook_extra.snowflake_conn_id == 'snowflake_default'
assert conn_params_shouldbe == self.db_hook_extra._get_conn_params()

def test_get_conn_extra(self):
assert self.db_hook_extra.get_conn() == self.conn

def test_key_pair_auth_encrypted_extra(self):
self.conn.extra_dejson = {
'database': 'db',
'account': 'airflow',
'warehouse': 'af_wh',
'region': 'af_region',
'role': 'af_role',
'private_key_file': self.encrypted_private_key,
}

params = self.db_hook_extra._get_conn_params()
assert 'private_key' in params

def test_key_pair_auth_not_encrypted_extra(self):
self.conn.extra_dejson = {
'database': 'db',
'account': 'airflow',
'warehouse': 'af_wh',
'region': 'af_region',
'role': 'af_role',
'private_key_file': self.non_encrypted_private_key,
}

self.conn.password = ''
params = self.db_hook_extra._get_conn_params()
assert 'private_key' in params

self.conn.password = None
params = self.db_hook_extra._get_conn_params()
assert 'private_key' in params

0 comments on commit 5999cb9

Please sign in to comment.