From f9e0595e124896ff2344d5c1c6228eb2b1fbf711 Mon Sep 17 00:00:00 2001 From: Kai Date: Thu, 13 May 2021 23:06:54 +0900 Subject: [PATCH] Fix sqs message retention logic (#3924) * Fix sqs message retention logic * Apply lint to moto/sqs/models.py * Fix failed tests because of freezing time * Fix freezing time in test_publish_to_sqs_in_different_region --- moto/sqs/models.py | 13 ++++++++----- tests/test_sns/test_publishing.py | 6 ++++-- tests/test_sns/test_publishing_boto3.py | 12 ++++++++---- tests/test_sqs/test_sqs.py | 23 +++++++++++++++++++++++ 4 files changed, 43 insertions(+), 11 deletions(-) diff --git a/moto/sqs/models.py b/moto/sqs/models.py index 85ba4452..1d0e8610 100644 --- a/moto/sqs/models.py +++ b/moto/sqs/models.py @@ -836,7 +836,9 @@ class SQSBackend(BaseBackend): queue.pending_messages.add(message) message.mark_received(visibility_timeout=visibility_timeout) _filter_message_attributes(message, message_attribute_names) - if not self.is_message_valid_based_on_retention_period(queue_name): + if not self.is_message_valid_based_on_retention_period( + queue_name, message + ): break result.append(message) if len(result) >= count: @@ -1015,11 +1017,12 @@ class SQSBackend(BaseBackend): def list_queue_tags(self, queue_name): return self.get_queue(queue_name) - def is_message_valid_based_on_retention_period(self, queue_name): + def is_message_valid_based_on_retention_period(self, queue_name, message): message_attributes = self.get_queue_attributes(queue_name, []) - retain_until = message_attributes.get( - "MessageRetentionPeriod" - ) + message_attributes.get("CreatedTimestamp") + retain_until = ( + message_attributes.get("MessageRetentionPeriod") + + message.sent_timestamp / 1000 + ) if retain_until <= unix_time(): return False return True diff --git a/tests/test_sns/test_publishing.py b/tests/test_sns/test_publishing.py index cc7dbb8d..0a17d90f 100644 --- a/tests/test_sns/test_publishing.py +++ b/tests/test_sns/test_publishing.py @@ -46,7 +46,8 @@ def test_publish_to_sqs(): ] queue = sqs_conn.get_queue("test-queue") - message = queue.read(1) + with freeze_time("2015-01-01 12:00:01"): + message = queue.read(1) expected = MESSAGE_FROM_SQS_TEMPLATE % ( message_to_publish, published_message_id, @@ -89,7 +90,8 @@ def test_publish_to_sqs_in_different_region(): ] queue = sqs_conn.get_queue("test-queue") - message = queue.read(1) + with freeze_time("2015-01-01 12:00:01"): + message = queue.read(1) expected = MESSAGE_FROM_SQS_TEMPLATE % ( message_to_publish, published_message_id, diff --git a/tests/test_sns/test_publishing_boto3.py b/tests/test_sns/test_publishing_boto3.py index 6ee5ef1c..4a1a647c 100644 --- a/tests/test_sns/test_publishing_boto3.py +++ b/tests/test_sns/test_publishing_boto3.py @@ -46,7 +46,8 @@ def test_publish_to_sqs(): published_message_id = published_message["MessageId"] queue = sqs_conn.get_queue_by_name(QueueName="test-queue") - messages = queue.receive_messages(MaxNumberOfMessages=1) + with freeze_time("2015-01-01 12:00:01"): + messages = queue.receive_messages(MaxNumberOfMessages=1) expected = MESSAGE_FROM_SQS_TEMPLATE % (message, published_message_id, "us-east-1") acquired_message = re.sub( r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{3}Z", @@ -77,7 +78,8 @@ def test_publish_to_sqs_raw(): with freeze_time("2015-01-01 12:00:00"): topic.publish(Message=message) - messages = queue.receive_messages(MaxNumberOfMessages=1) + with freeze_time("2015-01-01 12:00:01"): + messages = queue.receive_messages(MaxNumberOfMessages=1) messages[0].body.should.equal(message) @@ -279,7 +281,8 @@ def test_publish_to_sqs_dump_json(): published_message_id = published_message["MessageId"] queue = sqs_conn.get_queue_by_name(QueueName="test-queue") - messages = queue.receive_messages(MaxNumberOfMessages=1) + with freeze_time("2015-01-01 12:00:01"): + messages = queue.receive_messages(MaxNumberOfMessages=1) escaped = message.replace('"', '\\"') expected = MESSAGE_FROM_SQS_TEMPLATE % (escaped, published_message_id, "us-east-1") @@ -314,7 +317,8 @@ def test_publish_to_sqs_in_different_region(): published_message_id = published_message["MessageId"] queue = sqs_conn.get_queue_by_name(QueueName="test-queue") - messages = queue.receive_messages(MaxNumberOfMessages=1) + with freeze_time("2015-01-01 12:00:01"): + messages = queue.receive_messages(MaxNumberOfMessages=1) expected = MESSAGE_FROM_SQS_TEMPLATE % (message, published_message_id, "us-west-1") acquired_message = re.sub( r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{3}Z", diff --git a/tests/test_sqs/test_sqs.py b/tests/test_sqs/test_sqs.py index 24720d94..2288cec7 100644 --- a/tests/test_sqs/test_sqs.py +++ b/tests/test_sqs/test_sqs.py @@ -313,6 +313,29 @@ def test_message_retention_period(): assert len(messages) == 0 +@mock_sqs +def test_queue_retention_period(): + sqs = boto3.resource("sqs", region_name="us-east-1") + queue = sqs.create_queue( + QueueName="blah", Attributes={"MessageRetentionPeriod": "3"} + ) + + time.sleep(5) + + queue.send_message( + MessageBody="derp", + MessageAttributes={ + "SOME_Valid.attribute-Name": { + "StringValue": "1493147359900", + "DataType": "Number", + } + }, + ) + + messages = queue.receive_messages() + assert len(messages) == 1 + + @mock_sqs def test_message_with_invalid_attributes(): sqs = boto3.resource("sqs", region_name="us-east-1")