Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TDL-12486: Added support of compressed files #32

Merged
merged 15 commits into from
Jun 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 120 additions & 27 deletions tap_s3_csv/s3.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import itertools
import re
import io
import json
import gzip
import backoff
import boto3
import singer
Expand All @@ -13,8 +15,14 @@
)
from botocore.exceptions import ClientError
from botocore.session import Session
from singer_encodings import csv
from tap_s3_csv import conversion
from singer_encodings import (
compression,
csv
)
from tap_s3_csv import (
utils,
conversion
)

LOGGER = singer.get_logger()

Expand All @@ -23,7 +31,6 @@
SDC_SOURCE_LINENO_COLUMN = "_sdc_source_lineno"
SDC_EXTRA_COLUMN = "_sdc_extra"


def retry_pattern():
return backoff.on_exception(backoff.expo,
ClientError,
Expand All @@ -33,8 +40,7 @@ def retry_pattern():


def log_backoff_attempt(details):
LOGGER.info(
"Error detected communicating with Amazon, triggering backoff: %d try", details.get("tries"))
LOGGER.info("Error detected communicating with Amazon, triggering backoff: %d try", details.get("tries"))


class AssumeRoleProvider():
Expand Down Expand Up @@ -82,8 +88,7 @@ def get_sampled_schema_for_table(config, table_spec):

s3_files_gen = get_input_files_for_table(config, table_spec)

samples = [sample for sample in sample_files(
config, table_spec, s3_files_gen)]
samples = [sample for sample in sample_files(config, table_spec, s3_files_gen)]

if not samples:
return {}
Expand All @@ -103,7 +108,6 @@ def get_sampled_schema_for_table(config, table_spec):
'properties': merge_dicts(data_schema, metadata_schema)
}


def merge_dicts(first, second):
to_return = first.copy()

Expand Down Expand Up @@ -186,21 +190,45 @@ def check_key_properties_and_date_overrides_for_jsonl_file(table_spec, jsonl_sam
.format(s3_path, date_overrides - all_keys))


def sample_file(config, table_spec, s3_path, sample_rate):
def sampling_gz_file(table_spec, s3_path, file_handle, sample_rate):
if s3_path.endswith(".tar.gz"):
LOGGER.warning('Skipping "%s" file as .tar.gz extension is not supported',s3_path)
return []

file_bytes = file_handle.read()
gz_file_obj = gzip.GzipFile(fileobj=io.BytesIO(file_bytes))

gz_file_name = utils.get_file_name_from_gzfile(fileobj=io.BytesIO(file_bytes))

if gz_file_name:
if gz_file_name.endswith(".gz"):
LOGGER.warning('Skipping "%s" file as it contains nested compression.',s3_path)
return []

gz_file_extension = gz_file_name.split(".")[-1].lower()
return sample_file(table_spec, s3_path + "/" + gz_file_name, io.BytesIO(gz_file_obj.read()), sample_rate, gz_file_extension)

file_handle = get_file_handle(config, s3_path)._raw_stream
raise Exception('"{}" file has some error(s)'.format(s3_path))

extension = s3_path.split(".")[-1].lower()

records = []
def sample_file(table_spec, s3_path, file_handle, sample_rate, extension):

if extension in ("csv","txt"):
iterator = csv.get_row_iterator(file_handle, table_spec, None, True) #pylint:disable=protected-access
records = get_records_for_csv(s3_path, sample_rate, iterator)
elif extension == "jsonl":
iterator = file_handle
# Check whether file is without extension or not
if not extension or s3_path.lower() == extension:
LOGGER.warning('"%s" without extension will not be sampled.',s3_path)
return []
if extension in ["csv", "txt"]:
# If file object read from s3 bucket file else use extracted file object from zip or gz
file_handle = file_handle._raw_stream if hasattr(file_handle, "_raw_stream") else file_handle #pylint:disable=protected-access
iterator = csv.get_row_iterator(file_handle, table_spec, None, True)
return get_records_for_csv(s3_path, sample_rate, iterator)
if extension == "gz":
return sampling_gz_file(table_spec, s3_path, file_handle, sample_rate)
if extension == "jsonl":
# If file object read from s3 bucket file else use extracted file object from zip or gz
file_handle = file_handle._raw_stream if hasattr(file_handle, "_raw_stream") else file_handle
records = get_records_for_jsonl(
s3_path, sample_rate, iterator)
s3_path, sample_rate, file_handle)
check_jsonl_sample_records, records = itertools.tee(
records)
jsonl_sample_records = list(check_jsonl_sample_records)
Expand All @@ -211,22 +239,88 @@ def sample_file(config, table_spec, s3_path, sample_rate):
'No row sampled, Please check your JSONL file {}'.format(s3_path))
check_key_properties_and_date_overrides_for_jsonl_file(
table_spec, jsonl_sample_records, s3_path)
else:
LOGGER.warning(
"'%s' having the '.%s' extension will not be sampled.", s3_path, extension)
return records

