Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Snowflake agent #1799

Merged
merged 17 commits into from
Sep 15, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flytekit/extend/backend/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def convert_to_flyte_state(state: str) -> State:
state = state.lower()
if state in ["failed"]:
return RETRYABLE_FAILURE
elif state in ["done", "succeeded"]:
elif state in ["done", "succeeded", "success"]:
return SUCCEEDED
elif state in ["running"]:
return RUNNING
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@
SnowflakeTask
"""

from .task import SnowflakeConfig, SnowflakeTask
from .agent import SnowflakeAgent
from .task import SnowflakeConfig, SnowflakeTask
157 changes: 157 additions & 0 deletions plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import json
from dataclasses import asdict, dataclass
from typing import Optional

import grpc
import snowflake.connector
from flyteidl.admin.agent_pb2 import (
PERMANENT_FAILURE,
SUCCEEDED,
CreateTaskResponse,
DeleteTaskResponse,
GetTaskResponse,
Resource,
)
from snowflake.connector import ProgrammingError

from flytekit import FlyteContextManager, StructuredDataset, logger
from flytekit.core.type_engine import TypeEngine
from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry, convert_to_flyte_state
from flytekit.models import literals
from flytekit.models.literals import LiteralMap
from flytekit.models.task import TaskTemplate
from flytekit.models.types import LiteralType, StructuredDatasetType

TASK_TYPE = "snowflake"


@dataclass
class Metadata:
user: str
account: str
database: str
schema: str
warehouse: str
table: str
query_id: str


class SnowflakeAgent(AgentBase):
def __init__(self):
super().__init__(task_type=TASK_TYPE)

def get_private_key(self):
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization

import flytekit

pk_path = flytekit.current_context().secrets.get_secrets_file(TASK_TYPE, "rsa_key.p8")
hhcs9527 marked this conversation as resolved.
Show resolved Hide resolved

with open(pk_path, "rb") as key:
p_key = serialization.load_pem_private_key(key.read(), password=None, backend=default_backend())

pkb = p_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)

return pkb

def get_connection(self, metadata: Metadata) -> snowflake.connector:
return snowflake.connector.connect(
user=metadata.user,
account=metadata.account,
private_key=self.get_private_key(),
database=metadata.database,
schema=metadata.schema,
warehouse=metadata.warehouse,
)

async def async_create(
self,
context: grpc.ServicerContext,
output_prefix: str,
task_template: TaskTemplate,
inputs: Optional[LiteralMap] = None,
) -> CreateTaskResponse:
params = None
if inputs:
ctx = FlyteContextManager.current_context()
python_interface_inputs = {
name: TypeEngine.guess_python_type(lt.type) for name, lt in task_template.interface.inputs.items()
}
native_inputs = TypeEngine.literal_map_to_kwargs(ctx, inputs, python_interface_inputs)
logger.info(f"Create Snowflake params with inputs: {native_inputs}")
params = native_inputs

custom = task_template.custom

conn = snowflake.connector.connect(
user=custom["user"],
account=custom["account"],
private_key=self.get_private_key(),
database=custom["database"],
schema=custom["schema"],
warehouse=custom["warehouse"],
)

cs = conn.cursor()
cs.execute_async(task_template.sql.statement, params=params)

metadata = Metadata(
user=custom["user"],
account=custom["account"],
database=custom["database"],
schema=custom["schema"],
warehouse=custom["warehouse"],
table=custom["table"],
query_id=str(cs.sfqid),
)

return CreateTaskResponse(resource_meta=json.dumps(asdict(metadata)).encode("utf-8"))

async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse:
metadata = Metadata(**json.loads(resource_meta.decode("utf-8")))
conn = self.get_connection(metadata)
try:
query_status = conn.get_query_status_throw_if_error(metadata.query_id)
except ProgrammingError as err:
logger.error(err.msg)
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(err.msg)
return GetTaskResponse(resource=Resource(state=PERMANENT_FAILURE))
cur_state = convert_to_flyte_state(str(query_status.name))
res = None

if cur_state == SUCCEEDED:
ctx = FlyteContextManager.current_context()
output_metadata = f"snowflake://{metadata.user}:{metadata.account}/{metadata.database}/{metadata.schema}/{metadata.warehouse}/{metadata.table}"
res = literals.LiteralMap(
{
"results": TypeEngine.to_literal(
ctx,
StructuredDataset(uri=output_metadata),
StructuredDataset,
LiteralType(structured_dataset_type=StructuredDatasetType(format="")),
)
}
).to_flyte_idl()

return GetTaskResponse(resource=Resource(state=cur_state, outputs=res))

async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse:
metadata = Metadata(**json.loads(resource_meta.decode("utf-8")))
conn = self.get_connection(metadata)
cs = conn.cursor()
try:
cs.execute(f"SELECT SYSTEM$CANCEL_QUERY('{metadata.query_id}')")
cs.fetchall()
finally:
cs.close()
conn.close()
return DeleteTaskResponse()


AgentRegistry.register(SnowflakeAgent())
19 changes: 14 additions & 5 deletions plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@

from flytekit.configuration import SerializationSettings
from flytekit.extend import SQLTask
from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin
from flytekit.models import task as _task_model
from flytekit.types.schema import FlyteSchema
from flytekit.types.structured import StructuredDataset

_USER_FIELD = "user"
_ACCOUNT_FIELD = "account"
_DATABASE_FIELD = "database"
_SCHEMA_FIELD = "schema"
_WAREHOUSE_FIELD = "warehouse"
_TABLE_FIELD = "table"


@dataclass
Expand All @@ -18,17 +21,21 @@ class SnowflakeConfig(object):
SnowflakeConfig should be used to configure a Snowflake Task.
"""

