Run black on moto & test directories.

This commit is contained in:
Asher Foa 2019-10-31 08:44:26 -07:00
commit 96e5b1993d
507 changed files with 52541 additions and 47814 deletions

View file

@ -12,7 +12,12 @@ import boto.sqs
from moto.core.exceptions import RESTError
from moto.core import BaseBackend, BaseModel
from moto.core.utils import camelcase_to_underscores, get_random_message_id, unix_time, unix_time_millis
from moto.core.utils import (
camelcase_to_underscores,
get_random_message_id,
unix_time,
unix_time_millis,
)
from .utils import generate_receipt_handle
from .exceptions import (
MessageAttributesInvalid,
@ -24,7 +29,7 @@ from .exceptions import (
BatchRequestTooLong,
BatchEntryIdsNotDistinct,
TooManyEntriesInBatchRequest,
InvalidAttributeName
InvalidAttributeName,
)
DEFAULT_ACCOUNT_ID = 123456789012
@ -32,11 +37,10 @@ DEFAULT_SENDER_ID = "AIDAIT2UOQQY3AUEKVGXU"
MAXIMUM_MESSAGE_LENGTH = 262144 # 256 KiB
TRANSPORT_TYPE_ENCODINGS = {'String': b'\x01', 'Binary': b'\x02', 'Number': b'\x01'}
TRANSPORT_TYPE_ENCODINGS = {"String": b"\x01", "Binary": b"\x02", "Number": b"\x01"}
class Message(BaseModel):
def __init__(self, message_id, body):
self.id = message_id
self._body = body
@ -54,7 +58,7 @@ class Message(BaseModel):
@property
def body_md5(self):
md5 = hashlib.md5()
md5.update(self._body.encode('utf-8'))
md5.update(self._body.encode("utf-8"))
return md5.hexdigest()
@property
@ -68,17 +72,19 @@ class Message(BaseModel):
Not yet implemented:
List types (https://github.com/aws/aws-sdk-java/blob/7844c64cf248aed889811bf2e871ad6b276a89ca/aws-java-sdk-sqs/src/main/java/com/amazonaws/services/sqs/MessageMD5ChecksumHandler.java#L58k)
"""
def utf8(str):
if isinstance(str, six.string_types):
return str.encode('utf-8')
return str.encode("utf-8")
return str
md5 = hashlib.md5()
struct_format = "!I".encode('ascii') # ensure it's a bytestring
struct_format = "!I".encode("ascii") # ensure it's a bytestring
for name in sorted(self.message_attributes.keys()):
attr = self.message_attributes[name]
data_type = attr['data_type']
data_type = attr["data_type"]
encoded = utf8('')
encoded = utf8("")
# Each part of each attribute is encoded right after it's
# own length is packed into a 4-byte integer
# 'timestamp' -> b'\x00\x00\x00\t'
@ -88,18 +94,22 @@ class Message(BaseModel):
encoded += struct.pack(struct_format, len(data_type)) + utf8(data_type)
encoded += TRANSPORT_TYPE_ENCODINGS[data_type]
if data_type == 'String' or data_type == 'Number':
value = attr['string_value']
elif data_type == 'Binary':
print(data_type, attr['binary_value'], type(attr['binary_value']))
value = base64.b64decode(attr['binary_value'])
if data_type == "String" or data_type == "Number":
value = attr["string_value"]
elif data_type == "Binary":
print(data_type, attr["binary_value"], type(attr["binary_value"]))
value = base64.b64decode(attr["binary_value"])
else:
print("Moto hasn't implemented MD5 hashing for {} attributes".format(data_type))
print(
"Moto hasn't implemented MD5 hashing for {} attributes".format(
data_type
)
)
# The following should be enough of a clue to users that
# they are not, in fact, looking at a correct MD5 while
# also following the character and length constraints of
# MD5 so as not to break client softwre
return('deadbeefdeadbeefdeadbeefdeadbeef')
return "deadbeefdeadbeefdeadbeefdeadbeef"
encoded += struct.pack(struct_format, len(utf8(value))) + utf8(value)
@ -162,24 +172,30 @@ class Message(BaseModel):
class Queue(BaseModel):
BASE_ATTRIBUTES = ['ApproximateNumberOfMessages',
'ApproximateNumberOfMessagesDelayed',
'ApproximateNumberOfMessagesNotVisible',
'CreatedTimestamp',
'DelaySeconds',
'LastModifiedTimestamp',
'MaximumMessageSize',
'MessageRetentionPeriod',
'QueueArn',
'ReceiveMessageWaitTimeSeconds',
'VisibilityTimeout']
FIFO_ATTRIBUTES = ['FifoQueue',
'ContentBasedDeduplication']
KMS_ATTRIBUTES = ['KmsDataKeyReusePeriodSeconds',
'KmsMasterKeyId']
ALLOWED_PERMISSIONS = ('*', 'ChangeMessageVisibility', 'DeleteMessage',
'GetQueueAttributes', 'GetQueueUrl',
'ReceiveMessage', 'SendMessage')
BASE_ATTRIBUTES = [
"ApproximateNumberOfMessages",
"ApproximateNumberOfMessagesDelayed",
"ApproximateNumberOfMessagesNotVisible",
"CreatedTimestamp",
"DelaySeconds",
"LastModifiedTimestamp",
"MaximumMessageSize",
"MessageRetentionPeriod",
"QueueArn",
"ReceiveMessageWaitTimeSeconds",
"VisibilityTimeout",
]
FIFO_ATTRIBUTES = ["FifoQueue", "ContentBasedDeduplication"]
KMS_ATTRIBUTES = ["KmsDataKeyReusePeriodSeconds", "KmsMasterKeyId"]
ALLOWED_PERMISSIONS = (
"*",
"ChangeMessageVisibility",
"DeleteMessage",
"GetQueueAttributes",
"GetQueueUrl",
"ReceiveMessage",
"SendMessage",
)
def __init__(self, name, region, **kwargs):
self.name = name
@ -192,34 +208,36 @@ class Queue(BaseModel):
now = unix_time()
self.created_timestamp = now
self.queue_arn = 'arn:aws:sqs:{0}:{1}:{2}'.format(self.region,
DEFAULT_ACCOUNT_ID,
self.name)
self.queue_arn = "arn:aws:sqs:{0}:{1}:{2}".format(
self.region, DEFAULT_ACCOUNT_ID, self.name
)
self.dead_letter_queue = None
self.lambda_event_source_mappings = {}
# default settings for a non fifo queue
defaults = {
'ContentBasedDeduplication': 'false',
'DelaySeconds': 0,
'FifoQueue': 'false',
'KmsDataKeyReusePeriodSeconds': 300, # five minutes
'KmsMasterKeyId': None,
'MaximumMessageSize': int(64 << 10),
'MessageRetentionPeriod': 86400 * 4, # four days
'Policy': None,
'ReceiveMessageWaitTimeSeconds': 0,
'RedrivePolicy': None,
'VisibilityTimeout': 30,
"ContentBasedDeduplication": "false",
"DelaySeconds": 0,
"FifoQueue": "false",
"KmsDataKeyReusePeriodSeconds": 300, # five minutes
"KmsMasterKeyId": None,
"MaximumMessageSize": int(64 << 10),
"MessageRetentionPeriod": 86400 * 4, # four days
"Policy": None,
"ReceiveMessageWaitTimeSeconds": 0,
"RedrivePolicy": None,
"VisibilityTimeout": 30,
}
defaults.update(kwargs)
self._set_attributes(defaults, now)
# Check some conditions
if self.fifo_queue and not self.name.endswith('.fifo'):
raise MessageAttributesInvalid('Queue name must end in .fifo for FIFO queues')
if self.fifo_queue and not self.name.endswith(".fifo"):
raise MessageAttributesInvalid(
"Queue name must end in .fifo for FIFO queues"
)
@property
def pending_messages(self):
@ -227,18 +245,25 @@ class Queue(BaseModel):
@property
def pending_message_groups(self):
return set(message.group_id
for message in self._pending_messages
if message.group_id is not None)
return set(
message.group_id
for message in self._pending_messages
if message.group_id is not None
)
def _set_attributes(self, attributes, now=None):
if not now:
now = unix_time()
integer_fields = ('DelaySeconds', 'KmsDataKeyreusePeriodSeconds',
'MaximumMessageSize', 'MessageRetentionPeriod',
'ReceiveMessageWaitTime', 'VisibilityTimeout')
bool_fields = ('ContentBasedDeduplication', 'FifoQueue')
integer_fields = (
"DelaySeconds",
"KmsDataKeyreusePeriodSeconds",
"MaximumMessageSize",
"MessageRetentionPeriod",
"ReceiveMessageWaitTime",
"VisibilityTimeout",
)
bool_fields = ("ContentBasedDeduplication", "FifoQueue")
for key, value in six.iteritems(attributes):
if key in integer_fields:
@ -246,13 +271,13 @@ class Queue(BaseModel):
if key in bool_fields:
value = value == "true"
if key == 'RedrivePolicy' and value is not None:
if key == "RedrivePolicy" and value is not None:
continue
setattr(self, camelcase_to_underscores(key), value)
if attributes.get('RedrivePolicy', None):
self._setup_dlq(attributes['RedrivePolicy'])
if attributes.get("RedrivePolicy", None):
self._setup_dlq(attributes["RedrivePolicy"])
self.last_modified_timestamp = now
@ -262,59 +287,86 @@ class Queue(BaseModel):
try:
self.redrive_policy = json.loads(policy)
except ValueError:
raise RESTError('InvalidParameterValue', 'Redrive policy is not a dict or valid json')
raise RESTError(
"InvalidParameterValue",
"Redrive policy is not a dict or valid json",
)
elif isinstance(policy, dict):
self.redrive_policy = policy
else:
raise RESTError('InvalidParameterValue', 'Redrive policy is not a dict or valid json')
raise RESTError(
"InvalidParameterValue", "Redrive policy is not a dict or valid json"
)
if 'deadLetterTargetArn' not in self.redrive_policy:
raise RESTError('InvalidParameterValue', 'Redrive policy does not contain deadLetterTargetArn')
if 'maxReceiveCount' not in self.redrive_policy:
raise RESTError('InvalidParameterValue', 'Redrive policy does not contain maxReceiveCount')
if "deadLetterTargetArn" not in self.redrive_policy:
raise RESTError(
"InvalidParameterValue",
"Redrive policy does not contain deadLetterTargetArn",
)
if "maxReceiveCount" not in self.redrive_policy:
raise RESTError(
"InvalidParameterValue",
"Redrive policy does not contain maxReceiveCount",
)
# 'maxReceiveCount' is stored as int
self.redrive_policy['maxReceiveCount'] = int(self.redrive_policy['maxReceiveCount'])
self.redrive_policy["maxReceiveCount"] = int(
self.redrive_policy["maxReceiveCount"]
)
for queue in sqs_backends[self.region].queues.values():
if queue.queue_arn == self.redrive_policy['deadLetterTargetArn']:
if queue.queue_arn == self.redrive_policy["deadLetterTargetArn"]:
self.dead_letter_queue = queue
if self.fifo_queue and not queue.fifo_queue:
raise RESTError('InvalidParameterCombination', 'Fifo queues cannot use non fifo dead letter queues')
raise RESTError(
"InvalidParameterCombination",
"Fifo queues cannot use non fifo dead letter queues",
)
break
else:
raise RESTError('AWS.SimpleQueueService.NonExistentQueue', 'Could not find DLQ for {0}'.format(self.redrive_policy['deadLetterTargetArn']))
raise RESTError(
"AWS.SimpleQueueService.NonExistentQueue",
"Could not find DLQ for {0}".format(
self.redrive_policy["deadLetterTargetArn"]
),
)
@classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name):
properties = cloudformation_json['Properties']
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
properties = cloudformation_json["Properties"]
sqs_backend = sqs_backends[region_name]
return sqs_backend.create_queue(
name=properties['QueueName'],
region=region_name,
**properties
name=properties["QueueName"], region=region_name, **properties
)
@classmethod
def update_from_cloudformation_json(cls, original_resource, new_resource_name, cloudformation_json, region_name):
properties = cloudformation_json['Properties']
queue_name = properties['QueueName']
def update_from_cloudformation_json(
cls, original_resource, new_resource_name, cloudformation_json, region_name
):
properties = cloudformation_json["Properties"]
queue_name = properties["QueueName"]
sqs_backend = sqs_backends[region_name]
queue = sqs_backend.get_queue(queue_name)
if 'VisibilityTimeout' in properties:
queue.visibility_timeout = int(properties['VisibilityTimeout'])
if "VisibilityTimeout" in properties:
queue.visibility_timeout = int(properties["VisibilityTimeout"])
if 'ReceiveMessageWaitTimeSeconds' in properties:
queue.receive_message_wait_time_seconds = int(properties['ReceiveMessageWaitTimeSeconds'])
if "ReceiveMessageWaitTimeSeconds" in properties:
queue.receive_message_wait_time_seconds = int(
properties["ReceiveMessageWaitTimeSeconds"]
)
return queue
@classmethod
def delete_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name):
properties = cloudformation_json['Properties']
queue_name = properties['QueueName']
def delete_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
properties = cloudformation_json["Properties"]
queue_name = properties["QueueName"]
sqs_backend = sqs_backends[region_name]
sqs_backend.delete_queue(queue_name)
@ -353,10 +405,10 @@ class Queue(BaseModel):
result[attribute] = attr
if self.policy:
result['Policy'] = self.policy
result["Policy"] = self.policy
if self.redrive_policy:
result['RedrivePolicy'] = json.dumps(self.redrive_policy)
result["RedrivePolicy"] = json.dumps(self.redrive_policy)
for key in result:
if isinstance(result[key], bool):
@ -365,15 +417,22 @@ class Queue(BaseModel):
return result
def url(self, request_url):
return "{0}://{1}/123456789012/{2}".format(request_url.scheme, request_url.netloc, self.name)
return "{0}://{1}/123456789012/{2}".format(
request_url.scheme, request_url.netloc, self.name
)
@property
def messages(self):
return [message for message in self._messages if message.visible and not message.delayed]
return [
message
for message in self._messages
if message.visible and not message.delayed
]
def add_message(self, message):
self._messages.append(message)
from moto.awslambda import lambda_backends
for arn, esm in self.lambda_event_source_mappings.items():
backend = sqs_backends[self.region]
@ -391,27 +450,28 @@ class Queue(BaseModel):
)
result = lambda_backends[self.region].send_sqs_batch(
arn,
messages,
self.queue_arn,
arn, messages, self.queue_arn
)
if result:
[backend.delete_message(self.name, m.receipt_handle) for m in messages]
else:
[backend.change_message_visibility(self.name, m.receipt_handle, 0) for m in messages]
[
backend.change_message_visibility(self.name, m.receipt_handle, 0)
for m in messages
]
def get_cfn_attribute(self, attribute_name):
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
if attribute_name == 'Arn':
if attribute_name == "Arn":
return self.queue_arn
elif attribute_name == 'QueueName':
elif attribute_name == "QueueName":
return self.name
raise UnformattedGetAttTemplateException()
class SQSBackend(BaseBackend):
def __init__(self, region_name):
self.region_name = region_name
self.queues = {}
@ -427,7 +487,7 @@ class SQSBackend(BaseBackend):
queue = self.queues.get(name)
if queue:
try:
kwargs.pop('region')
kwargs.pop("region")
except KeyError:
pass
@ -436,28 +496,26 @@ class SQSBackend(BaseBackend):
queue_attributes = queue.attributes
new_queue_attributes = new_queue.attributes
static_attributes = (
'DelaySeconds',
'MaximumMessageSize',
'MessageRetentionPeriod',
'Policy',
'QueueArn',
'ReceiveMessageWaitTimeSeconds',
'RedrivePolicy',
'VisibilityTimeout',
'KmsMasterKeyId',
'KmsDataKeyReusePeriodSeconds',
'FifoQueue',
'ContentBasedDeduplication',
"DelaySeconds",
"MaximumMessageSize",
"MessageRetentionPeriod",
"Policy",
"QueueArn",
"ReceiveMessageWaitTimeSeconds",
"RedrivePolicy",
"VisibilityTimeout",
"KmsMasterKeyId",
"KmsDataKeyReusePeriodSeconds",
"FifoQueue",
"ContentBasedDeduplication",
)
for key in static_attributes:
if queue_attributes.get(key) != new_queue_attributes.get(key):
raise QueueAlreadyExists(
"The specified queue already exists.",
)
raise QueueAlreadyExists("The specified queue already exists.")
else:
try:
kwargs.pop('region')
kwargs.pop("region")
except KeyError:
pass
queue = Queue(name, region=self.region_name, **kwargs)
@ -472,9 +530,9 @@ class SQSBackend(BaseBackend):
return self.get_queue(queue_name)
def list_queues(self, queue_name_prefix):
re_str = '.*'
re_str = ".*"
if queue_name_prefix:
re_str = '^{0}.*'.format(queue_name_prefix)
re_str = "^{0}.*".format(queue_name_prefix)
prefix_re = re.compile(re_str)
qs = []
for name, q in self.queues.items():
@ -497,17 +555,24 @@ class SQSBackend(BaseBackend):
queue = self.get_queue(queue_name)
if not len(attribute_names):
attribute_names.append('All')
attribute_names.append("All")
valid_names = ['All'] + queue.BASE_ATTRIBUTES + queue.FIFO_ATTRIBUTES + queue.KMS_ATTRIBUTES
invalid_name = next((name for name in attribute_names if name not in valid_names), None)
valid_names = (
["All"]
+ queue.BASE_ATTRIBUTES
+ queue.FIFO_ATTRIBUTES
+ queue.KMS_ATTRIBUTES
)
invalid_name = next(
(name for name in attribute_names if name not in valid_names), None
)
if invalid_name or invalid_name == '':
if invalid_name or invalid_name == "":
raise InvalidAttributeName(invalid_name)
attributes = {}
if 'All' in attribute_names:
if "All" in attribute_names:
attributes = queue.attributes
else:
for name in (name for name in attribute_names if name in queue.attributes):
@ -520,7 +585,15 @@ class SQSBackend(BaseBackend):
queue._set_attributes(attributes)
return queue
def send_message(self, queue_name, message_body, message_attributes=None, delay_seconds=None, deduplication_id=None, group_id=None):
def send_message(
self,
queue_name,
message_body,
message_attributes=None,
delay_seconds=None,
deduplication_id=None,
group_id=None,
):
queue = self.get_queue(queue_name)
@ -541,9 +614,7 @@ class SQSBackend(BaseBackend):
if message_attributes:
message.message_attributes = message_attributes
message.mark_sent(
delay_seconds=delay_seconds
)
message.mark_sent(delay_seconds=delay_seconds)
queue.add_message(message)
@ -552,17 +623,25 @@ class SQSBackend(BaseBackend):
def send_message_batch(self, queue_name, entries):
self.get_queue(queue_name)
if any(not re.match(r'^[\w-]{1,80}$', entry['Id']) for entry in entries.values()):
if any(
not re.match(r"^[\w-]{1,80}$", entry["Id"]) for entry in entries.values()
):
raise InvalidBatchEntryId()
body_length = next(
(len(entry['MessageBody']) for entry in entries.values() if len(entry['MessageBody']) > MAXIMUM_MESSAGE_LENGTH),
False
(
len(entry["MessageBody"])
for entry in entries.values()
if len(entry["MessageBody"]) > MAXIMUM_MESSAGE_LENGTH
),
False,
)
if body_length:
raise BatchRequestTooLong(body_length)
duplicate_id = self._get_first_duplicate_id([entry['Id'] for entry in entries.values()])
duplicate_id = self._get_first_duplicate_id(
[entry["Id"] for entry in entries.values()]
)
if duplicate_id:
raise BatchEntryIdsNotDistinct(duplicate_id)
@ -574,11 +653,11 @@ class SQSBackend(BaseBackend):
# Loop through looking for messages
message = self.send_message(
queue_name,
entry['MessageBody'],
message_attributes=entry['MessageAttributes'],
delay_seconds=entry['DelaySeconds']
entry["MessageBody"],
message_attributes=entry["MessageAttributes"],
delay_seconds=entry["DelaySeconds"],
)
message.user_id = entry['Id']
message.user_id = entry["Id"]
messages.append(message)
@ -592,7 +671,9 @@ class SQSBackend(BaseBackend):
unique_ids.add(id)
return None
def receive_messages(self, queue_name, count, wait_seconds_timeout, visibility_timeout):
def receive_messages(
self, queue_name, count, wait_seconds_timeout, visibility_timeout
):
"""
Attempt to retrieve visible messages from a queue.
@ -638,13 +719,15 @@ class SQSBackend(BaseBackend):
queue.pending_messages.add(message)
if queue.dead_letter_queue is not None and message.approximate_receive_count >= queue.redrive_policy['maxReceiveCount']:
if (
queue.dead_letter_queue is not None
and message.approximate_receive_count
>= queue.redrive_policy["maxReceiveCount"]
):
messages_to_dlq.append(message)
continue
message.mark_received(
visibility_timeout=visibility_timeout
)
message.mark_received(visibility_timeout=visibility_timeout)
result.append(message)
if len(result) >= count:
break
@ -660,6 +743,7 @@ class SQSBackend(BaseBackend):
break
import time
time.sleep(0.01)
continue
@ -670,7 +754,9 @@ class SQSBackend(BaseBackend):
def delete_message(self, queue_name, receipt_handle):
queue = self.get_queue(queue_name)
if not any(message.receipt_handle == receipt_handle for message in queue._messages):
if not any(
message.receipt_handle == receipt_handle for message in queue._messages
):
raise ReceiptHandleIsInvalid()
new_messages = []
@ -715,12 +801,12 @@ class SQSBackend(BaseBackend):
queue = self.get_queue(queue_name)
if actions is None or len(actions) == 0:
raise RESTError('InvalidParameterValue', 'Need at least one Action')
raise RESTError("InvalidParameterValue", "Need at least one Action")
if account_ids is None or len(account_ids) == 0:
raise RESTError('InvalidParameterValue', 'Need at least one Account ID')
raise RESTError("InvalidParameterValue", "Need at least one Account ID")
if not all([item in Queue.ALLOWED_PERMISSIONS for item in actions]):
raise RESTError('InvalidParameterValue', 'Invalid permissions')
raise RESTError("InvalidParameterValue", "Invalid permissions")
queue.permissions[label] = (account_ids, actions)
@ -728,7 +814,9 @@ class SQSBackend(BaseBackend):
queue = self.get_queue(queue_name)
if label not in queue.permissions:
raise RESTError('InvalidParameterValue', 'Permission doesnt exist for the given label')
raise RESTError(
"InvalidParameterValue", "Permission doesnt exist for the given label"
)
del queue.permissions[label]
@ -736,12 +824,15 @@ class SQSBackend(BaseBackend):
queue = self.get_queue(queue_name)
if not len(tags):
raise RESTError('MissingParameter',
'The request must contain the parameter Tags.')
raise RESTError(
"MissingParameter", "The request must contain the parameter Tags."
)
if len(tags) > 50:
raise RESTError('InvalidParameterValue',
'Too many tags added for queue {}.'.format(queue_name))
raise RESTError(
"InvalidParameterValue",
"Too many tags added for queue {}.".format(queue_name),
)
queue.tags.update(tags)
@ -749,7 +840,10 @@ class SQSBackend(BaseBackend):
queue = self.get_queue(queue_name)
if not len(tag_keys):
raise RESTError('InvalidParameterValue', 'Tag keys must be between 1 and 128 characters in length.')
raise RESTError(
"InvalidParameterValue",
"Tag keys must be between 1 and 128 characters in length.",
)
for key in tag_keys:
try: