diff --git a/moto/s3/models.py b/moto/s3/models.py index bbfd571b..343b3ae8 100644 --- a/moto/s3/models.py +++ b/moto/s3/models.py @@ -36,11 +36,10 @@ class FakeKey(object): r.name = new_name return r - def set_metadata(self, key, metadata): - self._metadata[key] = metadata - - def clear_metadata(self): - self._metadata = {} + def set_metadata(self, metadata, replace=False): + if replace: + self._metadata = {} + self._metadata.update(metadata) def set_storage_class(self, storage_class): self._storage_class = storage_class @@ -115,8 +114,9 @@ class FakeKey(object): class FakeMultipart(object): - def __init__(self, key_name): + def __init__(self, key_name, metadata): self.key_name = key_name + self.metadata = metadata self.parts = {} rand_b64 = base64.b64encode(os.urandom(UPLOAD_ID_BYTES)) self.id = rand_b64.decode('utf-8').replace('=', '').replace('+', '') @@ -267,9 +267,9 @@ class S3Backend(BaseBackend): if str(key._version_id) == str(version_id): return key - def initiate_multipart(self, bucket_name, key_name): + def initiate_multipart(self, bucket_name, key_name, metadata): bucket = self.get_bucket(bucket_name) - new_multipart = FakeMultipart(key_name) + new_multipart = FakeMultipart(key_name, metadata) bucket.multiparts[new_multipart.id] = new_multipart return new_multipart @@ -282,7 +282,9 @@ class S3Backend(BaseBackend): return del bucket.multiparts[multipart_id] - return self.set_key(bucket_name, multipart.key_name, value, etag=etag) + key = self.set_key(bucket_name, multipart.key_name, value, etag=etag) + key.set_metadata(multipart.metadata) + return key def cancel_multipart(self, bucket_name, multipart_id): bucket = self.get_bucket(bucket_name) diff --git a/moto/s3/responses.py b/moto/s3/responses.py index bd0dbc88..8a0931ec 100644 --- a/moto/s3/responses.py +++ b/moto/s3/responses.py @@ -2,7 +2,6 @@ from __future__ import unicode_literals import re -from boto.s3.key import Key from jinja2 import Template import six from six.moves.urllib.parse import parse_qs, urlparse @@ -10,7 +9,7 @@ from six.moves.urllib.parse import parse_qs, urlparse from .exceptions import BucketAlreadyExists, MissingBucket from .models import s3_backend -from .utils import bucket_name_from_url +from .utils import bucket_name_from_url, metadata_from_headers from xml.dom import minidom @@ -189,14 +188,9 @@ class ResponseObject(object): new_key = self.backend.set_key(bucket_name, key, f) # Metadata - meta_regex = re.compile('^x-amz-meta-([a-zA-Z0-9\-_]+)$', flags=re.IGNORECASE) + metadata = metadata_from_headers(form) + new_key.set_metadata(metadata) - for form_id in form: - result = meta_regex.match(form_id) - if result: - meta_key = result.group(0).lower() - metadata = form[form_id] - new_key.set_metadata(meta_key, metadata) return 200, headers, "" def _bucket_response_delete_keys(self, request, bucket_name, headers): @@ -228,24 +222,6 @@ class ResponseObject(object): status_code, headers, response_content = response return status_code, headers, response_content - def _key_set_metadata(self, request, key, replace=False): - meta_regex = re.compile('^x-amz-meta-([a-zA-Z0-9\-_]+)$', flags=re.IGNORECASE) - if replace is True: - key.clear_metadata() - for header, value in request.headers.items(): - if isinstance(header, six.string_types): - result = meta_regex.match(header) - meta_key = None - if result: - # Check for extra metadata - meta_key = result.group(0).lower() - elif header.lower() in Key.base_user_settable_fields: - # Check for special metadata that doesn't start with x-amz-meta - meta_key = header - if meta_key: - metadata = request.headers[header] - key.set_metadata(meta_key, metadata) - def _key_response(self, request, full_url, headers): parsed_url = urlparse(full_url) query = parse_qs(parsed_url.query) @@ -270,7 +246,7 @@ class ResponseObject(object): elif method == 'DELETE': return self._key_response_delete(bucket_name, query, key_name, headers) elif method == 'POST': - return self._key_response_post(body, parsed_url, bucket_name, query, key_name, headers) + return self._key_response_post(request, body, parsed_url, bucket_name, query, key_name, headers) else: raise NotImplementedError("Method {0} has not been impelemented in the S3 backend yet".format(method)) @@ -328,7 +304,8 @@ class ResponseObject(object): mdirective = request.headers.get('x-amz-metadata-directive') if mdirective is not None and mdirective == 'REPLACE': new_key = self.backend.get_key(bucket_name, key_name) - self._key_set_metadata(request, new_key, replace=True) + metadata = metadata_from_headers(request.headers) + new_key.set_metadata(metadata, replace=True) template = Template(S3_OBJECT_COPY_RESPONSE) return template.render(key=src_key) streaming_request = hasattr(request, 'streaming') and request.streaming @@ -344,7 +321,8 @@ class ResponseObject(object): new_key = self.backend.set_key(bucket_name, key_name, body, storage=storage_class) request.streaming = True - self._key_set_metadata(request, new_key) + metadata = metadata_from_headers(request.headers) + new_key.set_metadata(metadata) template = Template(S3_OBJECT_RESPONSE) headers.update(new_key.response_dict) @@ -368,9 +346,11 @@ class ResponseObject(object): template = Template(S3_DELETE_OBJECT_SUCCESS) return 204, headers, template.render(bucket=removed_key) - def _key_response_post(self, body, parsed_url, bucket_name, query, key_name, headers): + def _key_response_post(self, request, body, parsed_url, bucket_name, query, key_name, headers): if body == b'' and parsed_url.query == 'uploads': - multipart = self.backend.initiate_multipart(bucket_name, key_name) + metadata = metadata_from_headers(request.headers) + multipart = self.backend.initiate_multipart(bucket_name, key_name, metadata) + template = Template(S3_MULTIPART_INITIATE_RESPONSE) response = template.render( bucket_name=bucket_name, diff --git a/moto/s3/utils.py b/moto/s3/utils.py index c01aea67..3431bb3f 100644 --- a/moto/s3/utils.py +++ b/moto/s3/utils.py @@ -1,7 +1,10 @@ from __future__ import unicode_literals + +from boto.s3.key import Key import re -import sys +import six from six.moves.urllib.parse import urlparse, unquote +import sys bucket_name_regex = re.compile("(.+).s3.amazonaws.com") @@ -24,6 +27,24 @@ def bucket_name_from_url(url): return None +def metadata_from_headers(headers): + metadata = {} + meta_regex = re.compile('^x-amz-meta-([a-zA-Z0-9\-_]+)$', flags=re.IGNORECASE) + for header, value in headers.items(): + if isinstance(header, six.string_types): + result = meta_regex.match(header) + meta_key = None + if result: + # Check for extra metadata + meta_key = result.group(0).lower() + elif header.lower() in Key.base_user_settable_fields: + # Check for special metadata that doesn't start with x-amz-meta + meta_key = header + if meta_key: + metadata[meta_key] = headers[header] + return metadata + + def clean_key_name(key_name): return unquote(key_name) diff --git a/tests/test_s3/test_s3.py b/tests/test_s3/test_s3.py index 4d7df5e0..b9b90c8b 100644 --- a/tests/test_s3/test_s3.py +++ b/tests/test_s3/test_s3.py @@ -87,6 +87,20 @@ def test_multipart_upload(): bucket.get_key("the-key").get_contents_as_string().should.equal(part1 + part2) +@mock_s3 +def test_multipart_upload_with_headers(): + conn = boto.connect_s3('the_key', 'the_secret') + bucket = conn.create_bucket("foobar") + + multipart = bucket.initiate_multipart_upload("the-key", metadata={"foo": "bar"}) + part1 = b'0' * 10 + multipart.upload_part_from_file(BytesIO(part1), 1) + multipart.complete_upload() + + key = bucket.get_key("the-key") + key.metadata.should.equal({"foo": "bar"}) + + @mock_s3 def test_multipart_upload_with_copy_key(): conn = boto.connect_s3('the_key', 'the_secret')