Skip to content

Commit

Permalink
fix the unit-test
Browse files Browse the repository at this point in the history
Signed-off-by: HH <[email protected]>
  • Loading branch information
hhcs9527 committed Aug 30, 2023
1 parent 920263a commit 6d867fa
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 17 deletions.
25 changes: 12 additions & 13 deletions plugins/flytekit-snowflake/tests/test_agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import json
import re
from dataclasses import asdict
from datetime import timedelta
from unittest import mock
Expand Down Expand Up @@ -57,15 +56,6 @@ async def test_snowflake_agent(mock_conn):
table="dummy_table",
)

task_config = {
"user" : "dummy_user",
"account" : "dummy_account",
"database" : "dummy_database",
"schema" : "dummy_schema",
"warehouse" : "dummy_warehouse",
"table" : "dummy_table",
}

int_type = types.LiteralType(types.SimpleType.INTEGER)
interfaces = interface_models.TypedInterface(
{
Expand All @@ -83,14 +73,23 @@ async def test_snowflake_agent(mock_conn):

dummy_template = TaskTemplate(
id=task_id,
custom=task_config,
custom=None,
config=task_config,
metadata=task_metadata,
interface=interfaces,
type="snowflake",
sql=Sql("SELECT 1"),
)

metadata = Metadata(user="dummy_user",account="dummy_account",table="dummy_table",database="dummy_database",schema="dummy_schema",warehouse="dummy_warehouse",query_id="dummy_query_id")
metadata = Metadata(
user="dummy_user",
account="dummy_account",
table="dummy_table",
database="dummy_database",
schema="dummy_schema",
warehouse="dummy_warehouse",
query_id="dummy_query_id",
)

res = await agent.async_create(ctx, "/tmp", dummy_template, task_inputs)
metadata.query_id = Metadata(**json.loads(res.resource_meta.decode("utf-8"))).query_id
Expand Down Expand Up @@ -118,4 +117,4 @@ async def test_snowflake_agent(mock_conn):

# Verify that the connection was closed
mock_cursor.close.assert_called_once()
mock_conn_instance.close.assert_called_once()
mock_conn_instance.close.assert_called_once()
8 changes: 4 additions & 4 deletions plugins/flytekit-snowflake/tests/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ def my_wf(ds: str) -> FlyteSchema:
assert "{{ .rawOutputDataPrefix" in task_spec.template.sql.statement
assert "insert overwrite directory" in task_spec.template.sql.statement
assert task_spec.template.sql.dialect == task_spec.template.sql.Dialect.ANSI
assert "snowflake" == task_spec.template.custom["account"]
assert "my_warehouse" == task_spec.template.custom["warehouse"]
assert "my_schema" == task_spec.template.custom["schema"]
assert "my_database" == task_spec.template.custom["database"]
assert "snowflake" == task_spec.template.config["account"]
assert "my_warehouse" == task_spec.template.config["warehouse"]
assert "my_schema" == task_spec.template.config["schema"]
assert "my_database" == task_spec.template.config["database"]
assert len(task_spec.template.interface.inputs) == 1
assert len(task_spec.template.interface.outputs) == 1

Expand Down

0 comments on commit 6d867fa

Please sign in to comment.