From 85d94ad6edb91303a3526b4627978f510aea73d2 Mon Sep 17 00:00:00 2001 From: usmangani1 Date: Sat, 1 May 2021 12:18:39 +0530 Subject: [PATCH] Fix:SQS Added support for system attributes in sqs (#3878) * Adding SQS system attributes * Fix Comments * Change template in response --- moto/sqs/models.py | 6 ++++-- moto/sqs/responses.py | 10 ++++++++++ moto/sqs/utils.py | 14 +++++++------- tests/test_sqs/test_sqs.py | 28 ++++++++++++++++++++++++++++ 4 files changed, 49 insertions(+), 9 deletions(-) diff --git a/moto/sqs/models.py b/moto/sqs/models.py index 0ad838a5..85ba4452 100644 --- a/moto/sqs/models.py +++ b/moto/sqs/models.py @@ -70,7 +70,7 @@ DEDUPLICATION_TIME_IN_SECONDS = 300 class Message(BaseModel): - def __init__(self, message_id, body): + def __init__(self, message_id, body, system_attributes={}): self.id = message_id self._body = body self.message_attributes = {} @@ -84,6 +84,7 @@ class Message(BaseModel): self.sequence_number = None self.visible_at = 0 self.delayed_until = 0 + self.system_attributes = system_attributes @property def body_md5(self): @@ -673,6 +674,7 @@ class SQSBackend(BaseBackend): delay_seconds=None, deduplication_id=None, group_id=None, + system_attributes=None, ): queue = self.get_queue(queue_name) @@ -689,7 +691,7 @@ class SQSBackend(BaseBackend): delay_seconds = queue.delay_seconds message_id = get_random_message_id() - message = Message(message_id, message_body) + message = Message(message_id, message_body, system_attributes) # if content based deduplication is set then set sha256 hash of the message # as the deduplication_id diff --git a/moto/sqs/responses.py b/moto/sqs/responses.py index 623c3174..7e879ed8 100644 --- a/moto/sqs/responses.py +++ b/moto/sqs/responses.py @@ -228,6 +228,9 @@ class SQSResponse(BaseResponse): return ERROR_TOO_LONG_RESPONSE, dict(status=400) message_attributes = parse_message_attributes(self.querystring) + system_message_attributes = parse_message_attributes( + self.querystring, key="MessageSystemAttribute" + ) queue_name = self._get_queue_name() @@ -246,6 +249,7 @@ class SQSResponse(BaseResponse): delay_seconds=delay_seconds, deduplication_id=message_dedupe_id, group_id=message_group_id, + system_attributes=system_message_attributes, ) template = self.response_template(SEND_MESSAGE_RESPONSE) return template.render(message=message, message_attributes=message_attributes) @@ -596,6 +600,12 @@ RECEIVE_MESSAGE_RESPONSE = """ {{ message.group_id }} {% endif %} + {% if message.system_attributes and message.system_attributes.get('AWSTraceHeader') is not none %} + + AWSTraceHeader + {{ message.system_attributes.get('AWSTraceHeader',{}).get('string_value') }} + + {% endif %} {% if attributes.sequence_number and message.sequence_number is not none %} SequenceNumber diff --git a/moto/sqs/utils.py b/moto/sqs/utils.py index 876d6b40..b490813c 100644 --- a/moto/sqs/utils.py +++ b/moto/sqs/utils.py @@ -26,20 +26,20 @@ def extract_input_message_attributes(querystring): return message_attributes -def parse_message_attributes(querystring, base="", value_namespace="Value."): +def parse_message_attributes( + querystring, key="MessageAttribute", base="", value_namespace="Value." +): message_attributes = {} index = 1 while True: # Loop through looking for message attributes - name_key = base + "MessageAttribute.{0}.Name".format(index) + name_key = base + "{0}.{1}.Name".format(key, index) name = querystring.get(name_key) if not name: # Found all attributes break - data_type_key = base + "MessageAttribute.{0}.{1}DataType".format( - index, value_namespace - ) + data_type_key = base + "{0}.{1}.{2}DataType".format(key, index, value_namespace) data_type = querystring.get(data_type_key) if not data_type: raise MessageAttributesInvalid( @@ -64,8 +64,8 @@ def parse_message_attributes(querystring, base="", value_namespace="Value."): if data_type_parts[0] == "Binary": type_prefix = "Binary" - value_key = base + "MessageAttribute.{0}.{1}{2}Value".format( - index, value_namespace, type_prefix + value_key = base + "{0}.{1}.{2}{3}Value".format( + key, index, value_namespace, type_prefix ) value = querystring.get(value_key) if not value: diff --git a/tests/test_sqs/test_sqs.py b/tests/test_sqs/test_sqs.py index 4a4449f9..24720d94 100644 --- a/tests/test_sqs/test_sqs.py +++ b/tests/test_sqs/test_sqs.py @@ -2745,3 +2745,31 @@ def test_fifo_send_message_when_same_group_id_is_in_dlq(): msg_queue.send_message(MessageBody="second", MessageGroupId="1") messages = msg_queue.receive_messages() messages.should.have.length_of(1) + + +@mock_sqs +def test_message_attributes_in_receive_message(): + sqs = boto3.resource("sqs", region_name="us-east-1") + conn = boto3.client("sqs", region_name="us-east-1") + conn.create_queue(QueueName="test-queue") + queue = sqs.Queue("test-queue") + body_one = "this is a test message" + + queue.send_message( + MessageBody=body_one, + MessageSystemAttributes={ + "AWSTraceHeader": { + "StringValue": "Root=1-3152b799-8954dae64eda91bc9a23a7e8;Parent=7fa8c0f79203be72;Sampled=1", + "DataType": "String", + } + }, + ) + + messages = conn.receive_message( + QueueUrl=queue.url, MaxNumberOfMessages=2, MessageAttributeNames=["All"] + )["Messages"] + + assert ( + messages[0]["Attributes"]["AWSTraceHeader"] + == "Root=1-3152b799-8954dae64eda91bc9a23a7e8;Parent=7fa8c0f79203be72;Sampled=1" + )