Run black on moto & test directories.

This commit is contained in:
Asher Foa 2019-10-31 08:44:26 -07:00
commit 96e5b1993d
507 changed files with 52541 additions and 47814 deletions

View file

@ -2,6 +2,6 @@ from __future__ import unicode_literals
from .models import sqs_backends
from ..core.models import base_decorator, deprecated_base_decorator
sqs_backend = sqs_backends['us-east-1']
sqs_backend = sqs_backends["us-east-1"]
mock_sqs = base_decorator(sqs_backends)
mock_sqs_deprecated = deprecated_base_decorator(sqs_backends)

View file

@ -12,8 +12,7 @@ class ReceiptHandleIsInvalid(RESTError):
def __init__(self):
super(ReceiptHandleIsInvalid, self).__init__(
'ReceiptHandleIsInvalid',
'The input receipt handle is invalid.'
"ReceiptHandleIsInvalid", "The input receipt handle is invalid."
)
@ -29,15 +28,16 @@ class QueueDoesNotExist(RESTError):
def __init__(self):
super(QueueDoesNotExist, self).__init__(
"QueueDoesNotExist", "The specified queue does not exist for this wsdl version.")
"QueueDoesNotExist",
"The specified queue does not exist for this wsdl version.",
)
class QueueAlreadyExists(RESTError):
code = 400
def __init__(self, message):
super(QueueAlreadyExists, self).__init__(
"QueueAlreadyExists", message)
super(QueueAlreadyExists, self).__init__("QueueAlreadyExists", message)
class EmptyBatchRequest(RESTError):
@ -45,8 +45,8 @@ class EmptyBatchRequest(RESTError):
def __init__(self):
super(EmptyBatchRequest, self).__init__(
'EmptyBatchRequest',
'There should be at least one SendMessageBatchRequestEntry in the request.'
"EmptyBatchRequest",
"There should be at least one SendMessageBatchRequestEntry in the request.",
)
@ -55,9 +55,9 @@ class InvalidBatchEntryId(RESTError):
def __init__(self):
super(InvalidBatchEntryId, self).__init__(
'InvalidBatchEntryId',
'A batch entry id can only contain alphanumeric characters, '
'hyphens and underscores. It can be at most 80 letters long.'
"InvalidBatchEntryId",
"A batch entry id can only contain alphanumeric characters, "
"hyphens and underscores. It can be at most 80 letters long.",
)
@ -66,9 +66,9 @@ class BatchRequestTooLong(RESTError):
def __init__(self, length):
super(BatchRequestTooLong, self).__init__(
'BatchRequestTooLong',
'Batch requests cannot be longer than 262144 bytes. '
'You have sent {} bytes.'.format(length)
"BatchRequestTooLong",
"Batch requests cannot be longer than 262144 bytes. "
"You have sent {} bytes.".format(length),
)
@ -77,8 +77,7 @@ class BatchEntryIdsNotDistinct(RESTError):
def __init__(self, entry_id):
super(BatchEntryIdsNotDistinct, self).__init__(
'BatchEntryIdsNotDistinct',
'Id {} repeated.'.format(entry_id)
"BatchEntryIdsNotDistinct", "Id {} repeated.".format(entry_id)
)
@ -87,9 +86,9 @@ class TooManyEntriesInBatchRequest(RESTError):
def __init__(self, number):
super(TooManyEntriesInBatchRequest, self).__init__(
'TooManyEntriesInBatchRequest',
'Maximum number of entries per request are 10. '
'You have sent {}.'.format(number)
"TooManyEntriesInBatchRequest",
"Maximum number of entries per request are 10. "
"You have sent {}.".format(number),
)
@ -98,6 +97,5 @@ class InvalidAttributeName(RESTError):
def __init__(self, attribute_name):
super(InvalidAttributeName, self).__init__(
'InvalidAttributeName',
'Unknown Attribute {}.'.format(attribute_name)
"InvalidAttributeName", "Unknown Attribute {}.".format(attribute_name)
)

View file