# The account to query against
# The user to query against
user: Optional[str] = None
# The account to query againstk
hhcs9527 marked this conversation as resolved.
Show resolved Hide resolved
account: Optional[str] = None
# The database to query against
database: Optional[str] = None
# The optional schema to separate query execution.
schema: Optional[str] = None
# The optional warehouse to set for the given Snowflake query
warehouse: Optional[str] = None
# The optional table to set for the given Snowflake query
table: Optional[str] = None


class SnowflakeTask(SQLTask[SnowflakeConfig]):
class SnowflakeTask(AsyncAgentExecutorMixin, SQLTask[SnowflakeConfig]):
"""
This is the simplest form of a Snowflake Task, that can be used even for tasks that do not produce any output.
"""
Expand All @@ -42,7 +49,7 @@ def __init__(
query_template: str,
task_config: Optional[SnowflakeConfig] = None,
inputs: Optional[Dict[str, Type]] = None,
output_schema_type: Optional[Type[FlyteSchema]] = None,
output_schema_type: Optional[Type[StructuredDataset]] = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -74,12 +81,14 @@ def __init__(
)
self._output_schema_type = output_schema_type

def get_config(self, settings: SerializationSettings) -> Dict[str, str]:
def get_custom(self, settings: SerializationSettings) -> Dict[str, str]:
hhcs9527 marked this conversation as resolved.
Show resolved Hide resolved
return {
_USER_FIELD: self.task_config.user,
_ACCOUNT_FIELD: self.task_config.account,
_DATABASE_FIELD: self.task_config.database,
_SCHEMA_FIELD: self.task_config.schema,
_WAREHOUSE_FIELD: self.task_config.warehouse,
_TABLE_FIELD: self.task_config.table,
}

def get_sql(self, settings: SerializationSettings) -> Optional[_task_model.Sql]:
Expand Down
2 changes: 1 addition & 1 deletion plugins/flytekit-snowflake/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

microlib_name = f"flytekitplugins-{PLUGIN_NAME}"

plugin_requires = ["flytekit>=1.3.0b2,<2.0.0"]
plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "snowflake-connector-python>=3.1.0"]

__version__ = "0.0.0+develop"

Expand Down
121 changes: 121 additions & 0 deletions plugins/flytekit-snowflake/tests/test_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import json
import re
from dataclasses import asdict
from datetime import timedelta
from unittest import mock
from unittest.mock import MagicMock

import grpc
import pytest
from flyteidl.admin.agent_pb2 import SUCCEEDED, DeleteTaskResponse
from flytekitplugins.snowflake.agent import Metadata
from flytekitplugins.snowflake.task import SnowflakeConfig

import flytekit.models.interface as interface_models
from flytekit.extend.backend.base_agent import AgentRegistry
from flytekit.interfaces.cli_identifiers import Identifier
from flytekit.models import literals, task, types
from flytekit.models.core.identifier import ResourceType
from flytekit.models.task import Sql, TaskTemplate