return records
if extension == "zip":
LOGGER.warning('Skipping "%s" file as it contains nested compression.',s3_path)
return []
LOGGER.warning('"%s" having the ".%s" extension will not be sampled.',s3_path,extension)
return []


def get_files_to_sample(config, s3_files, max_files):
"""
Returns the list of files for sampling, it checks the s3_files whether any zip or gz file exists or not
if exists then extract if and append in the list of files

Args:
config dict(): Configuration
s3_files list(): List of S3 Bucket files
Returns:
list(dict()) : List of Files for sampling
|_ s3_path str(): S3 Bucket File path
|_ file_handle StreamingBody(): file object
|_ type str(): Type of file which is used for extracted file
|_ extension str(): extension of file (for normal files only)
"""
sampled_files = []

OTHER_FILES = ["csv","gz","jsonl","txt"]

for s3_file in s3_files:
file_key = s3_file.get('key')

if len(sampled_files) >= max_files:
break

if file_key:
file_name = file_key.split("/").pop()
extension = file_name.split(".").pop().lower()
file_handle = get_file_handle(config, file_key)

# Check whether file is without extension or not
if not extension or file_name.lower() == extension:
LOGGER.warning('"%s" without extension will not be sampled.',file_key)
elif file_key.endswith(".tar.gz"):
LOGGER.warning('Skipping "%s" file as .tar.gz extension is not supported', file_key)
elif extension == "zip":
files = compression.infer(io.BytesIO(file_handle.read()), file_name)

# Add only those extracted files which are supported by tap
# Prepare dictionary contains the zip file name, type i.e. unzipped and file object of extracted file
sampled_files.extend([{ "type" : "unzipped", "s3_path" : file_key, "file_handle" : de_file } for de_file in files if de_file.name.split(".")[-1].lower() in OTHER_FILES and not de_file.name.endswith(".tar.gz") ])
elif extension in OTHER_FILES:
# Prepare dictionary contains the s3 file path, extension of file and file object
sampled_files.append({ "s3_path" : file_key , "file_handle" : file_handle, "extension" : extension })
else:
LOGGER.warning('"%s" having the ".%s" extension will not be sampled.',file_key,extension)

return sampled_files


# pylint: disable=too-many-arguments
def sample_files(config, table_spec, s3_files,
sample_rate=5, max_records=1000, max_files=5):
LOGGER.info("Sampling files (max files: %s)", max_files)
for s3_file in itertools.islice(s3_files, max_files):

for s3_file in itertools.islice(get_files_to_sample(config, s3_files, max_files), max_files):

s3_path = s3_file.get("s3_path","")
file_handle = s3_file.get("file_handle")
file_type = s3_file.get("type")
extension = s3_file.get("extension")

# Check whether the file is extracted from zip file.
if file_type and file_type == "unzipped":
# Append the extracted file name with zip file.
s3_path += "/" + file_handle.name
extension = file_handle.name.split(".")[-1].lower()

LOGGER.info('Sampling %s (max records: %s, sample rate: %s)',
s3_file['key'],
s3_path,
max_records,
sample_rate)
yield from itertools.islice(sample_file(config, table_spec, s3_file['key'], sample_rate), max_records)
yield from itertools.islice(sample_file(table_spec, s3_path, file_handle, sample_rate, extension), max_records)


def get_input_files_for_table(config, table_spec, modified_since=None):
Expand Down Expand Up @@ -310,8 +404,7 @@ def list_files_in_bucket(bucket, search_prefix=None):
if s3_object_count > 0:
LOGGER.info("Found %s files.", s3_object_count)
else:
LOGGER.warning(
'Found no files for bucket "%s" that match prefix "%s"', bucket, search_prefix)
LOGGER.warning('Found no files for bucket "%s" that match prefix "%s"', bucket, search_prefix)


@retry_pattern()
Expand Down
Loading