Skip to content

Commit

Permalink
Optimize Snowflake load_file using native COPY INTO (#544)
Browse files Browse the repository at this point in the history
Fix: #430 

Reduce the time to load to Snowflake by 20% for 5GB datasets (from 24.46 min to 5.49 min). Further details are in the PR results file.

Co-authored-by: Ankit Chaurasia <[email protected]>
  • Loading branch information
tatiana and Ankit Chaurasia authored Jul 26, 2022
1 parent a0d1912 commit 461e63e
Show file tree
Hide file tree
Showing 9 changed files with 199 additions and 46 deletions.
19 changes: 11 additions & 8 deletions example_dags/example_amazon_s3_snowflake_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,17 @@ def combine_data(center_1: Table, center_2: Table):
@aql.transform()
def clean_data(input_table: Table):
return """SELECT *
FROM {{input_table}} WHERE TYPE NOT LIKE 'Guinea Pig'
FROM {{input_table}} WHERE type NOT LIKE 'Guinea Pig'
"""


@aql.dataframe(identifiers_as_lower=False)
@aql.dataframe()
def aggregate_data(df: pd.DataFrame):
adoption_reporting_dataframe = df.pivot_table(
new_df = df.pivot_table(
index="date", values="name", columns=["type"], aggfunc="count"
).reset_index()

return adoption_reporting_dataframe
new_df.columns = new_df.columns.str.lower()
return new_df


@dag(
Expand Down Expand Up @@ -67,11 +67,11 @@ def example_amazon_s3_snowflake_transform():
)

temp_table_1 = aql.load_file(
input_file=File(path=f"{s3_bucket}/ADOPTION_CENTER_1.csv"),
input_file=File(path=f"{s3_bucket}/ADOPTION_CENTER_1_unquoted.csv"),
output_table=input_table_1,
)
temp_table_2 = aql.load_file(
input_file=File(path=f"{s3_bucket}/ADOPTION_CENTER_2.csv"),
input_file=File(path=f"{s3_bucket}/ADOPTION_CENTER_2_unquoted.csv"),
output_table=input_table_2,
)

Expand All @@ -85,7 +85,10 @@ def example_amazon_s3_snowflake_transform():
cleaned_data,
output_table=Table(
name="aggregated_adoptions_" + str(int(time.time())),
metadata=Metadata(schema=os.environ["SNOWFLAKE_SCHEMA"]),
metadata=Metadata(
schema=os.environ["SNOWFLAKE_SCHEMA"],
database=os.environ["SNOWFLAKE_DATABASE"],
),
conn_id="snowflake_conn",
),
)
Expand Down
131 changes: 108 additions & 23 deletions src/astro/databases/snowflake.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
"""Snowflake database implementation."""
import logging
import os
import random
import string
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple

import pandas as pd
from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook
from pandas.io.sql import SQLDatabase
from snowflake.connector import pandas_tools
from snowflake.connector.errors import ProgrammingError

from astro import settings
from astro.constants import (
DEFAULT_CHUNK_SIZE,
FileLocation,
Expand All @@ -36,6 +37,14 @@
FileType.PARQUET: "MATCH_BY_COLUMN_NAME=CASE_INSENSITIVE",
}

DEFAULT_STORAGE_INTEGRATION = {
FileLocation.S3: settings.SNOWFLAKE_STORAGE_INTEGRATION_AMAZON,
FileLocation.GS: settings.SNOWFLAKE_STORAGE_INTEGRATION_GOOGLE,
}

NATIVE_LOAD_SUPPORTED_FILE_TYPES = (FileType.CSV, FileType.NDJSON, FileType.PARQUET)
NATIVE_LOAD_SUPPORTED_FILE_LOCATIONS = (FileLocation.GS, FileLocation.S3)


@dataclass
class SnowflakeStage:
Expand Down Expand Up @@ -138,7 +147,6 @@ class SnowflakeDatabase(BaseDatabase):
"""

def __init__(self, conn_id: str = DEFAULT_CONN_ID):
self.storage_integration: Optional[str] = None
super().__init__(conn_id)

@property
Expand Down Expand Up @@ -192,7 +200,9 @@ def _create_stage_auth_sub_statement(
:param storage_integration: Previously created Snowflake storage integration
:return: String containing line to be used for authentication on the remote storage
"""

storage_integration = storage_integration or DEFAULT_STORAGE_INTEGRATION.get(
file.location.location_type
)
if storage_integration is not None:
auth = f"storage_integration = {storage_integration};"
else:
Expand Down Expand Up @@ -291,6 +301,93 @@ def drop_stage(self, stage: SnowflakeStage) -> None:
# Table load methods
# ---------------------------------------------------------

def create_table_using_schema_autodetection(
self,
table: Table,
file: Optional[File] = None,
dataframe: Optional[pd.DataFrame] = None,
) -> None:
"""
Create a SQL table, automatically inferring the schema using the given file.
:param table: The table to be created.
:param file: File used to infer the new table columns.
:param dataframe: Dataframe used to infer the new table columns if there is no file
"""
if file:
dataframe = file.export_to_dataframe(
nrows=settings.LOAD_TABLE_AUTODETECT_ROWS_COUNT
)

# Snowflake doesn't handle well mixed capitalisation of column name chars
# we are handling this more gracefully in a separate PR
if dataframe is not None:
dataframe.columns.str.upper()

super().create_table_using_schema_autodetection(table, dataframe=dataframe)

def is_native_load_file_available(
self, source_file: File, target_table: Table
) -> bool:
"""
Check if there is an optimised path for source to destination.
:param source_file: File from which we need to transfer data
:param target_table: Table that needs to be populated with file data
"""
is_file_type_supported = (
source_file.type.name in NATIVE_LOAD_SUPPORTED_FILE_TYPES
)
is_file_location_supported = (
source_file.location.location_type in NATIVE_LOAD_SUPPORTED_FILE_LOCATIONS
)
return is_file_type_supported and is_file_location_supported

def load_file_to_table_natively(
self,
source_file: File,
target_table: Table,
if_exists: LoadExistStrategy = "replace",
native_support_kwargs: Optional[Dict] = None,
**kwargs,
) -> None:
"""
Load the content of a file to an existing Snowflake table natively by:
- Creating a Snowflake external stage
- Using Snowflake COPY INTO statement
Requirements:
- The user must have permissions to create a STAGE in Snowflake.
- If loading from GCP Cloud Storage, `native_support_kwargs` must define `storage_integration`
- If loading from AWS S3, the credentials for creating the stage may be
retrieved from the Airflow connection or from the `storage_integration`
attribute within `native_support_kwargs`.
:param source_file: File from which we need to transfer data
:param target_table: Table to which the content of the file will be loaded to
:param if_exists: Strategy used to load (currently supported: "append" or "replace")
:param native_support_kwargs: may be used for the stage creation, as described above.
.. seealso::
`Snowflake official documentation on COPY INTO
<https://docs.snowflake.com/en/sql-reference/sql/copy-into-table.html>`_
`Snowflake official documentation on CREATE STAGE
<https://docs.snowflake.com/en/sql-reference/sql/create-stage.html>`_
"""
native_support_kwargs = native_support_kwargs or {}
storage_integration = native_support_kwargs.get("storage_integration")
stage = self.create_stage(
file=source_file, storage_integration=storage_integration
)
table_name = self.get_table_qualified_name(target_table)
file_path = os.path.basename(source_file.path) or ""
sql_statement = (
f"COPY INTO {table_name} FROM @{stage.qualified_name}/{file_path}"
)
self.hook.run(sql_statement)
self.drop_stage(stage)

def load_pandas_dataframe_to_table(
self,
source_dataframe: pd.DataFrame,
Expand All @@ -307,27 +404,15 @@ def load_pandas_dataframe_to_table(
:param if_exists: Strategy to be used in case the target table already exists.
:param chunk_size: Specify the number of rows in each batch to be written at a time.
"""
db = SQLDatabase(engine=self.sqlalchemy_engine)
# Make columns uppercase to prevent weird errors in snowflake
source_dataframe.columns = source_dataframe.columns.str.upper()
schema = None
if target_table.metadata:
schema = getattr(target_table.metadata, "schema", None)

# within prep_table() we use pandas drop() function which is used when we pass 'if_exists=replace'.
# There is an issue where has_table() works with uppercase table names but the function meta.reflect() don't.
# To prevent the issue we are passing table name in lowercase.
db.prep_table(
source_dataframe,
target_table.name.lower(),
schema=schema,
if_exists=if_exists,
index=False,
)
self.create_table(target_table, dataframe=source_dataframe)

self.table_exists(target_table)
pandas_tools.write_pandas(
self.hook.get_conn(),
source_dataframe,
target_table.name,
conn=self.hook.get_conn(),
df=source_dataframe,
table_name=target_table.name,
schema=target_table.metadata.schema,
database=target_table.metadata.database,
chunk_size=chunk_size,
quote_identifiers=False,
)
Expand Down
1 change: 1 addition & 0 deletions tests/benchmark/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ ENV AIRFLOW_HOME=/opt/app/
ENV PYTHONPATH=/opt/app/
ENV ASTRO_PUBLISH_BENCHMARK_DATA=True
ENV GCP_BUCKET=dag-authoring
ENV AIRFLOW__ASTRO_SDK__SNOWFLAKE_STORAGE_INTEGRATION_GOOGLE=gcs_int_python_sdk

# Debian Bullseye is shipped with Python 3.9
# Upgrade built-in pip
Expand Down
7 changes: 5 additions & 2 deletions tests/benchmark/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ clean:
@rm -f unittests.cfg
@rm -f unittests.db
@rm -f webserver_config.py
@rm -f ../../unittests.cfg
@rm -f ../../unittests.db
@rm -f ../../airflow.cfg
@rm -f ../../airflow.db

# Takes approximately 7min
setup_gke:
Expand All @@ -45,8 +49,7 @@ local: check_google_credentials
benchmark
@rm -rf astro-sdk


run_job: check_google_credentials
run_job:
@gcloud container clusters get-credentials astro-sdk --zone us-central1-a --project ${GCP_PROJECT}
@kubectl apply -f infrastructure/kubernetes/namespace.yaml
@kubectl apply -f infrastructure/kubernetes/postgres.yaml
Expand Down
12 changes: 6 additions & 6 deletions tests/benchmark/config.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
{
"databases": [
{
"name": "snowflake",
"params": {
"conn_id": "snowflake_conn"
}
},
{
"name": "postgres",
"params": {
Expand All @@ -17,12 +23,6 @@
"database": "bigquery"
}
}
},
{
"name": "snowflake",
"params": {
"conn_id": "snowflake_conn"
}
}
],
"datasets": [
Expand Down
12 changes: 12 additions & 0 deletions tests/benchmark/debug.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
apiVersion: v1
kind: Pod
metadata:
name: troubleshoot
namespace: benchmark
spec:
containers:
- name: troubleshoot-benchmark
image: gcr.io/astronomer-dag-authoring/benchmark
# Just spin & wait forever
command: [ "/bin/bash", "-c", "--" ]
args: [ "while true; do sleep 30; done;" ]
20 changes: 19 additions & 1 deletion tests/benchmark/results.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ The benchmark was run as a Kubernetes job in GKE:
* Container resource limit:
* Memory: 10 Gi


| database | dataset | total_time | memory_rss | cpu_time_user | cpu_time_system |
|:-----------|:-----------|:-------------|:-------------|:----------------|:------------------|
| snowflake | ten_kb | 4.75s | 59.3MB | 1.45s | 100.0ms |
Expand All @@ -95,6 +94,25 @@ The benchmark was run as a Kubernetes job in GKE:
| snowflake | five_gb | 24.46min | 97.85MB | 1.43min | 5.94s |
| snowflake | ten_gb | 50.85min | 104.53MB | 2.7min | 12.11s |

### With native support

The benchmark was run as a Kubernetes job in GKE:

* Version: `astro-sdk-python` 1.0.0a1 (`bc58830`)
* Machine type: `n2-standard-4`
* vCPU: 4
* Memory: 16 GB RAM
* Container resource limit:
* Memory: 10 Gi

| database | dataset | total_time | memory_rss | cpu_time_user | cpu_time_system |
|:-----------|:-----------|:-------------|:-------------|:----------------|:------------------|
| snowflake | ten_kb | 9.1s | 56.45MB | 2.56s | 110.0ms |
| snowflake | hundred_kb | 9.19s | 45.4MB | 2.55s | 120.0ms |
| snowflake | ten_mb | 10.9s | 47.51MB | 2.58s | 160.0ms |
| snowflake | one_gb | 1.07min | 47.94MB | 8.7s | 5.67s |
| snowflake | five_gb | 5.49min | 53.69MB | 18.76s | 1.6s |

### Database: postgres

| database | dataset | total_time | memory_rss | cpu_time_user | cpu_time_system |
Expand Down
6 changes: 0 additions & 6 deletions tests/data/README.md

This file was deleted.

37 changes: 37 additions & 0 deletions tests/databases/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,3 +451,40 @@ def test_create_stage_amazon_fails_due_to_no_credentials(get_credentials):
"In order to create an stage for S3, one of the following is required"
)
assert exc_info.match(expected_msg)


@pytest.mark.integration
@pytest.mark.parametrize(
"database_table_fixture",
[
{"database": Database.SNOWFLAKE},
],
indirect=True,
ids=["snowflake"],
)
@pytest.mark.parametrize(
"remote_files_fixture",
[
{"provider": "amazon", "filetype": FileType.CSV},
],
indirect=True,
ids=["amazon_csv"],
)
def test_load_file_to_table_natively(remote_files_fixture, database_table_fixture):
"""Load a file to a Snowflake table using the native optimisation."""
filepath = remote_files_fixture[0]
database, target_table = database_table_fixture
database.load_file_to_table(
File(filepath), target_table, {}, use_native_support=True
)

df = database.hook.get_pandas_df(f"SELECT * FROM {target_table.name}")
assert len(df) == 3
expected = pd.DataFrame(
[
{"id": 1, "name": "First"},
{"id": 2, "name": "Second"},
{"id": 3, "name": "Third with unicode पांचाल"},
]
)
test_utils.assert_dataframes_are_equal(df, expected)

0 comments on commit 461e63e

Please sign in to comment.