@mock.patch("snowflake.connector.connect")
@pytest.mark.asyncio
async def test_snowflake_agent(mock_conn):
query_status_mock = MagicMock()
query_status_mock.name = "SUCCEEDED"

# Configure the mock connection to return the mock status object
mock_conn_instance = mock_conn.return_value
mock_conn_instance.get_query_status_throw_if_error.return_value = query_status_mock

ctx = MagicMock(spec=grpc.ServicerContext)
agent = AgentRegistry.get_agent(ctx, "snowflake")

task_id = Identifier(
resource_type=ResourceType.TASK, project="project", domain="domain", name="name", version="version"
)

task_metadata = task.TaskMetadata(
True,
task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"),
timedelta(days=1),
literals.RetryStrategy(3),
True,
"0.1.1b0",
"This is deprecated!",
True,
"A",
)

task_config = SnowflakeConfig(
user="dummy_user",
account="dummy_account",
database="dummy_database",
schema="dummy_schema",
warehouse="dummy_warehouse",
table="dummy_table",
)

task_config = {
"user" : "dummy_user",
"account" : "dummy_account",
"database" : "dummy_database",
"schema" : "dummy_schema",
"warehouse" : "dummy_warehouse",
"table" : "dummy_table",
}

int_type = types.LiteralType(types.SimpleType.INTEGER)
interfaces = interface_models.TypedInterface(
{
"a": interface_models.Variable(int_type, "description1"),
"b": interface_models.Variable(int_type, "description2"),
},
{},
)
task_inputs = literals.LiteralMap(
{
"a": literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=1))),
"b": literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=1))),
},
)

dummy_template = TaskTemplate(
id=task_id,
custom=task_config,
metadata=task_metadata,
interface=interfaces,
type="snowflake",
sql=Sql("SELECT 1"),
)

metadata = Metadata(user="dummy_user",account="dummy_account",table="dummy_table",database="dummy_database",schema="dummy_schema",warehouse="dummy_warehouse",query_id="dummy_query_id")

res = await agent.async_create(ctx, "/tmp", dummy_template, task_inputs)
metadata.query_id = Metadata(**json.loads(res.resource_meta.decode("utf-8"))).query_id
metadata_bytes = json.dumps(asdict(metadata)).encode("utf-8")
assert res.resource_meta == metadata_bytes

res = await agent.async_get(ctx, metadata_bytes)
assert res.resource.state == SUCCEEDED
assert (
res.resource.outputs.literals["results"].scalar.structured_dataset.uri
== "snowflake://dummy_user:dummy_account/dummy_database/dummy_schema/dummy_warehouse/dummy_table"
)

delete_response = await agent.async_delete(ctx, metadata_bytes)

# Assert the response
assert isinstance(delete_response, DeleteTaskResponse)

# Verify that the expected methods were called on the mock cursor
mock_cursor = mock_conn_instance.cursor.return_value
mock_cursor.fetchall.assert_called_once()

mock_cursor.execute.assert_called_once_with(f"SELECT SYSTEM$CANCEL_QUERY('{metadata.query_id}')")
mock_cursor.fetchall.assert_called_once()

# Verify that the connection was closed
mock_cursor.close.assert_called_once()
mock_conn_instance.close.assert_called_once()
8 changes: 4 additions & 4 deletions plugins/flytekit-snowflake/tests/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ def my_wf(ds: str) -> FlyteSchema:
assert "{{ .rawOutputDataPrefix" in task_spec.template.sql.statement
assert "insert overwrite directory" in task_spec.template.sql.statement
assert task_spec.template.sql.dialect == task_spec.template.sql.Dialect.ANSI
assert "snowflake" == task_spec.template.config["account"]
assert "my_warehouse" == task_spec.template.config["warehouse"]
assert "my_schema" == task_spec.template.config["schema"]
assert "my_database" == task_spec.template.config["database"]
assert "snowflake" == task_spec.template.custom["account"]
assert "my_warehouse" == task_spec.template.custom["warehouse"]
assert "my_schema" == task_spec.template.custom["schema"]
assert "my_database" == task_spec.template.custom["database"]
assert len(task_spec.template.interface.inputs) == 1
assert len(task_spec.template.interface.outputs) == 1

Expand Down