Skip to content

Commit

Permalink
Merge branch 'master' into sklearn-extra
Browse files Browse the repository at this point in the history
  • Loading branch information
cosmicBboy authored Jan 25, 2023
2 parents 728b9ab + 7b48fff commit 9475341
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 10 deletions.
2 changes: 1 addition & 1 deletion flytekit/core/base_sql_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(
task_config=task_config,
**kwargs,
)
self._query_template = query_template.replace("\n", "\\n").replace("\t", "\\t")
self._query_template = re.sub(r"\s+", " ", query_template.replace("\n", " ").replace("\t", " ")).strip()

@property
def query_template(self) -> str:
Expand Down
5 changes: 2 additions & 3 deletions flytekit/extras/sqlite3/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,13 @@ def __init__(
container_image=container_image or DefaultImages.default_image(),
executor_type=SQLite3TaskExecutor,
task_type=self._SQLITE_TASK_TYPE,
# Sanitize query by removing the newlines at the end of the query. Keep in mind
# that the query can be a multiline string.
query_template=query_template,
inputs=inputs,
outputs=outputs,
**kwargs,
)
# Sanitize query by removing the newlines at the end of the query. Keep in mind
# that the query can be a multiline string.
self._query_template = query_template.replace("\n", " ")

@property
def output_columns(self) -> typing.Optional[typing.List[str]]:
Expand Down
4 changes: 2 additions & 2 deletions plugins/flytekit-snowflake/tests/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_local_exec():
)

assert len(snowflake_task.interface.inputs) == 1
assert snowflake_task.query_template == "select 1\\n"
assert snowflake_task.query_template == "select 1"
assert len(snowflake_task.interface.outputs) == 1

# will not run locally
Expand All @@ -86,4 +86,4 @@ def test_sql_template():
custom where column = 1""",
output_schema_type=FlyteSchema,
)
assert snowflake_task.query_template == "select 1 from\\t\\n custom where column = 1"
assert snowflake_task.query_template == "select 1 from custom where column = 1"
20 changes: 18 additions & 2 deletions plugins/flytekit-sqlalchemy/tests/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,23 @@ def test_task_schema(sql_server):
assert df is not None


def test_workflow(sql_server):
@pytest.mark.parametrize(
"query_template",
[
"select * from tracks limit {{.inputs.limit}}",
"""
select * from tracks
limit {{.inputs.limit}}
""",
"""select * from tracks
limit {{.inputs.limit}}
""",
"""
select * from tracks
limit {{.inputs.limit}}""",
],
)
def test_workflow(sql_server, query_template):
@task
def my_task(df: pandas.DataFrame) -> int:
return len(df[df.columns[0]])
Expand All @@ -84,7 +100,7 @@ def my_task(df: pandas.DataFrame) -> int:

sql_task = SQLAlchemyTask(
"test",
query_template="select * from tracks limit {{.inputs.limit}}",
query_template=query_template,
inputs=kwtypes(limit=int),
task_config=SQLAlchemyConfig(uri=sql_server),
)
Expand Down
4 changes: 2 additions & 2 deletions tests/flytekit/unit/extras/sqlite3/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,14 @@ def test_task_serialization():
select *
from tracks
limit {{.inputs.limit}}""",
" select * from tracks limit {{.inputs.limit}}",
"select * from tracks limit {{.inputs.limit}}",
),
(
""" \
select * \
from tracks \
limit {{.inputs.limit}}""",
" select * from tracks limit {{.inputs.limit}}",
"select * from tracks limit {{.inputs.limit}}",
),
("select * from abc", "select * from abc"),
],
Expand Down

0 comments on commit 9475341

Please sign in to comment.