Skip to content

Commit

Permalink
Merge branch 'fix-sse-py3' into develop
Browse files Browse the repository at this point in the history
* fix-sse-py3:
  Fix SSE key handling for S3 with py3 byte keys

Conflicts:
	tests/integration/test_s3.py

Merges boto#383.
  • Loading branch information
jamesls committed Nov 19, 2014
2 parents 5919d2a + b03183a commit 370176f
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 29 deletions.
39 changes: 22 additions & 17 deletions botocore/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,16 +121,21 @@ def sse_md5(params, **kwargs):
encryption key. This handler does both if the MD5 has not been set by
the caller.
"""
prefix = 'x-amz-server-side-encryption-customer-'
key = prefix + 'key'
key_md5 = prefix + 'key-MD5'
if key in params['headers'] and not key_md5 in params['headers']:
original = six.b(params['headers'][key])
md5 = hashlib.md5()
md5.update(original)
value = base64.b64encode(md5.digest()).decode('utf-8')
params['headers'][key] = base64.b64encode(original).decode('utf-8')
params['headers'][key_md5] = value
if not _needs_s3_sse_customization(params):
return
key_as_bytes = params['SSECustomerKey']
if isinstance(key_as_bytes, six.text_type):
key_as_bytes = key_as_bytes.encode('utf-8')
key_md5_str = base64.b64encode(
hashlib.md5(key_as_bytes).digest()).decode('utf-8')
key_b64_encoded = base64.b64encode(key_as_bytes).decode('utf-8')
params['SSECustomerKey'] = key_b64_encoded
params['SSECustomerKeyMD5'] = key_md5_str


def _needs_s3_sse_customization(params):
return (params.get('SSECustomerKey') is not None and
'SSECustomerKeyMD5' not in params)


def check_dns_name(bucket_name):
Expand Down Expand Up @@ -398,13 +403,13 @@ def base64_encode_user_data(params, **kwargs):
REGISTER_FIRST),
('service-data-loaded', register_retries_for_service),
('service-data-loaded', signature_overrides),
('before-call.s3.HeadObject', sse_md5),
('before-call.s3.GetObject', sse_md5),
('before-call.s3.PutObject', sse_md5),
('before-call.s3.CopyObject', sse_md5),
('before-call.s3.CreateMultipartUpload', sse_md5),
('before-call.s3.UploadPart', sse_md5),
('before-call.s3.UploadPartCopy', sse_md5),
('before-parameter-build.s3.HeadObject', sse_md5),
('before-parameter-build.s3.GetObject', sse_md5),
('before-parameter-build.s3.PutObject', sse_md5),
('before-parameter-build.s3.CopyObject', sse_md5),
('before-parameter-build.s3.CreateMultipartUpload', sse_md5),
('before-parameter-build.s3.UploadPart', sse_md5),
('before-parameter-build.s3.UploadPartCopy', sse_md5),
('before-parameter-build.ec2.RunInstances', base64_encode_user_data),
('before-parameter-build.autoscaling.CreateLaunchConfiguration',
base64_encode_user_data),
Expand Down
51 changes: 51 additions & 0 deletions tests/integration/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import shutil
import threading
import mock
import six
try:
from itertools import izip_longest as zip_longest
except ImportError:
Expand Down Expand Up @@ -597,5 +598,55 @@ def test_verify_can_switch_sigv4(self):
self.assertEqual(http.status_code, 200)


class TestSSEKeyParamValidation(unittest.TestCase):
def setUp(self):
self.session = botocore.session.get_session()
self.client = self.session.create_client('s3', 'us-west-2')
self.bucket_name = 'botocoretest%s-%s' % (
int(time.time()), random.randint(1, 1000))
self.client.create_bucket(
Bucket=self.bucket_name,
CreateBucketConfiguration={
'LocationConstraint': 'us-west-2',
}
)
self.addCleanup(self.client.delete_bucket, Bucket=self.bucket_name)

def test_make_request_with_sse(self):
key_bytes = os.urandom(32)
# Obviously a bad key here, but we just want to ensure we can use
# a str/unicode type as a key.
key_str = 'abcd' * 8

# Put two objects with an sse key, one with random bytes,
# one with str/unicode. Then verify we can GetObject() both
# objects.
self.client.put_object(
Bucket=self.bucket_name, Key='foo.txt',
Body=six.BytesIO(b'mycontents'), SSECustomerAlgorithm='AES256',
SSECustomerKey=key_bytes)
self.addCleanup(self.client.delete_object,
Bucket=self.bucket_name, Key='foo.txt')
self.client.put_object(
Bucket=self.bucket_name, Key='foo2.txt',
Body=six.BytesIO(b'mycontents2'), SSECustomerAlgorithm='AES256',
SSECustomerKey=key_str)
self.addCleanup(self.client.delete_object,
Bucket=self.bucket_name, Key='foo2.txt')

self.assertEqual(
self.client.get_object(Bucket=self.bucket_name,
Key='foo.txt',
SSECustomerAlgorithm='AES256',
SSECustomerKey=key_bytes)['Body'].read(),
b'mycontents')
self.assertEqual(
self.client.get_object(Bucket=self.bucket_name,
Key='foo2.txt',
SSECustomerAlgorithm='AES256',
SSECustomerKey=key_str)['Body'].read(),
b'mycontents2')


if __name__ == '__main__':
unittest.main()
29 changes: 17 additions & 12 deletions tests/unit/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,22 +149,27 @@ def test_500_response_can_be_none(self):
# object is None. We need to handle this case.
handlers.check_for_200_error(None)

def test_sse_headers(self):
prefix = 'x-amz-server-side-encryption-customer-'
def test_sse_params(self):
for op in ('HeadObject', 'GetObject', 'PutObject', 'CopyObject',
'CreateMultipartUpload', 'UploadPart', 'UploadPartCopy'):
event = self.session.create_event(
'before-call', 's3', op)
params = {'headers': {
prefix + 'algorithm': 'foo',
prefix + 'key': 'bar'
}}
'before-parameter-build', 's3', op)
params = {'SSECustomerKey': b'bar',
'SSECustomerAlgorithm': 'AES256'}
self.session.emit(event, params=params, model=mock.Mock())
self.assertEqual(
params['headers'][prefix + 'key'], 'YmFy')
self.assertEqual(
params['headers'][prefix + 'key-MD5'],
'N7UdGUp1E+RbVvZSTy1R8g==')
self.assertEqual(params['SSECustomerKey'], 'YmFy')
self.assertEqual(params['SSECustomerKeyMD5'],
'N7UdGUp1E+RbVvZSTy1R8g==')

def test_sse_params_as_str(self):
event = self.session.create_event(
'before-parameter-build', 's3', 'PutObject')
params = {'SSECustomerKey': 'bar',
'SSECustomerAlgorithm': 'AES256'}
self.session.emit(event, params=params, model=mock.Mock())
self.assertEqual(params['SSECustomerKey'], 'YmFy')
self.assertEqual(params['SSECustomerKeyMD5'],
'N7UdGUp1E+RbVvZSTy1R8g==')

def test_fix_s3_host_initial(self):
endpoint = mock.Mock(region_name='us-west-2')
Expand Down

0 comments on commit 370176f

Please sign in to comment.