S3 - Add test case to showcase bug when downloading large files

This commit is contained in:
Bert Blommers 2020-03-12 12:25:31 +00:00
commit 3802767817
4 changed files with 111 additions and 3 deletions

View file

@ -12,6 +12,7 @@ import codecs
import random
import string
import tempfile
import threading
import sys
import time
import uuid
@ -110,6 +111,7 @@ class FakeKey(BaseModel):
self._value_buffer = tempfile.SpooledTemporaryFile(max_size=max_buffer_size)
self._max_buffer_size = max_buffer_size
self.value = value
self.lock = threading.Lock()
@property
def version_id(self):
@ -117,8 +119,14 @@ class FakeKey(BaseModel):
@property
def value(self):
self.lock.acquire()
print("===>value")
self._value_buffer.seek(0)
return self._value_buffer.read()
print("===>seek")
r = self._value_buffer.read()
print("===>read")
self.lock.release()
return r
@value.setter
def value(self, new_value):
@ -1319,6 +1327,7 @@ class S3Backend(BaseBackend):
return key
def get_key(self, bucket_name, key_name, version_id=None, part_number=None):
print("get_key("+str(bucket_name)+","+str(key_name)+","+str(version_id)+","+str(part_number)+")")
key_name = clean_key_name(key_name)
bucket = self.get_bucket(bucket_name)
key = None

View file

@ -2,6 +2,7 @@ from __future__ import unicode_literals
import re
import sys
import threading
import six
from botocore.awsrequest import AWSPreparedRequest
@ -150,6 +151,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
self.path = ""
self.data = {}
self.headers = {}
self.lock = threading.Lock()
@property
def should_autoescape(self):
@ -857,6 +859,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
def _handle_range_header(self, request, headers, response_content):
response_headers = {}
length = len(response_content)
print("Length: " + str(length) + " Range: " + str(request.headers.get("range")))
last = length - 1
_, rspec = request.headers.get("range").split("=")
if "," in rspec:
@ -874,6 +877,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
else:
return 400, response_headers, ""
if begin < 0 or end > last or begin > min(end, last):
print(str(begin)+ " < 0 or " + str(end) + " > " + str(last) + " or " + str(begin) + " > min("+str(end)+","+str(last)+")")
return 416, response_headers, ""
response_headers["content-range"] = "bytes {0}-{1}/{2}".format(
begin, end, length
@ -903,14 +907,20 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
response_content = response
else:
status_code, response_headers, response_content = response
print("response received: " + str(len(response_content)))
print(request.headers)
if status_code == 200 and "range" in request.headers:
return self._handle_range_header(
self.lock.acquire()
r = self._handle_range_header(
request, response_headers, response_content
)
self.lock.release()
return r
return status_code, response_headers, response_content
def _control_response(self, request, full_url, headers):
print("_control_response")
parsed_url = urlparse(full_url)
query = parse_qs(parsed_url.query, keep_blank_values=True)
method = request.method
@ -1058,12 +1068,14 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
)
def _key_response_get(self, bucket_name, query, key_name, headers):
print("_key_response_get("+str(key_name)+","+str(headers)+")")
self._set_action("KEY", "GET", query)
self._authenticate_and_authorize_s3_action()
response_headers = {}
if query.get("uploadId"):
upload_id = query["uploadId"][0]
print("UploadID: " + str(upload_id))
parts = self.backend.list_multipart(bucket_name, upload_id)
template = self.response_template(S3_MULTIPART_LIST_RESPONSE)
return (
@ -1095,6 +1107,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
response_headers.update(key.metadata)
response_headers.update(key.response_dict)
print("returning 200, " + str(headers) + ", " + str(len(key.value)) + " ( " + str(key_name) + ")")
return 200, response_headers, key.value
def _key_response_put(self, request, body, bucket_name, query, key_name, headers):

View file

@ -104,7 +104,9 @@ class _VersionedKeyStore(dict):
def get(self, key, default=None):
try:
return self[key]
except (KeyError, IndexError):
except (KeyError, IndexError) as e:
print("Error retrieving " + str(key))
print(e)
pass
return default