@ -12,7 +12,12 @@ import boto.sqs
from moto.core.exceptions import RESTError
from moto.core import BaseBackend, BaseModel
from moto.core.utils import camelcase_to_underscores, get_random_message_id, unix_time, unix_time_millis
from moto.core.utils import (
camelcase_to_underscores,
get_random_message_id,
unix_time,
unix_time_millis,
)
from .utils import generate_receipt_handle
from .exceptions import (
MessageAttributesInvalid,
@ -24,7 +29,7 @@ from .exceptions import (
BatchRequestTooLong,
BatchEntryIdsNotDistinct,
TooManyEntriesInBatchRequest,
InvalidAttributeName
InvalidAttributeName,
)
DEFAULT_ACCOUNT_ID = 123456789012
@ -32,11 +37,10 @@ DEFAULT_SENDER_ID = "AIDAIT2UOQQY3AUEKVGXU"
MAXIMUM_MESSAGE_LENGTH = 262144 # 256 KiB
TRANSPORT_TYPE_ENCODINGS = {'String': b'\x01', 'Binary': b'\x02', 'Number': b'\x01'}
TRANSPORT_TYPE_ENCODINGS = {"String": b"\x01", "Binary": b"\x02", "Number": b"\x01"}
class Message(BaseModel):
def __init__(self, message_id, body):
self.id = message_id
self._body = body
@ -54,7 +58,7 @@ class Message(BaseModel):
@property
def body_md5(self):
md5 = hashlib.md5()
md5.update(self._body.encode('utf-8'))
md5.update(self._body.encode("utf-8"))
return md5.hexdigest()
@property
@ -68,17 +72,19 @@ class Message(BaseModel):
Not yet implemented:
List types (https://github.com/aws/aws-sdk-java/blob/7844c64cf248aed889811bf2e871ad6b276a89ca/aws-java-sdk-sqs/src/main/java/com/amazonaws/services/sqs/MessageMD5ChecksumHandler.java#L58k)
"""
def utf8(str):
if isinstance(str, six.string_types):
return str.encode('utf-8')
return str.encode("utf-8")
return str
md5 = hashlib.md5()
struct_format = "!I".encode('ascii') # ensure it's a bytestring
struct_format = "!I".encode("ascii") # ensure it's a bytestring
for name in sorted(self.message_attributes.keys()):
attr = self.message_attributes[name]
data_type = attr['data_type']
data_type = attr["data_type"]
encoded = utf8('')
encoded = utf8("")
# Each part of each attribute is encoded right after it's
# own length is packed into a 4-byte integer
# 'timestamp' -> b'\x00\x00\x00\t'
@ -88,18 +94,22 @@ class Message(BaseModel):
encoded += struct.pack(struct_format, len(data_type)) + utf8(data_type)
encoded += TRANSPORT_TYPE_ENCODINGS[data_type]
if data_type == 'String' or data_type == 'Number':
value = attr['string_value']
elif data_type == 'Binary':
print(data_type, attr['binary_value'], type(attr['binary_value']))
value = base64.b64decode(attr['binary_value'])
if data_type == "String" or data_type == "Number":
value = attr["string_value"]
elif data_type == "Binary":
print(data_type, attr["binary_value"], type(attr["binary_value"]))
value = base64.b64decode(attr["binary_value"])
else:
print("Moto hasn't implemented MD5 hashing for {} attributes".format(data_type))
print(
"Moto hasn't implemented MD5 hashing for {} attributes".format(
data_type
)
)
# The following should be enough of a clue to users that
# they are not, in fact, looking at a correct MD5 while
# also following the character and length constraints of
# MD5 so as not to break client softwre
return('deadbeefdeadbeefdeadbeefdeadbeef')
return "deadbeefdeadbeefdeadbeefdeadbeef"
encoded += struct.pack(struct_format, len(utf8(value))) + utf8(value)
@ -162,24 +172,30 @@ class Message(BaseModel):
class Queue(BaseModel):
BASE_ATTRIBUTES = ['ApproximateNumberOfMessages',
'ApproximateNumberOfMessagesDelayed',
'ApproximateNumberOfMessagesNotVisible',
'CreatedTimestamp',
'DelaySeconds',
'LastModifiedTimestamp',
'MaximumMessageSize',
'MessageRetentionPeriod',
'QueueArn',
'ReceiveMessageWaitTimeSeconds',
'VisibilityTimeout']
FIFO_ATTRIBUTES = ['FifoQueue',
'ContentBasedDeduplication']
KMS_ATTRIBUTES = ['KmsDataKeyReusePeriodSeconds',
'KmsMasterKeyId']
ALLOWED_PERMISSIONS = ('*', 'ChangeMessageVisibility', 'DeleteMessage',
'GetQueueAttributes', 'GetQueueUrl',
'ReceiveMessage', 'SendMessage')
BASE_ATTRIBUTES = [
"ApproximateNumberOfMessages",
"ApproximateNumberOfMessagesDelayed",
"ApproximateNumberOfMessagesNotVisible",
"CreatedTimestamp",
"DelaySeconds",
"LastModifiedTimestamp",
"MaximumMessageSize",
"MessageRetentionPeriod",
"QueueArn",
"ReceiveMessageWaitTimeSeconds",
"VisibilityTimeout",
]
FIFO_ATTRIBUTES = ["FifoQueue", "ContentBasedDeduplication"]
KMS_ATTRIBUTES = ["KmsDataKeyReusePeriodSeconds", "KmsMasterKeyId"]
ALLOWED_PERMISSIONS = (
"*",
"ChangeMessageVisibility",
"DeleteMessage",
"GetQueueAttributes",
"GetQueueUrl",
"ReceiveMessage",
"SendMessage",
)
def __init__(self, name, region, **kwargs):
self.name = name
@ -192,34 +208,36 @@ class Queue(BaseModel):
now = unix_time()
self.created_timestamp = now
self.queue_arn = 'arn:aws:sqs:{0}:{1}:{2}'.format(self.region,
DEFAULT_ACCOUNT_ID,
self.name)
self.queue_arn = "arn:aws:sqs:{0}:{1}:{2}".format(
self.region, DEFAULT_ACCOUNT_ID, self.name
)
self.dead_letter_queue = None
self.lambda_event_source_mappings = {}
# default settings for a non fifo queue
defaults = {
'ContentBasedDeduplication': 'false',
'DelaySeconds': 0,
'FifoQueue': 'false',
'KmsDataKeyReusePeriodSeconds': 300, # five minutes
'KmsMasterKeyId': None,
'MaximumMessageSize': int(64 << 10),
'MessageRetentionPeriod': 86400 * 4, # four days
'Policy': None,
'ReceiveMessageWaitTimeSeconds': 0,
'RedrivePolicy': None,
'VisibilityTimeout': 30,
"ContentBasedDeduplication": "false",
"DelaySeconds": 0,
"FifoQueue": "false",
"KmsDataKeyReusePeriodSeconds": 300, # five minutes
"KmsMasterKeyId": None,
"MaximumMessageSize": int(64 << 10),
"MessageRetentionPeriod": 86400 * 4, # four days
"Policy": None,
"ReceiveMessageWaitTimeSeconds": 0,
"RedrivePolicy": None,
"VisibilityTimeout": 30,
}
defaults.update(kwargs)
self._set_attributes(defaults, now)
# Check some conditions
if self.fifo_queue and not self.name.endswith('.fifo'):
raise MessageAttributesInvalid('Queue name must end in .fifo for FIFO queues')
if self.fifo_queue and not self.name.endswith(".fifo"):
raise MessageAttributesInvalid(
"Queue name must end in .fifo for FIFO queues"
)
@property
def pending_messages(self):
@ -227,18 +245,25 @@ class Queue(BaseModel):
@property
def pending_message_groups(self):
return set(message.group_id
for message in self._pending_messages
if message.group_id is not None)
return set(
message.group_id
for message in self._pending_messages
if message.group_id is not None
)
def _set_attributes(self, attributes, now=None):
if not now:
now = unix_time()
integer_fields = ('DelaySeconds', 'KmsDataKeyreusePeriodSeconds',
'MaximumMessageSize', 'MessageRetentionPeriod',
'ReceiveMessageWaitTime', 'VisibilityTimeout')
bool_fields = ('ContentBasedDeduplication', 'FifoQueue')
integer_fields = (
"DelaySeconds",
"KmsDataKeyreusePeriodSeconds",
"MaximumMessageSize",
"MessageRetentionPeriod",
"ReceiveMessageWaitTime",
"VisibilityTimeout",
)
bool_fields = ("ContentBasedDeduplication", "FifoQueue")
for key, value in six.iteritems(attributes):
if key in integer_fields:
@ -246,13 +271,13 @@ class Queue(BaseModel):
if key in bool_fields:
value = value == "true"
if key == 'RedrivePolicy' and value is not None:
if key == "RedrivePolicy" and value is not None:
continue
setattr(self, camelcase_to_underscores(key), value)
if attributes.get('RedrivePolicy', None):
self._setup_dlq(attributes['RedrivePolicy'])
if attributes.get("RedrivePolicy", None):
self._setup_dlq(attributes["RedrivePolicy"])
self.last_modified_timestamp = now
@ -262,59 +287,86 @@ class Queue(BaseModel):
try:
self.redrive_policy = json.loads(policy)
except ValueError:
raise RESTError('InvalidParameterValue', 'Redrive policy is not a dict or valid json')
raise RESTError(
"InvalidParameterValue",
"Redrive policy is not a dict or valid json",
)
elif isinstance(policy, dict):
self.redrive_policy = policy
else:
raise RESTError('InvalidParameterValue', 'Redrive policy is not a dict or valid json')
raise RESTError(
"InvalidParameterValue", "Redrive policy is not a dict or valid json"
)
if 'deadLetterTargetArn' not in self.redrive_policy:
raise RESTError('InvalidParameterValue', 'Redrive policy does not contain deadLetterTargetArn')
if 'maxReceiveCount' not in self.redrive_policy:
raise RESTError('InvalidParameterValue', 'Redrive policy does not contain maxReceiveCount')
if "deadLetterTargetArn" not in self.redrive_policy:
raise RESTError(
"InvalidParameterValue",
"Redrive policy does not contain deadLetterTargetArn",
)
if "maxReceiveCount" not in self.redrive_policy:
raise RESTError(
"InvalidParameterValue",
"Redrive policy does not contain maxReceiveCount",
)
# 'maxReceiveCount' is stored as int
self.redrive_policy['maxReceiveCount'] = int(self.redrive_policy['maxReceiveCount'])
self.redrive_policy["maxReceiveCount"] = int(
self.redrive_policy["maxReceiveCount"]
)
for queue in sqs_backends[self.region].queues.values():
if queue.queue_arn == self.redrive_policy['deadLetterTargetArn']:
if queue.queue_arn == self.redrive_policy["deadLetterTargetArn"]:
self.dead_letter_queue = queue
if self.fifo_queue and not queue.fifo_queue:
raise RESTError('InvalidParameterCombination', 'Fifo queues cannot use non fifo dead letter queues')
raise RESTError(
"InvalidParameterCombination",
"Fifo queues cannot use non fifo dead letter queues",
)
break
else:
raise RESTError('AWS.SimpleQueueService.NonExistentQueue', 'Could not find DLQ for {0}'.format(self.redrive_policy['deadLetterTargetArn']))
raise RESTError(
"AWS.SimpleQueueService.NonExistentQueue",
"Could not find DLQ for {0}".format(
self.redrive_policy["deadLetterTargetArn"]
),
)
@classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name):
properties = cloudformation_json['Properties']
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
properties = cloudformation_json["Properties"]
sqs_backend = sqs_backends[region_name]
return sqs_backend.create_queue(
name=properties['QueueName'],
region=region_name,
**properties
name=properties["QueueName"], region=region_name, **properties
)
@classmethod
def update_from_cloudformation_json(cls, original_resource, new_resource_name, cloudformation_json, region_name):
properties = cloudformation_json['Properties']
queue_name = properties['QueueName']
def update_from_cloudformation_json(
cls, original_resource, new_resource_name, cloudformation_json, region_name
):
properties = cloudformation_json["Properties"]
queue_name = properties["QueueName"]
sqs_backend = sqs_backends[region_name]
queue = sqs_backend.get_queue(queue_name)
if 'VisibilityTimeout' in properties:
queue.visibility_timeout = int(properties['VisibilityTimeout'])
if "VisibilityTimeout" in properties:
queue.visibility_timeout = int(properties["VisibilityTimeout"])
if 'ReceiveMessageWaitTimeSeconds' in properties:
queue.receive_message_wait_time_seconds = int(properties['ReceiveMessageWaitTimeSeconds'])
if "ReceiveMessageWaitTimeSeconds" in properties:
queue.receive_message_wait_time_seconds = int(
properties["ReceiveMessageWaitTimeSeconds"]
)
return queue
@classmethod
def delete_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name):
properties = cloudformation_json['Properties']
queue_name = properties['QueueName']
def delete_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
properties = cloudformation_json["Properties"]
queue_name = properties["QueueName"]
sqs_backend = sqs_backends[region_name]
sqs_backend.delete_queue(queue_name)
@ -353,10 +405,10 @@ class Queue(BaseModel):
result[attribute] = attr
if self.policy:
result['Policy'] = self.policy
result["Policy"] = self.policy
if self.redrive_policy:
result['RedrivePolicy'] = json.dumps(self.redrive_policy)
result["RedrivePolicy"] = json.dumps(self.redrive_policy)
for key in result:
if isinstance(result[key], bool):
@ -365,15 +417,22 @@ class Queue(BaseModel):
return result
def url(self, request_url):
return "{0}://{1}/123456789012/{2}".format(request_url.scheme, request_url.netloc, self.name)
return "{0}://{1}/123456789012/{2}".format(
request_url.scheme, request_url.netloc, self.name
)
@property
def messages(self):
return [message for message in self._messages if message.visible and not message.delayed]
return [
message
for message in self._messages
if message.visible and not message.delayed
]
def add_message(self, message):
self._messages.append(message)
from moto.awslambda import lambda_backends
for arn, esm in self.lambda_event_source_mappings.items():
backend = sqs_backends[self.region]
@ -391,27 +450,28 @@ class Queue(BaseModel):
)
result = lambda_backends[self.region].send_sqs_batch(
arn,
messages,
self.queue_arn,
arn, messages, self.queue_arn
)
if result:
[backend.delete_message(self.name, m.receipt_handle) for m in messages]
else:
[backend.change_message_visibility(self.name, m.receipt_handle, 0) for m in messages]
[
backend.change_message_visibility(self.name, m.receipt_handle, 0)
for m in messages
]
def get_cfn_attribute(self, attribute_name):
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
if attribute_name == 'Arn':
if attribute_name == "Arn":
return self.queue_arn
elif attribute_name == 'QueueName':
elif attribute_name == "QueueName":
return self.name
raise UnformattedGetAttTemplateException()
class SQSBackend(BaseBackend):
def __init__(self, region_name):
self.region_name = region_name
self.queues = {}
@ -427,7 +487,7 @@ class SQSBackend(BaseBackend):
queue = self.queues.get(name)
if queue:
try:
kwargs.pop('region')
kwargs.pop("region")
except KeyError:
pass
@ -436,28 +496,26 @@ class SQSBackend(BaseBackend):
queue_attributes = queue.attributes
new_queue_attributes = new_queue.attributes
static_attributes = (
'DelaySeconds',
'MaximumMessageSize',
'MessageRetentionPeriod',
'Policy',
'QueueArn',
'ReceiveMessageWaitTimeSeconds',
'RedrivePolicy',
'VisibilityTimeout',
'KmsMasterKeyId',
'KmsDataKeyReusePeriodSeconds',
'FifoQueue',
'ContentBasedDeduplication',
"DelaySeconds",
"MaximumMessageSize",
"MessageRetentionPeriod",
"Policy",
"QueueArn",
"ReceiveMessageWaitTimeSeconds",
"RedrivePolicy",
"VisibilityTimeout",
"KmsMasterKeyId",
"KmsDataKeyReusePeriodSeconds",
"FifoQueue",
"ContentBasedDeduplication",
)
for key in static_attributes:
if queue_attributes.get(key) != new_queue_attributes.get(key):
raise QueueAlreadyExists(
"The specified queue already exists.",
)
raise QueueAlreadyExists("The specified queue already exists.")
else:
try:
kwargs.pop('region')
kwargs.pop("region")
except KeyError:
pass
queue = Queue(name, region=self.region_name, **kwargs)
@ -472,9 +530,9 @@ class SQSBackend(BaseBackend):
return self.get_queue(queue_name)
def list_queues(self, queue_name_prefix):
re_str = '.*'
re_str = ".*"
if queue_name_prefix:
re_str = '^{0}.*'.format(queue_name_prefix)
re_str = "^{0}.*".format(queue_name_prefix)
prefix_re = re.compile(re_str)
qs = []
for name, q in self.queues.items():
@ -497,17 +555,24 @@ class SQSBackend(BaseBackend):
queue = self.get_queue(queue_name)
if not len(attribute_names):
attribute_names.append('All')
attribute_names.append("All")
valid_names = ['All'] + queue.BASE_ATTRIBUTES + queue.FIFO_ATTRIBUTES + queue.KMS_ATTRIBUTES
invalid_name = next((name for name in attribute_names if name not in valid_names), None)
valid_names = (
["All"]
+ queue.BASE_ATTRIBUTES
+ queue.FIFO_ATTRIBUTES
+ queue.KMS_ATTRIBUTES
)
invalid_name = next(
(name for name in attribute_names if name not in valid_names), None
)
if invalid_name or invalid_name == '':
if invalid_name or invalid_name == "":
raise InvalidAttributeName(invalid_name)
attributes = {}
if 'All' in attribute_names:
if "All" in attribute_names:
attributes = queue.attributes
else:
for name in (name for name in attribute_names if name in queue.attributes):
@ -520,7 +585,15 @@ class SQSBackend(BaseBackend):
queue._set_attributes(attributes)
return queue
def send_message(self, queue_name, message_body, message_attributes=None, delay_seconds=None, deduplication_id=None, group_id=None):
def send_message(
self,
queue_name,
message_body,
message_attributes=None,
delay_seconds=None,
deduplication_id=None,
group_id=None,
):
queue = self.get_queue(queue_name)
@ -541,9 +614,7 @@ class SQSBackend(BaseBackend):
if message_attributes:
message.message_attributes = message_attributes
message.mark_sent(
delay_seconds=delay_seconds
)
message.mark_sent(delay_seconds=delay_seconds)
queue.add_message(message)
@ -552,17 +623,25 @@ class SQSBackend(BaseBackend):
def send_message_batch(self, queue_name, entries):
self.get_queue(queue_name)
if any(not re.match(r'^[\w-]{1,80}$', entry['Id']) for entry in entries.values()):
if any(
not re.match(r"^[\w-]{1,80}$", entry["Id"]) for entry in entries.values()
):
raise InvalidBatchEntryId()
body_length = next(
(len(entry['MessageBody']) for entry in entries.values() if len(entry['MessageBody']) > MAXIMUM_MESSAGE_LENGTH),
False
(
len(entry["MessageBody"])
for entry in entries.values()
if len(entry["MessageBody"]) > MAXIMUM_MESSAGE_LENGTH
),
False,
)
if body_length:
raise BatchRequestTooLong(body_length)
duplicate_id = self._get_first_duplicate_id([entry['Id'] for entry in entries.values()])
duplicate_id = self._get_first_duplicate_id(
[entry["Id"] for entry in entries.values()]
)
if duplicate_id:
raise BatchEntryIdsNotDistinct(duplicate_id)
@ -574,11 +653,11 @@ class SQSBackend(BaseBackend):
# Loop through looking for messages
message = self.send_message(
queue_name,
entry['MessageBody'],
message_attributes=entry['MessageAttributes'],
delay_seconds=entry['DelaySeconds']
entry["MessageBody"],
message_attributes=entry["MessageAttributes"],
delay_seconds=entry["DelaySeconds"],
)
message.user_id = entry['Id']
message.user_id = entry["Id"]
messages.append(message)
@ -592,7 +671,9 @@ class SQSBackend(BaseBackend):
unique_ids.add(id)
return None
def receive_messages(self, queue_name, count, wait_seconds_timeout, visibility_timeout):
def receive_messages(
self, queue_name, count, wait_seconds_timeout, visibility_timeout
):
"""
Attempt to retrieve visible messages from a queue.
@ -638,13 +719,15 @@ class SQSBackend(BaseBackend):
queue.pending_messages.add(message)
if queue.dead_letter_queue is not None and message.approximate_receive_count >= queue.redrive_policy['maxReceiveCount']:
if (
queue.dead_letter_queue is not None
and message.approximate_receive_count
>= queue.redrive_policy["maxReceiveCount"]
):
messages_to_dlq.append(message)
continue
message.mark_received(
visibility_timeout=visibility_timeout
)
message.mark_received(visibility_timeout=visibility_timeout)
result.append(message)
if len(result) >= count:
break
@ -660,6 +743,7 @@ class SQSBackend(BaseBackend):
break
import time
time.sleep(0.01)
continue
@ -670,7 +754,9 @@ class SQSBackend(BaseBackend):
def delete_message(self, queue_name, receipt_handle):
queue = self.get_queue(queue_name)
if not any(message.receipt_handle == receipt_handle for message in queue._messages):
if not any(
message.receipt_handle == receipt_handle for message in queue._messages
):
raise ReceiptHandleIsInvalid()
new_messages = []
@ -715,12 +801,12 @@ class SQSBackend(BaseBackend):
queue = self.get_queue(queue_name)
if actions is None or len(actions) == 0:
raise RESTError('InvalidParameterValue', 'Need at least one Action')
raise RESTError("InvalidParameterValue", "Need at least one Action")
if account_ids is None or len(account_ids) == 0:
raise RESTError('InvalidParameterValue', 'Need at least one Account ID')
raise RESTError("InvalidParameterValue", "Need at least one Account ID")
if not all([item in Queue.ALLOWED_PERMISSIONS for item in actions]):
raise RESTError('InvalidParameterValue', 'Invalid permissions')
raise RESTError("InvalidParameterValue", "Invalid permissions")
queue.permissions[label] = (account_ids, actions)
@ -728,7 +814,9 @@ class SQSBackend(BaseBackend):
queue = self.get_queue(queue_name)
if label not in queue.permissions:
raise RESTError('InvalidParameterValue', 'Permission doesnt exist for the given label')
raise RESTError(
"InvalidParameterValue", "Permission doesnt exist for the given label"
)
del queue.permissions[label]
@ -736,12 +824,15 @@ class SQSBackend(BaseBackend):
queue = self.get_queue(queue_name)
if not len(tags):
raise RESTError('MissingParameter',
'The request must contain the parameter Tags.')
raise RESTError(
"MissingParameter", "The request must contain the parameter Tags."
)
if len(tags) > 50:
raise RESTError('InvalidParameterValue',
'Too many tags added for queue {}.'.format(queue_name))
raise RESTError(
"InvalidParameterValue",
"Too many tags added for queue {}.".format(queue_name),
)
queue.tags.update(tags)
@ -749,7 +840,10 @@ class SQSBackend(BaseBackend):
queue = self.get_queue(queue_name)
if not len(tag_keys):
raise RESTError('InvalidParameterValue', 'Tag keys must be between 1 and 128 characters in length.')
raise RESTError(
"InvalidParameterValue",
"Tag keys must be between 1 and 128 characters in length.",
)
for key in tag_keys:
try:

View file

@ -12,7 +12,7 @@ from .exceptions import (
MessageNotInflight,
ReceiptHandleIsInvalid,
EmptyBatchRequest,
InvalidAttributeName
InvalidAttributeName,
)
MAXIMUM_VISIBILTY_TIMEOUT = 43200
@ -22,7 +22,7 @@ DEFAULT_RECEIVED_MESSAGES = 1
class SQSResponse(BaseResponse):
region_regex = re.compile(r'://(.+?)\.queue\.amazonaws\.com')
region_regex = re.compile(r"://(.+?)\.queue\.amazonaws\.com")
@property
def sqs_backend(self):
@ -30,19 +30,21 @@ class SQSResponse(BaseResponse):
@property
def attribute(self):
if not hasattr(self, '_attribute'):
self._attribute = self._get_map_prefix('Attribute', key_end='.Name', value_end='.Value')
if not hasattr(self, "_attribute"):
self._attribute = self._get_map_prefix(
"Attribute", key_end=".Name", value_end=".Value"
)
return self._attribute
@property
def tags(self):
if not hasattr(self, '_tags'):
self._tags = self._get_map_prefix('Tag', key_end='.Key', value_end='.Value')
if not hasattr(self, "_tags"):
self._tags = self._get_map_prefix("Tag", key_end=".Key", value_end=".Value")
return self._tags
def _get_queue_name(self):
try:
queue_name = self.querystring.get('QueueUrl')[0].split("/")[-1]
queue_name = self.querystring.get("QueueUrl")[0].split("/")[-1]
except TypeError:
# Fallback to reading from the URL
queue_name = self.path.split("/")[-1]
@ -80,9 +82,11 @@ class SQSResponse(BaseResponse):
queue_name = self._get_param("QueueName")
try:
queue = self.sqs_backend.create_queue(queue_name, self.tags, **self.attribute)
queue = self.sqs_backend.create_queue(
queue_name, self.tags, **self.attribute
)
except MessageAttributesInvalid as e:
return self._error('InvalidParameterValue', e.description)
return self._error("InvalidParameterValue", e.description)
template = self.response_template(CREATE_QUEUE_RESPONSE)
return template.render(queue_url=queue.url(request_url))
@ -98,14 +102,14 @@ class SQSResponse(BaseResponse):
def list_queues(self):
request_url = urlparse(self.uri)
queue_name_prefix = self._get_param('QueueNamePrefix')
queue_name_prefix = self._get_param("QueueNamePrefix")
queues = self.sqs_backend.list_queues(queue_name_prefix)
template = self.response_template(LIST_QUEUES_RESPONSE)
return template.render(queues=queues, request_url=request_url)
def change_message_visibility(self):
queue_name = self._get_queue_name()
receipt_handle = self._get_param('ReceiptHandle')
receipt_handle = self._get_param("ReceiptHandle")
try:
visibility_timeout = self._get_validated_visibility_timeout()
@ -116,53 +120,64 @@ class SQSResponse(BaseResponse):
self.sqs_backend.change_message_visibility(
queue_name=queue_name,
receipt_handle=receipt_handle,
visibility_timeout=visibility_timeout
visibility_timeout=visibility_timeout,
)
except MessageNotInflight as e:
return "Invalid request: {0}".format(e.description), dict(status=e.status_code)
return (
"Invalid request: {0}".format(e.description),
dict(status=e.status_code),
)
template = self.response_template(CHANGE_MESSAGE_VISIBILITY_RESPONSE)
return template.render()
def change_message_visibility_batch(self):
queue_name = self._get_queue_name()
entries = self._get_list_prefix('ChangeMessageVisibilityBatchRequestEntry')
entries = self._get_list_prefix("ChangeMessageVisibilityBatchRequestEntry")
success = []
error = []
for entry in entries:
try:
visibility_timeout = self._get_validated_visibility_timeout(entry['visibility_timeout'])
visibility_timeout = self._get_validated_visibility_timeout(
entry["visibility_timeout"]
)
except ValueError:
error.append({
'Id': entry['id'],
'SenderFault': 'true',
'Code': 'InvalidParameterValue',
'Message': 'Visibility timeout invalid'
})
error.append(
{
"Id": entry["id"],
"SenderFault": "true",
"Code": "InvalidParameterValue",
"Message": "Visibility timeout invalid",
}
)
continue
try:
self.sqs_backend.change_message_visibility(
queue_name=queue_name,
receipt_handle=entry['receipt_handle'],
visibility_timeout=visibility_timeout
receipt_handle=entry["receipt_handle"],
visibility_timeout=visibility_timeout,
)
success.append(entry['id'])
success.append(entry["id"])
except ReceiptHandleIsInvalid as e:
error.append({
'Id': entry['id'],
'SenderFault': 'true',
'Code': 'ReceiptHandleIsInvalid',
'Message': e.description
})
error.append(
{
"Id": entry["id"],
"SenderFault": "true",
"Code": "ReceiptHandleIsInvalid",
"Message": e.description,
}
)
except MessageNotInflight as e:
error.append({
'Id': entry['id'],
'SenderFault': 'false',
'Code': 'AWS.SimpleQueueService.MessageNotInflight',
'Message': e.description
})
error.append(
{
"Id": entry["id"],
"SenderFault": "false",
"Code": "AWS.SimpleQueueService.MessageNotInflight",
"Message": e.description,
}
)
template = self.response_template(CHANGE_MESSAGE_VISIBILITY_BATCH_RESPONSE)
return template.render(success=success, errors=error)
@ -170,10 +185,10 @@ class SQSResponse(BaseResponse):
def get_queue_attributes(self):
queue_name = self._get_queue_name()
if self.querystring.get('AttributeNames'):
raise InvalidAttributeName('')
if self.querystring.get("AttributeNames"):
raise InvalidAttributeName("")
attribute_names = self._get_multi_param('AttributeName')
attribute_names = self._get_multi_param("AttributeName")
attributes = self.sqs_backend.get_queue_attributes(queue_name, attribute_names)
@ -192,14 +207,17 @@ class SQSResponse(BaseResponse):
queue_name = self._get_queue_name()
queue = self.sqs_backend.delete_queue(queue_name)
if not queue:
return "A queue with name {0} does not exist".format(queue_name), dict(status=404)
return (
"A queue with name {0} does not exist".format(queue_name),
dict(status=404),
)
template = self.response_template(DELETE_QUEUE_RESPONSE)
return template.render(queue=queue)
def send_message(self):
message = self._get_param('MessageBody')
delay_seconds = int(self._get_param('DelaySeconds', 0))
message = self._get_param("MessageBody")
delay_seconds = int(self._get_param("DelaySeconds", 0))
message_group_id = self._get_param("MessageGroupId")
message_dedupe_id = self._get_param("MessageDeduplicationId")
@ -219,7 +237,7 @@ class SQSResponse(BaseResponse):
message_attributes=message_attributes,
delay_seconds=delay_seconds,
deduplication_id=message_dedupe_id,
group_id=message_group_id
group_id=message_group_id,
)
template = self.response_template(SEND_MESSAGE_RESPONSE)
return template.render(message=message, message_attributes=message_attributes)
@ -240,25 +258,30 @@ class SQSResponse(BaseResponse):
self.sqs_backend.get_queue(queue_name)
if self.querystring.get('Entries'):
if self.querystring.get("Entries"):
raise EmptyBatchRequest()
entries = {}
for key, value in self.querystring.items():
match = re.match(r'^SendMessageBatchRequestEntry\.(\d+)\.Id', key)
match = re.match(r"^SendMessageBatchRequestEntry\.(\d+)\.Id", key)
if match:
index = match.group(1)
message_attributes = parse_message_attributes(
self.querystring, base='SendMessageBatchRequestEntry.{}.'.format(index))
self.querystring,
base="SendMessageBatchRequestEntry.{}.".format(index),
)
entries[index] = {
'Id': value[0],
'MessageBody': self.querystring.get(
'SendMessageBatchRequestEntry.{}.MessageBody'.format(index))[0],
'DelaySeconds': self.querystring.get(
'SendMessageBatchRequestEntry.{}.DelaySeconds'.format(index), [None])[0],
'MessageAttributes': message_attributes
"Id": value[0],
"MessageBody": self.querystring.get(
"SendMessageBatchRequestEntry.{}.MessageBody".format(index)
)[0],
"DelaySeconds": self.querystring.get(
"SendMessageBatchRequestEntry.{}.DelaySeconds".format(index),
[None],
)[0],
"MessageAttributes": message_attributes,
}
messages = self.sqs_backend.send_message_batch(queue_name, entries)
@ -288,8 +311,9 @@ class SQSResponse(BaseResponse):
message_ids = []
for index in range(1, 11):
# Loop through looking for messages
receipt_key = 'DeleteMessageBatchRequestEntry.{0}.ReceiptHandle'.format(
index)
receipt_key = "DeleteMessageBatchRequestEntry.{0}.ReceiptHandle".format(
index
)
receipt_handle = self.querystring.get(receipt_key)
if not receipt_handle:
# Found all messages
@ -297,8 +321,7 @@ class SQSResponse(BaseResponse):
self.sqs_backend.delete_message(queue_name, receipt_handle[0])
message_user_id_key = 'DeleteMessageBatchRequestEntry.{0}.Id'.format(
index)
message_user_id_key = "DeleteMessageBatchRequestEntry.{0}.Id".format(index)
message_user_id = self.querystring.get(message_user_id_key)[0]
message_ids.append(message_user_id)
@ -327,7 +350,8 @@ class SQSResponse(BaseResponse):
"An error occurred (InvalidParameterValue) when calling "
"the ReceiveMessage operation: Value %s for parameter "
"MaxNumberOfMessages is invalid. Reason: must be between "
"1 and 10, if provided." % message_count)
"1 and 10, if provided." % message_count,
)
try:
wait_time = int(self.querystring.get("WaitTimeSeconds")[0])
@ -340,7 +364,8 @@ class SQSResponse(BaseResponse):
"An error occurred (InvalidParameterValue) when calling "
"the ReceiveMessage operation: Value %s for parameter "
"WaitTimeSeconds is invalid. Reason: must be &lt;= 0 and "
"&gt;= 20 if provided." % wait_time)
"&gt;= 20 if provided." % wait_time,
)
try:
visibility_timeout = self._get_validated_visibility_timeout()
@ -350,7 +375,8 @@ class SQSResponse(BaseResponse):
return ERROR_MAX_VISIBILITY_TIMEOUT_RESPONSE, dict(status=400)
messages = self.sqs_backend.receive_messages(
queue_name, message_count, wait_time, visibility_timeout)
queue_name, message_count, wait_time, visibility_timeout
)
template = self.response_template(RECEIVE_MESSAGE_RESPONSE)
return template.render(messages=messages)
@ -365,9 +391,9 @@ class SQSResponse(BaseResponse):
def add_permission(self):
queue_name = self._get_queue_name()
actions = self._get_multi_param('ActionName')
account_ids = self._get_multi_param('AWSAccountId')
label = self._get_param('Label')
actions = self._get_multi_param("ActionName")
account_ids = self._get_multi_param("AWSAccountId")
label = self._get_param("Label")
self.sqs_backend.add_permission(queue_name, actions, account_ids, label)
@ -376,7 +402,7 @@ class SQSResponse(BaseResponse):
def remove_permission(self):
queue_name = self._get_queue_name()
label = self._get_param('Label')
label = self._get_param("Label")
self.sqs_backend.remove_permission(queue_name, label)
@ -385,7 +411,7 @@ class SQSResponse(BaseResponse):
def tag_queue(self):
queue_name = self._get_queue_name()
tags = self._get_map_prefix('Tag', key_end='.Key', value_end='.Value')
tags = self._get_map_prefix("Tag", key_end=".Key", value_end=".Value")
self.sqs_backend.tag_queue(queue_name, tags)
@ -394,7 +420,7 @@ class SQSResponse(BaseResponse):
def untag_queue(self):
queue_name = self._get_queue_name()
tag_keys = self._get_multi_param('TagKey')
tag_keys = self._get_multi_param("TagKey")
self.sqs_backend.untag_queue(queue_name, tag_keys)
@ -672,7 +698,8 @@ ERROR_TOO_LONG_RESPONSE = """<ErrorResponse xmlns="http://queue.amazonaws.com/do
</ErrorResponse>"""
ERROR_MAX_VISIBILITY_TIMEOUT_RESPONSE = "Invalid request, maximum visibility timeout is {0}".format(
MAXIMUM_VISIBILTY_TIMEOUT)
MAXIMUM_VISIBILTY_TIMEOUT
)
ERROR_INEXISTENT_QUEUE = """<ErrorResponse xmlns="http://queue.amazonaws.com/doc/2012-11-05/">
<Error>

View file

@ -1,13 +1,11 @@
from __future__ import unicode_literals
from .responses import SQSResponse
url_bases = [
"https?://(.*?)(queue|sqs)(.*?).amazonaws.com"
]
url_bases = ["https?://(.*?)(queue|sqs)(.*?).amazonaws.com"]
dispatch = SQSResponse().dispatch
url_paths = {
'{0}/$': dispatch,
'{0}/(?P<account_id>\d+)/(?P<queue_name>[a-zA-Z0-9\-_\.]+)': dispatch,
"{0}/$": dispatch,
"{0}/(?P<account_id>\d+)/(?P<queue_name>[a-zA-Z0-9\-_\.]+)": dispatch,
}

View file

@ -8,46 +8,62 @@ from .exceptions import MessageAttributesInvalid
def generate_receipt_handle():
# http://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/ImportantIdentifiers.html#ImportantIdentifiers-receipt-handles
length = 185
return ''.join(random.choice(string.ascii_lowercase) for x in range(length))
return "".join(random.choice(string.ascii_lowercase) for x in range(length))
def parse_message_attributes(querystring, base='', value_namespace='Value.'):
def parse_message_attributes(querystring, base="", value_namespace="Value."):
message_attributes = {}
index = 1
while True:
# Loop through looking for message attributes
name_key = base + 'MessageAttribute.{0}.Name'.format(index)
name_key = base + "MessageAttribute.{0}.Name".format(index)
name = querystring.get(name_key)
if not name:
# Found all attributes
break
data_type_key = base + \
'MessageAttribute.{0}.{1}DataType'.format(index, value_namespace)
data_type_key = base + "MessageAttribute.{0}.{1}DataType".format(
index, value_namespace
)
data_type = querystring.get(data_type_key)
if not data_type:
raise MessageAttributesInvalid(
"The message attribute '{0}' must contain non-empty message attribute value.".format(name[0]))
"The message attribute '{0}' must contain non-empty message attribute value.".format(
name[0]
)
)
data_type_parts = data_type[0].split('.')
if len(data_type_parts) > 2 or data_type_parts[0] not in ['String', 'Binary', 'Number']:
data_type_parts = data_type[0].split(".")
if len(data_type_parts) > 2 or data_type_parts[0] not in [
"String",
"Binary",
"Number",
]:
raise MessageAttributesInvalid(
"The message attribute '{0}' has an invalid message attribute type, the set of supported type prefixes is Binary, Number, and String.".format(name[0]))
"The message attribute '{0}' has an invalid message attribute type, the set of supported type prefixes is Binary, Number, and String.".format(
name[0]
)
)
type_prefix = 'String'
if data_type_parts[0] == 'Binary':
type_prefix = 'Binary'
type_prefix = "String"
if data_type_parts[0] == "Binary":
type_prefix = "Binary"
value_key = base + \
'MessageAttribute.{0}.{1}{2}Value'.format(
index, value_namespace, type_prefix)
value_key = base + "MessageAttribute.{0}.{1}{2}Value".format(
index, value_namespace, type_prefix
)
value = querystring.get(value_key)
if not value:
raise MessageAttributesInvalid(
"The message attribute '{0}' must contain non-empty message attribute value for message attribute type '{1}'.".format(name[0], data_type[0]))
"The message attribute '{0}' must contain non-empty message attribute value for message attribute type '{1}'.".format(
name[0], data_type[0]
)
)
message_attributes[name[0]] = {'data_type': data_type[
0], type_prefix.lower() + '_value': value[0]}
message_attributes[name[0]] = {
"data_type": data_type[0],
type_prefix.lower() + "_value": value[0],
}
index += 1