Skip to content

Commit

Permalink
[AIRFLOW-4255] Make GCS Hook Backwards compatible (#5089)
Browse files Browse the repository at this point in the history
* [AIRFLOW-4255] Make GCS Hook Backwards compatible

* Update UPDATING.md

* Add option to stop warnings

* Update test_gcs_hook.py

* Add tests
  • Loading branch information
kaxil committed Apr 16, 2019
1 parent f35b507 commit f6a0faf
Show file tree
Hide file tree
Showing 3 changed files with 193 additions and 23 deletions.
19 changes: 3 additions & 16 deletions UPDATING.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,24 +27,11 @@ assists users migrating to a new version.
### Changes to GoogleCloudStorageHook

* the discovery-based api (`googleapiclient.discovery`) used in `GoogleCloudStorageHook` is now replaced by the recommended client based api (`google-cloud-storage`). To know the difference between both the libraries, read https://cloud.google.com/apis/docs/client-libraries-explained. PR: [#5054](https://github.com/apache/airflow/pull/5054)
* as a part of this replacement, the `multipart` & `num_retries` parameters for `GoogleCloudStorageHook.upload` method has been removed:
* as a part of this replacement, the `multipart` & `num_retries` parameters for `GoogleCloudStorageHook.upload` method have been deprecated.

**Old**:
```python
def upload(self, bucket, object, filename,
mime_type='application/octet-stream', gzip=False,
multipart=False, num_retries=0):
```

**New**:
```python
def upload(self, bucket, object, filename,
mime_type='application/octet-stream', gzip=False):
```

The client library uses multipart upload automatically if the object/blob size is more than 8 MB - [source code](https://github.com/googleapis/google-cloud-python/blob/11c543ce7dd1d804688163bc7895cf592feb445f/storage/google/cloud/storage/blob.py#L989-L997).
The client library uses multipart upload automatically if the object/blob size is more than 8 MB - [source code](https://github.com/googleapis/google-cloud-python/blob/11c543ce7dd1d804688163bc7895cf592feb445f/storage/google/cloud/storage/blob.py#L989-L997). The client also handles retries automatically

* the `generation` parameter is no longer supported in `GoogleCloudStorageHook.delete` and `GoogleCloudStorageHook.insert_object_acl`.
* the `generation` parameter is deprecated in `GoogleCloudStorageHook.delete` and `GoogleCloudStorageHook.insert_object_acl`.

## Airflow 1.10.3

Expand Down
30 changes: 25 additions & 5 deletions airflow/contrib/hooks/gcs_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import gzip as gz
import os
import shutil
import warnings

from google.cloud import storage

Expand Down Expand Up @@ -172,7 +173,8 @@ def download(self, bucket, object, filename=None):

# pylint:disable=redefined-builtin
def upload(self, bucket, object, filename,
mime_type='application/octet-stream', gzip=False):
mime_type='application/octet-stream', gzip=False,
multipart=None, num_retries=None):
"""
Uploads a local file to Google Cloud Storage.
Expand All @@ -188,6 +190,14 @@ def upload(self, bucket, object, filename,
:type gzip: bool
"""

if multipart is not None:
warnings.warn("'multipart' parameter is deprecated."
" It is handled automatically by the Storage client", DeprecationWarning)

if num_retries is not None:
warnings.warn("'num_retries' parameter is deprecated."
" It is handled automatically by the Storage client", DeprecationWarning)

if gzip:
filename_gz = filename + '.gz'

Expand Down Expand Up @@ -255,7 +265,7 @@ def is_updated_after(self, bucket, object, ts):

return False

def delete(self, bucket, object):
def delete(self, bucket, object, generation=None):
"""
Deletes an object from the bucket.
Expand All @@ -264,6 +274,10 @@ def delete(self, bucket, object):
:param object: name of the object to delete
:type object: str
"""

if generation is not None:
warnings.warn("'generation' parameter is no longer supported", DeprecationWarning)

client = self.get_conn()
bucket = client.get_bucket(bucket_name=bucket)
blob = bucket.blob(blob_name=object)
Expand Down Expand Up @@ -320,7 +334,7 @@ def list(self, bucket, versions=None, maxResults=None, prefix=None, delimiter=No

def get_size(self, bucket, object):
"""
Gets the size of a file in Google Cloud Storage.
Gets the size of a file in Google Cloud Storage in bytes.
:param bucket: The Google cloud storage bucket where the object is.
:type bucket: str
Expand Down Expand Up @@ -476,7 +490,8 @@ def insert_bucket_acl(self, bucket, entity, role, user_project=None):

self.log.info('A new ACL entry created in bucket: %s', bucket)

def insert_object_acl(self, bucket, object_name, entity, role, user_project=None):
def insert_object_acl(self, bucket, object_name, entity, role, generation=None,
user_project=None):
"""
Creates a new ACL entry on the specified object.
See: https://cloud.google.com/storage/docs/json_api/v1/objectAccessControls/insert
Expand All @@ -499,6 +514,8 @@ def insert_object_acl(self, bucket, object_name, entity, role, user_project=None
Required for Requester Pays buckets.
:type user_project: str
"""
if generation is not None:
warnings.warn("'generation' parameter is no longer supported", DeprecationWarning)
self.log.info('Creating a new ACL entry for object: %s in bucket: %s',
object_name, bucket)
client = self.get_conn()
Expand All @@ -514,7 +531,7 @@ def insert_object_acl(self, bucket, object_name, entity, role, user_project=None
self.log.info('A new ACL entry created for object: %s in bucket: %s',
object_name, bucket)

def compose(self, bucket, source_objects, destination_object):
def compose(self, bucket, source_objects, destination_object, num_retries=None):
"""
Composes a list of existing object into a new object in the same storage bucket
Expand All @@ -532,6 +549,9 @@ def compose(self, bucket, source_objects, destination_object):
:param destination_object: The path of the object if given.
:type destination_object: str
"""
if num_retries is not None:
warnings.warn("'num_retries' parameter is Deprecated. Retries are "
"now handled automatically", DeprecationWarning)

if not source_objects or not len(source_objects):
raise ValueError('source_objects cannot be empty.')
Expand Down
167 changes: 165 additions & 2 deletions tests/contrib/hooks/test_gcs_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import unittest
import io
import six
import tempfile
import os

Expand All @@ -28,6 +28,11 @@
from tests.contrib.utils.base_gcp_mock import mock_base_gcp_hook_default_project_id
from google.cloud import storage
from google.cloud import exceptions
if six.PY2:
# Need `assertWarns` back-ported from unittest2
import unittest2 as unittest
else:
import unittest

BASE_STRING = 'airflow.contrib.hooks.gcp_api_base_hook.{}'
GCS_STRING = 'airflow.contrib.hooks.gcs_hook.{}'
Expand Down Expand Up @@ -224,6 +229,40 @@ def test_rewrite(self, mock_service, mock_bucket):
rewrite_method.assert_called_once_with(
source=source_blob)

def test_rewrite_empty_source_bucket(self):
source_bucket = None
source_object = 'test-source-object'
destination_bucket = 'test-dest-bucket'
destination_object = 'test-dest-object'

with self.assertRaises(ValueError) as e:
self.gcs_hook.rewrite(source_bucket=source_bucket,
source_object=source_object,
destination_bucket=destination_bucket,
destination_object=destination_object)

self.assertEqual(
str(e.exception),
'source_bucket and source_object cannot be empty.'
)

def test_rewrite_empty_source_object(self):
source_bucket = 'test-source-object'
source_object = None
destination_bucket = 'test-dest-bucket'
destination_object = 'test-dest-object'

with self.assertRaises(ValueError) as e:
self.gcs_hook.rewrite(source_bucket=source_bucket,
source_object=source_object,
destination_bucket=destination_bucket,
destination_object=destination_object)

self.assertEqual(
str(e.exception),
'source_bucket and source_object cannot be empty.'
)

@mock.patch('google.cloud.storage.Bucket')
@mock.patch(GCS_STRING.format('GoogleCloudStorageHook.get_conn'))
def test_delete(self, mock_service, mock_bucket):
Expand Down Expand Up @@ -252,6 +291,55 @@ def test_delete_nonexisting_object(self, mock_service):
with self.assertRaises(exceptions.NotFound):
self.gcs_hook.delete(bucket=test_bucket, object=test_object)

@mock.patch(GCS_STRING.format('GoogleCloudStorageHook.get_conn'))
def test_object_get_size(self, mock_service):
test_bucket = 'test_bucket'
test_object = 'test_object'
returned_file_size = 1200

get_bucket_method = mock_service.return_value.get_bucket
get_blob_method = get_bucket_method.return_value.get_blob
get_blob_method.return_value.size = returned_file_size

response = self.gcs_hook.get_size(bucket=test_bucket, object=test_object)

self.assertEquals(response, returned_file_size)
get_blob_method.return_value.reload.assert_called_once_with()

@mock.patch(GCS_STRING.format('GoogleCloudStorageHook.get_conn'))
def test_object_get_crc32c(self, mock_service):
test_bucket = 'test_bucket'
test_object = 'test_object'
returned_file_crc32c = "xgdNfQ=="

get_bucket_method = mock_service.return_value.get_bucket
get_blob_method = get_bucket_method.return_value.get_blob
get_blob_method.return_value.crc32c = returned_file_crc32c

response = self.gcs_hook.get_crc32c(bucket=test_bucket, object=test_object)

self.assertEquals(response, returned_file_crc32c)

# Check that reload method is called
get_blob_method.return_value.reload.assert_called_once_with()

@mock.patch(GCS_STRING.format('GoogleCloudStorageHook.get_conn'))
def test_object_get_md5hash(self, mock_service):
test_bucket = 'test_bucket'
test_object = 'test_object'
returned_file_md5hash = "leYUJBUWrRtks1UeUFONJQ=="

get_bucket_method = mock_service.return_value.get_bucket
get_blob_method = get_bucket_method.return_value.get_blob
get_blob_method.return_value.md5_hash = returned_file_md5hash

response = self.gcs_hook.get_md5hash(bucket=test_bucket, object=test_object)

self.assertEquals(response, returned_file_md5hash)

# Check that reload method is called
get_blob_method.return_value.reload.assert_called_once_with()

@mock.patch('google.cloud.storage.Bucket')
@mock.patch(GCS_STRING.format('GoogleCloudStorageHook.get_conn'))
def test_create_bucket(self, mock_service, mock_bucket):
Expand Down Expand Up @@ -398,6 +486,60 @@ def test_compose_without_destination_object(self, mock_service):
'bucket and destination_object cannot be empty.'
)

# Test Deprecation warnings for deprecated parameters
@mock.patch(GCS_STRING.format('GoogleCloudStorageHook.get_conn'))
def test_compose_deprecated_params(self, mock_service):
test_bucket = 'test_bucket'
test_source_objects = ['test_object_1', 'test_object_2', 'test_object_3']
test_destination_object = 'test_object_composed'

with self.assertWarns(DeprecationWarning):
self.gcs_hook.compose(
bucket=test_bucket,
source_objects=test_source_objects,
destination_object=test_destination_object,
num_retries=5
)

@mock.patch(GCS_STRING.format('GoogleCloudStorageHook.get_conn'))
def test_download_as_string(self, mock_service):
test_bucket = 'test_bucket'
test_object = 'test_object'
test_object_bytes = io.BytesIO(b"input")

download_method = mock_service.return_value.get_bucket.return_value \
.blob.return_value.download_as_string
download_method.return_value = test_object_bytes

response = self.gcs_hook.download(bucket=test_bucket,
object=test_object,
filename=None)

self.assertEquals(response, test_object_bytes)
download_method.assert_called_once_with()

@mock.patch(GCS_STRING.format('GoogleCloudStorageHook.get_conn'))
def test_download_to_file(self, mock_service):
test_bucket = 'test_bucket'
test_object = 'test_object'
test_object_bytes = io.BytesIO(b"input")
test_file = 'test_file'

download_filename_method = mock_service.return_value.get_bucket.return_value \
.blob.return_value.download_to_filename
download_filename_method.return_value = None

download_as_a_string_method = mock_service.return_value.get_bucket.return_value \
.blob.return_value.download_as_string
download_as_a_string_method.return_value = test_object_bytes

response = self.gcs_hook.download(bucket=test_bucket,
object=test_object,
filename=test_file)

self.assertEquals(response, test_object_bytes)
download_filename_method.assert_called_once_with(test_file)


class TestGoogleCloudStorageHookUpload(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -448,3 +590,24 @@ def test_upload_gzip(self, mock_service):
gzip=True)
self.assertFalse(os.path.exists(self.testfile.name + '.gz'))
self.assertIsNone(response)

@mock.patch(GCS_STRING.format('GoogleCloudStorageHook.get_conn'))
def test_upload_deprecated_params(self, mock_service):
test_bucket = 'test_bucket'
test_object = 'test_object'

upload_method = mock_service.return_value.get_bucket.return_value\
.blob.return_value.upload_from_filename
upload_method.return_value = None

with self.assertWarns(DeprecationWarning):
self.gcs_hook.upload(test_bucket,
test_object,
self.testfile.name,
multipart=True)

with self.assertWarns(DeprecationWarning):
self.gcs_hook.upload(test_bucket,
test_object,
self.testfile.name,
num_retries=2)

0 comments on commit f6a0faf

Please sign in to comment.