From 75812eb8382c1b7adcdf77929dde572042eb776c Mon Sep 17 00:00:00 2001 From: Adrian Galera Date: Fri, 11 Jan 2019 10:44:30 +0100 Subject: [PATCH 01/24] Enable SES feedback via SNS --- moto/ses/feedback.py | 81 +++++++++++++++++++ moto/ses/models.py | 86 +++++++++++++++++++- moto/ses/responses.py | 23 +++++- tests/test_ses/test_ses_sns_boto3.py | 114 +++++++++++++++++++++++++++ 4 files changed, 299 insertions(+), 5 deletions(-) create mode 100644 moto/ses/feedback.py create mode 100644 tests/test_ses/test_ses_sns_boto3.py diff --git a/moto/ses/feedback.py b/moto/ses/feedback.py new file mode 100644 index 00000000..2d32f9ce --- /dev/null +++ b/moto/ses/feedback.py @@ -0,0 +1,81 @@ +""" +SES Feedback messages +Extracted from https://docs.aws.amazon.com/ses/latest/DeveloperGuide/notification-contents.html +""" +COMMON_MAIL = { + "notificationType": "Bounce, Complaint, or Delivery.", + "mail": { + "timestamp": "2018-10-08T14:05:45 +0000", + "messageId": "000001378603177f-7a5433e7-8edb-42ae-af10-f0181f34d6ee-000000", + "source": "sender@example.com", + "sourceArn": "arn:aws:ses:us-west-2:888888888888:identity/example.com", + "sourceIp": "127.0.3.0", + "sendingAccountId": "123456789012", + "destination": [ + "recipient@example.com" + ], + "headersTruncated": False, + "headers": [ + { + "name": "From", + "value": "\"Sender Name\" " + }, + { + "name": "To", + "value": "\"Recipient Name\" " + } + ], + "commonHeaders": { + "from": [ + "Sender Name " + ], + "date": "Mon, 08 Oct 2018 14:05:45 +0000", + "to": [ + "Recipient Name " + ], + "messageId": " custom-message-ID", + "subject": "Message sent using Amazon SES" + } + } +} +BOUNCE = { + "bounceType": "Permanent", + "bounceSubType": "General", + "bouncedRecipients": [ + { + "status": "5.0.0", + "action": "failed", + "diagnosticCode": "smtp; 550 user unknown", + "emailAddress": "recipient1@example.com" + }, + { + "status": "4.0.0", + "action": "delayed", + "emailAddress": "recipient2@example.com" + } + ], + "reportingMTA": "example.com", + "timestamp": "2012-05-25T14:59:38.605Z", + "feedbackId": "000001378603176d-5a4b5ad9-6f30-4198-a8c3-b1eb0c270a1d-000000", + "remoteMtaIp": "127.0.2.0" +} +COMPLAINT = { + "userAgent": "AnyCompany Feedback Loop (V0.01)", + "complainedRecipients": [ + { + "emailAddress": "recipient1@example.com" + } + ], + "complaintFeedbackType": "abuse", + "arrivalDate": "2009-12-03T04:24:21.000-05:00", + "timestamp": "2012-05-25T14:59:38.623Z", + "feedbackId": "000001378603177f-18c07c78-fa81-4a58-9dd1-fedc3cb8f49a-000000" +} +DELIVERY = { + "timestamp": "2014-05-28T22:41:01.184Z", + "processingTimeMillis": 546, + "recipients": ["success@simulator.amazonses.com"], + "smtpResponse": "250 ok: Message 64111812 accepted", + "reportingMTA": "a8-70.smtp-out.amazonses.com", + "remoteMtaIp": "127.0.2.0" +} diff --git a/moto/ses/models.py b/moto/ses/models.py index 71fe9d9a..77cd5719 100644 --- a/moto/ses/models.py +++ b/moto/ses/models.py @@ -4,12 +4,39 @@ import email from email.utils import parseaddr from moto.core import BaseBackend, BaseModel +from moto.sns.models import sns_backends from .exceptions import MessageRejectedError from .utils import get_random_message_id - +from .feedback import COMMON_MAIL, BOUNCE, COMPLAINT, DELIVERY RECIPIENT_LIMIT = 50 +class SESFeedback(BaseModel): + + BOUNCE = "Bounce" + COMPLAINT = "Complaint" + DELIVERY = "Delivery" + + SUCCESS_ADDR = "success" + BOUNCE_ADDR = "bounce" + COMPLAINT_ADDR = "complaint" + + FEEDBACK_SUCCESS_MSG = {"test": "success"} + FEEDBACK_BOUNCE_MSG = {"test": "bounce"} + FEEDBACK_COMPLAINT_MSG = {"test": "complaint"} + + @staticmethod + def generate_message(msg_type): + msg = dict(COMMON_MAIL) + if msg_type == SESFeedback.BOUNCE: + msg["bounce"] = BOUNCE + elif msg_type == SESFeedback.COMPLAINT: + msg["complaint"] = COMPLAINT + elif msg_type == SESFeedback.DELIVERY: + msg["delivery"] = DELIVERY + + return msg + class Message(BaseModel): @@ -48,6 +75,7 @@ class SESBackend(BaseBackend): self.domains = [] self.sent_messages = [] self.sent_message_count = 0 + self.sns_topics = {} def _is_verified_address(self, source): _, address = parseaddr(source) @@ -77,7 +105,7 @@ class SESBackend(BaseBackend): else: self.domains.remove(identity) - def send_email(self, source, subject, body, destinations): + def send_email(self, source, subject, body, destinations, region): recipient_count = sum(map(len, destinations.values())) if recipient_count > RECIPIENT_LIMIT: raise MessageRejectedError('Too many recipients.') @@ -86,13 +114,52 @@ class SESBackend(BaseBackend): "Email address not verified %s" % source ) + self.__process_sns_feedback__(source, destinations, region) + message_id = get_random_message_id() message = Message(message_id, source, subject, body, destinations) self.sent_messages.append(message) self.sent_message_count += recipient_count return message - def send_raw_email(self, source, destinations, raw_data): + def __type_of_message__(self, destinations): + """Checks the destination for any special address that could indicate delivery, complaint or bounce + like in SES simualtor""" + alladdress = destinations.get("ToAddresses", []) + destinations.get("CcAddresses", []) + destinations.get("BccAddresses", []) + for addr in alladdress: + if SESFeedback.SUCCESS_ADDR in addr: + return SESFeedback.DELIVERY + elif SESFeedback.COMPLAINT_ADDR in addr: + return SESFeedback.COMPLAINT + elif SESFeedback.BOUNCE_ADDR in addr: + return SESFeedback.BOUNCE + + return None + + def __generate_feedback__(self, msg_type): + """Generates the SNS message for the feedback""" + return SESFeedback.generate_message(msg_type) + + def __process_sns_feedback__(self, source, destinations, region): + domain = str(source) + if "@" in domain: + domain = domain.split("@")[1] + print(domain, self.sns_topics) + if domain in self.sns_topics: + print("SNS Feedback configured for %s => %s" % (domain, self.sns_topics[domain])) + msg_type = self.__type_of_message__(destinations) + print("Message type for destinations %s => %s" % (destinations, msg_type)) + if msg_type is not None: + sns_topic = self.sns_topics[domain].get(msg_type, None) + if sns_topic is not None: + message = self.__generate_feedback__(msg_type) + if message: + print("Message generated for %s => %s" % (message, msg_type)) + sns_backends[region].publish(sns_topic, message) + else: + print("SNS Feedback not configured") + + def send_raw_email(self, source, destinations, raw_data, region): if source is not None: _, source_email_address = parseaddr(source) if source_email_address not in self.addresses: @@ -122,6 +189,8 @@ class SESBackend(BaseBackend): if recipient_count > RECIPIENT_LIMIT: raise MessageRejectedError('Too many recipients.') + self.__process_sns_feedback__(source, destinations, region) + self.sent_message_count += recipient_count message_id = get_random_message_id() message = RawMessage(message_id, source, destinations, raw_data) @@ -131,5 +200,16 @@ class SESBackend(BaseBackend): def get_send_quota(self): return SESQuota(self.sent_message_count) + def set_identity_notification_topic(self, identity, notification_type, sns_topic): + identity_sns_topics = self.sns_topics.get(identity, {}) + if sns_topic is None: + del identity_sns_topics[notification_type] + else: + identity_sns_topics[notification_type] = sns_topic + + self.sns_topics[identity] = identity_sns_topics + + return {} + ses_backend = SESBackend() diff --git a/moto/ses/responses.py b/moto/ses/responses.py index bdf87383..d2dda55f 100644 --- a/moto/ses/responses.py +++ b/moto/ses/responses.py @@ -70,7 +70,7 @@ class EmailResponse(BaseResponse): break destinations[dest_type].append(address[0]) - message = ses_backend.send_email(source, subject, body, destinations) + message = ses_backend.send_email(source, subject, body, destinations, self.region) template = self.response_template(SEND_EMAIL_RESPONSE) return template.render(message=message) @@ -92,7 +92,7 @@ class EmailResponse(BaseResponse): break destinations.append(address[0]) - message = ses_backend.send_raw_email(source, destinations, raw_data) + message = ses_backend.send_raw_email(source, destinations, raw_data, self.region) template = self.response_template(SEND_RAW_EMAIL_RESPONSE) return template.render(message=message) @@ -101,6 +101,18 @@ class EmailResponse(BaseResponse): template = self.response_template(GET_SEND_QUOTA_RESPONSE) return template.render(quota=quota) + def set_identity_notification_topic(self): + + identity = self.querystring.get("Identity")[0] + not_type = self.querystring.get("NotificationType")[0] + sns_topic = self.querystring.get("SnsTopic") + if sns_topic: + sns_topic = sns_topic[0] + + ses_backend.set_identity_notification_topic(identity, not_type, sns_topic) + template = self.response_template(SET_IDENTITY_NOTIFICATION_TOPIC_RESPONSE) + return template.render() + VERIFY_EMAIL_IDENTITY = """ @@ -200,3 +212,10 @@ GET_SEND_QUOTA_RESPONSE = """ + + + 47e0ef1a-9bf2-11e1-9279-0100e8cf109a + +""" diff --git a/tests/test_ses/test_ses_sns_boto3.py b/tests/test_ses/test_ses_sns_boto3.py new file mode 100644 index 00000000..37f79a8b --- /dev/null +++ b/tests/test_ses/test_ses_sns_boto3.py @@ -0,0 +1,114 @@ +from __future__ import unicode_literals + +import boto3 +import json +from botocore.exceptions import ClientError +from six.moves.email_mime_multipart import MIMEMultipart +from six.moves.email_mime_text import MIMEText + +import sure # noqa +from nose import tools +from moto import mock_ses, mock_sns, mock_sqs +from moto.ses.models import SESFeedback + + +@mock_ses +def test_enable_disable_ses_sns_communication(): + conn = boto3.client('ses', region_name='us-east-1') + conn.set_identity_notification_topic( + Identity='test.com', + NotificationType='Bounce', + SnsTopic='the-arn' + ) + conn.set_identity_notification_topic( + Identity='test.com', + NotificationType='Bounce' + ) + + +def __setup_feedback_env__(ses_conn, sns_conn, sqs_conn, domain, topic, queue, region, expected_msg): + """Setup the AWS environment to test the SES SNS Feedback""" + # Environment setup + # Create SQS queue + sqs_conn.create_queue(QueueName=queue) + # Create SNS topic + create_topic_response = sns_conn.create_topic(Name=topic) + topic_arn = create_topic_response["TopicArn"] + # Subscribe the SNS topic to the SQS queue + sns_conn.subscribe(TopicArn=topic_arn, + Protocol="sqs", + Endpoint="arn:aws:sqs:%s:123456789012:%s" % (region, queue)) + # Verify SES domain + ses_conn.verify_domain_identity(Domain=domain) + # Setup SES notification topic + if expected_msg is not None: + ses_conn.set_identity_notification_topic( + Identity=domain, + NotificationType=expected_msg, + SnsTopic=topic_arn + ) + + +def __test_sns_feedback__(addr, expected_msg): + region_name = "us-east-1" + ses_conn = boto3.client('ses', region_name=region_name) + sns_conn = boto3.client('sns', region_name=region_name) + sqs_conn = boto3.resource('sqs', region_name=region_name) + domain = "example.com" + topic = "bounce-arn-feedback" + queue = "feedback-test-queue" + + __setup_feedback_env__(ses_conn, sns_conn, sqs_conn, domain, topic, queue, region_name, expected_msg) + + # Send the message + kwargs = dict( + Source="test@" + domain, + Destination={ + "ToAddresses": [addr + "@" + domain], + "CcAddresses": ["test_cc@" + domain], + "BccAddresses": ["test_bcc@" + domain], + }, + Message={ + "Subject": {"Data": "test subject"}, + "Body": {"Text": {"Data": "test body"}} + } + ) + ses_conn.send_email(**kwargs) + + # Wait for messages in the queues + queue = sqs_conn.get_queue_by_name(QueueName=queue) + messages = queue.receive_messages(MaxNumberOfMessages=1) + if expected_msg is not None: + msg = messages[0].body + msg = json.loads(msg) + assert msg["Message"] == SESFeedback.generate_message(expected_msg) + else: + assert len(messages) == 0 + + +@mock_sqs +@mock_sns +@mock_ses +def test_no_sns_feedback(): + __test_sns_feedback__("test", None) + + +@mock_sqs +@mock_sns +@mock_ses +def test_sns_feedback_bounce(): + __test_sns_feedback__(SESFeedback.BOUNCE_ADDR, SESFeedback.BOUNCE) + + +@mock_sqs +@mock_sns +@mock_ses +def test_sns_feedback_complaint(): + __test_sns_feedback__(SESFeedback.COMPLAINT_ADDR, SESFeedback.COMPLAINT) + + +@mock_sqs +@mock_sns +@mock_ses +def test_sns_feedback_delivery(): + __test_sns_feedback__(SESFeedback.SUCCESS_ADDR, SESFeedback.DELIVERY) From 53f8feca55d93c1a259a5f9f796fa991811f5c87 Mon Sep 17 00:00:00 2001 From: Adrian Galera Date: Fri, 11 Jan 2019 13:35:18 +0100 Subject: [PATCH 02/24] apply linter changes --- moto/ses/models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/moto/ses/models.py b/moto/ses/models.py index 77cd5719..0b69b8cd 100644 --- a/moto/ses/models.py +++ b/moto/ses/models.py @@ -11,6 +11,7 @@ from .feedback import COMMON_MAIL, BOUNCE, COMPLAINT, DELIVERY RECIPIENT_LIMIT = 50 + class SESFeedback(BaseModel): BOUNCE = "Bounce" From 016dec6435b1efbdd54260cea3f95f7fca6bd46e Mon Sep 17 00:00:00 2001 From: Adrian Galera Date: Fri, 11 Jan 2019 13:45:34 +0100 Subject: [PATCH 03/24] Cleanup prints --- moto/ses/models.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/moto/ses/models.py b/moto/ses/models.py index 0b69b8cd..0544ac27 100644 --- a/moto/ses/models.py +++ b/moto/ses/models.py @@ -145,20 +145,14 @@ class SESBackend(BaseBackend): domain = str(source) if "@" in domain: domain = domain.split("@")[1] - print(domain, self.sns_topics) if domain in self.sns_topics: - print("SNS Feedback configured for %s => %s" % (domain, self.sns_topics[domain])) msg_type = self.__type_of_message__(destinations) - print("Message type for destinations %s => %s" % (destinations, msg_type)) if msg_type is not None: sns_topic = self.sns_topics[domain].get(msg_type, None) if sns_topic is not None: message = self.__generate_feedback__(msg_type) if message: - print("Message generated for %s => %s" % (message, msg_type)) sns_backends[region].publish(sns_topic, message) - else: - print("SNS Feedback not configured") def send_raw_email(self, source, destinations, raw_data, region): if source is not None: From a86ec26e46edc16c59e20b8540ed0838f7b8c894 Mon Sep 17 00:00:00 2001 From: William Richard Date: Tue, 22 Jan 2019 16:20:15 -0500 Subject: [PATCH 04/24] Add support for redirect actions on ELBv2 listeners --- moto/elbv2/exceptions.py | 2 +- moto/elbv2/models.py | 58 ++++++++++---- moto/elbv2/responses.py | 42 +++++++++- tests/test_elbv2/test_elbv2.py | 140 +++++++++++++++++++++++++++++++++ 4 files changed, 223 insertions(+), 19 deletions(-) diff --git a/moto/elbv2/exceptions.py b/moto/elbv2/exceptions.py index 0bf9649d..11dcbcb2 100644 --- a/moto/elbv2/exceptions.py +++ b/moto/elbv2/exceptions.py @@ -131,7 +131,7 @@ class InvalidActionTypeError(ELBClientError): def __init__(self, invalid_name, index): super(InvalidActionTypeError, self).__init__( "ValidationError", - "1 validation error detected: Value '%s' at 'actions.%s.member.type' failed to satisfy constraint: Member must satisfy enum value set: [forward]" % (invalid_name, index) + "1 validation error detected: Value '%s' at 'actions.%s.member.type' failed to satisfy constraint: Member must satisfy enum value set: [forward, redirect]" % (invalid_name, index) ) diff --git a/moto/elbv2/models.py b/moto/elbv2/models.py index 3925fa95..8d98f187 100644 --- a/moto/elbv2/models.py +++ b/moto/elbv2/models.py @@ -204,8 +204,20 @@ class FakeListener(BaseModel): # transform default actions to confirm with the rest of the code and XML templates if "DefaultActions" in properties: default_actions = [] - for action in properties['DefaultActions']: - default_actions.append({'type': action['Type'], 'target_group_arn': action['TargetGroupArn']}) + for i, action in enumerate(properties['DefaultActions']): + action_type = action['Type'] + if action_type == 'forward': + default_actions.append({'type': action_type, 'target_group_arn': action['TargetGroupArn']}) + elif action_type == 'redirect': + redirect_action = {'type': action_type, } + for redirect_config_key, redirect_config_value in action['RedirectConfig'].items(): + # need to match the output of _get_list_prefix + if redirect_config_key == 'StatusCode': + redirect_config_key = 'status_code' + redirect_action['redirect_config._' + redirect_config_key.lower()] = redirect_config_value + default_actions.append(redirect_action) + else: + raise InvalidActionTypeError(action_type, i + 1) else: default_actions = None @@ -417,11 +429,15 @@ class ELBv2Backend(BaseBackend): for i, action in enumerate(actions): index = i + 1 action_type = action['type'] - if action_type not in ['forward']: + if action_type == 'forward': + action_target_group_arn = action['target_group_arn'] + if action_target_group_arn not in target_group_arns: + raise ActionTargetGroupNotFoundError(action_target_group_arn) + elif action_type == 'redirect': + # nothing to do + pass + else: raise InvalidActionTypeError(action_type, index) - action_target_group_arn = action['target_group_arn'] - if action_target_group_arn not in target_group_arns: - raise ActionTargetGroupNotFoundError(action_target_group_arn) # TODO: check for error 'TooManyRegistrationsForTargetId' # TODO: check for error 'TooManyRules' @@ -483,10 +499,18 @@ class ELBv2Backend(BaseBackend): arn = load_balancer_arn.replace(':loadbalancer/', ':listener/') + "/%s%s" % (port, id(self)) listener = FakeListener(load_balancer_arn, arn, protocol, port, ssl_policy, certificate, default_actions) balancer.listeners[listener.arn] = listener - for action in default_actions: - if action['target_group_arn'] in self.target_groups.keys(): - target_group = self.target_groups[action['target_group_arn']] - target_group.load_balancer_arns.append(load_balancer_arn) + for i, action in enumerate(default_actions): + action_type = action['type'] + if action_type == 'forward': + if action['target_group_arn'] in self.target_groups.keys(): + target_group = self.target_groups[action['target_group_arn']] + target_group.load_balancer_arns.append(load_balancer_arn) + elif action_type == 'redirect': + # nothing to do + pass + else: + raise InvalidActionTypeError(action_type, i + 1) + return listener def describe_load_balancers(self, arns, names): @@ -649,11 +673,15 @@ class ELBv2Backend(BaseBackend): for i, action in enumerate(actions): index = i + 1 action_type = action['type'] - if action_type not in ['forward']: + if action_type == 'forward': + action_target_group_arn = action['target_group_arn'] + if action_target_group_arn not in target_group_arns: + raise ActionTargetGroupNotFoundError(action_target_group_arn) + elif action_type == 'redirect': + # nothing to do + pass + else: raise InvalidActionTypeError(action_type, index) - action_target_group_arn = action['target_group_arn'] - if action_target_group_arn not in target_group_arns: - raise ActionTargetGroupNotFoundError(action_target_group_arn) # TODO: check for error 'TooManyRegistrationsForTargetId' # TODO: check for error 'TooManyRules' @@ -873,7 +901,7 @@ class ELBv2Backend(BaseBackend): # Its already validated in responses.py listener.ssl_policy = ssl_policy - if default_actions is not None: + if default_actions is not None and default_actions != []: # Is currently not validated listener.default_actions = default_actions diff --git a/moto/elbv2/responses.py b/moto/elbv2/responses.py index 1814f127..3ca53240 100644 --- a/moto/elbv2/responses.py +++ b/moto/elbv2/responses.py @@ -704,7 +704,11 @@ CREATE_RULE_TEMPLATE = """ + {% if action["type"] == "forward" %} {{ action["target_group_arn"] }} + {% elif action["type"] == "redirect" %} + {{ action["redirect_config"] }} + {% endif %} {% endfor %} @@ -772,7 +776,15 @@ CREATE_LISTENER_TEMPLATE = """{{ action["target_group_arn"] }} + {% elif action["type"] == "redirect" %} + + {{ action["redirect_config._protocol"] }} + {{ action["redirect_config._port"] }} + {{ action["redirect_config._status_code"] }} + + {% endif %} {% endfor %} @@ -877,7 +889,15 @@ DESCRIBE_RULES_TEMPLATE = """ + {% if action["type"] == "forward" %} {{ action["target_group_arn"] }} + {% elif action["type"] == "redirect" %} + + {{ action["redirect_config._protocol"] }} + {{ action["redirect_config._port"] }} + {{ action["redirect_config._status_code"] }} + + {% endif %} {% endfor %} @@ -970,7 +990,15 @@ DESCRIBE_LISTENERS_TEMPLATE = """{{ action["target_group_arn"] }}m + {% elif action["type"] == "redirect" %} + + {{ action["redirect_config._protocol"] }} + {{ action["redirect_config._port"] }} + {{ action["redirect_config._status_code"] }} + + {% endif %} {% endfor %} @@ -1399,7 +1427,15 @@ MODIFY_LISTENER_TEMPLATE = """{{ action["target_group_arn"] }} + {% elif action["type"] == "redirect" %} + + {{ action["redirect_config._protocol"] }} + {{ action["redirect_config._port"] }} + {{ action["redirect_config._status_code"] }} + + {% endif %} {% endfor %} diff --git a/tests/test_elbv2/test_elbv2.py b/tests/test_elbv2/test_elbv2.py index b58345fd..2010e384 100644 --- a/tests/test_elbv2/test_elbv2.py +++ b/tests/test_elbv2/test_elbv2.py @@ -1586,3 +1586,143 @@ def test_create_target_groups_through_cloudformation(): assert len( [tg for tg in target_group_dicts if tg['TargetGroupName'].startswith('test-stack')] ) == 2 + + +@mock_elbv2 +@mock_ec2 +def test_redirect_action_listener_rule(): + conn = boto3.client('elbv2', region_name='us-east-1') + ec2 = boto3.resource('ec2', region_name='us-east-1') + + security_group = ec2.create_security_group( + GroupName='a-security-group', Description='First One') + vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + subnet1 = ec2.create_subnet( + VpcId=vpc.id, + CidrBlock='172.28.7.192/26', + AvailabilityZone='us-east-1a') + subnet2 = ec2.create_subnet( + VpcId=vpc.id, + CidrBlock='172.28.7.192/26', + AvailabilityZone='us-east-1b') + + response = conn.create_load_balancer( + Name='my-lb', + Subnets=[subnet1.id, subnet2.id], + SecurityGroups=[security_group.id], + Scheme='internal', + Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) + + load_balancer_arn = response.get('LoadBalancers')[0].get('LoadBalancerArn') + + response = conn.create_listener(LoadBalancerArn=load_balancer_arn, + Protocol='HTTP', + Port=80, + DefaultActions=[ + {'Type': 'redirect', + 'RedirectConfig': { + 'Protocol': 'HTTPS', + 'Port': '443', + 'StatusCode': 'HTTP_301' + }}]) + + listener = response.get('Listeners')[0] + expected_default_actions = [{ + 'Type': 'redirect', + 'RedirectConfig': { + 'Protocol': 'HTTPS', + 'Port': '443', + 'StatusCode': 'HTTP_301' + } + }] + listener.get('DefaultActions').should.equal(expected_default_actions) + listener_arn = listener.get('ListenerArn') + + describe_rules_response = conn.describe_rules(ListenerArn=listener_arn) + describe_rules_response['Rules'][0]['Actions'].should.equal(expected_default_actions) + + describe_listener_response = conn.describe_listeners(ListenerArns=[listener_arn, ]) + describe_listener_actions = describe_listener_response['Listeners'][0]['DefaultActions'] + describe_listener_actions.should.equal(expected_default_actions) + + modify_listener_response = conn.modify_listener(ListenerArn=listener_arn, Port=81) + modify_listener_actions = modify_listener_response['Listeners'][0]['DefaultActions'] + modify_listener_actions.should.equal(expected_default_actions) + + +@mock_elbv2 +@mock_cloudformation +def test_redirect_action_listener_rule_cloudformation(): + cnf_conn = boto3.client('cloudformation', region_name='us-east-1') + elbv2_client = boto3.client('elbv2', region_name='us-east-1') + + template = { + "AWSTemplateFormatVersion": "2010-09-09", + "Description": "ECS Cluster Test CloudFormation", + "Resources": { + "testVPC": { + "Type": "AWS::EC2::VPC", + "Properties": { + "CidrBlock": "10.0.0.0/16", + }, + }, + "subnet1": { + "Type": "AWS::EC2::Subnet", + "Properties": { + "CidrBlock": "10.0.0.0/24", + "VpcId": {"Ref": "testVPC"}, + "AvalabilityZone": "us-east-1b", + }, + }, + "subnet2": { + "Type": "AWS::EC2::Subnet", + "Properties": { + "CidrBlock": "10.0.1.0/24", + "VpcId": {"Ref": "testVPC"}, + "AvalabilityZone": "us-east-1b", + }, + }, + "testLb": { + "Type": "AWS::ElasticLoadBalancingV2::LoadBalancer", + "Properties": { + "Name": "my-lb", + "Subnets": [{"Ref": "subnet1"}, {"Ref": "subnet2"}], + "Type": "application", + "SecurityGroups": [], + } + }, + "testListener": { + "Type": "AWS::ElasticLoadBalancingV2::Listener", + "Properties": { + "LoadBalancerArn": {"Ref": "testLb"}, + "Port": 80, + "Protocol": "HTTP", + "DefaultActions": [{ + "Type": "redirect", + "RedirectConfig": { + "Port": "443", + "Protocol": "HTTPS", + "StatusCode": "HTTP_301", + } + }] + } + + } + } + } + template_json = json.dumps(template) + cnf_conn.create_stack(StackName="test-stack", TemplateBody=template_json) + + describe_load_balancers_response = elbv2_client.describe_load_balancers(Names=['my-lb',]) + describe_load_balancers_response['LoadBalancers'].should.have.length_of(1) + load_balancer_arn = describe_load_balancers_response['LoadBalancers'][0]['LoadBalancerArn'] + + describe_listeners_response = elbv2_client.describe_listeners(LoadBalancerArn=load_balancer_arn) + + describe_listeners_response['Listeners'].should.have.length_of(1) + describe_listeners_response['Listeners'][0]['DefaultActions'].should.equal([{ + 'Type': 'redirect', + 'RedirectConfig': { + 'Port': '443', 'Protocol': 'HTTPS', 'StatusCode': 'HTTP_301', + } + },]) From a8384c0416d0ae59cba5e14a1c3fac2fae6588d5 Mon Sep 17 00:00:00 2001 From: William Richard Date: Wed, 27 Feb 2019 15:15:50 -0500 Subject: [PATCH 05/24] Fix serial number field https://github.com/spulec/moto/pull/2077/files#diff-5fa8d19b019905e97d955f78d3dd1b99 --- moto/acm/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/moto/acm/models.py b/moto/acm/models.py index 39be8945..15a1bd44 100644 --- a/moto/acm/models.py +++ b/moto/acm/models.py @@ -243,7 +243,7 @@ class CertBundle(BaseModel): 'KeyAlgorithm': key_algo, 'NotAfter': datetime_to_epoch(self._cert.not_valid_after), 'NotBefore': datetime_to_epoch(self._cert.not_valid_before), - 'Serial': self._cert.serial, + 'Serial': self._cert.serial_number, 'SignatureAlgorithm': self._cert.signature_algorithm_oid._name.upper().replace('ENCRYPTION', ''), 'Status': self.status, # One of PENDING_VALIDATION, ISSUED, INACTIVE, EXPIRED, VALIDATION_TIMED_OUT, REVOKED, FAILED. 'Subject': 'CN={0}'.format(self.common_name), From e01d91b2d62cb0c704c3bcf66e2835ac08d44680 Mon Sep 17 00:00:00 2001 From: William Richard Date: Mon, 15 Apr 2019 23:07:14 -0400 Subject: [PATCH 06/24] Set the physical resource ID property for the lambda model --- moto/awslambda/models.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/moto/awslambda/models.py b/moto/awslambda/models.py index a37a15e2..960570c9 100644 --- a/moto/awslambda/models.py +++ b/moto/awslambda/models.py @@ -231,6 +231,10 @@ class LambdaFunction(BaseModel): config.update({"VpcId": "vpc-123abc"}) return config + @property + def physical_resource_id(self): + return self.function_name + def __repr__(self): return json.dumps(self.get_configuration()) From 9bd15b5a090d0e92bee79f7427ff54eb3b107661 Mon Sep 17 00:00:00 2001 From: Elliott Butler Date: Thu, 21 Jun 2018 21:09:04 -0500 Subject: [PATCH 07/24] Fix route53 alias response. This commit * includes the work by @elliotmb in #1694, * removes the AliasTarget.DNSName copy into a RecordSet.Value, * fixes and adds tests. --- moto/route53/models.py | 9 ++++++ moto/route53/responses.py | 5 +--- tests/test_route53/test_route53.py | 47 ++++++++++++++++++++++++++---- 3 files changed, 51 insertions(+), 10 deletions(-) diff --git a/moto/route53/models.py b/moto/route53/models.py index 3760d381..071dcfac 100644 --- a/moto/route53/models.py +++ b/moto/route53/models.py @@ -85,6 +85,7 @@ class RecordSet(BaseModel): self.health_check = kwargs.get('HealthCheckId') self.hosted_zone_name = kwargs.get('HostedZoneName') self.hosted_zone_id = kwargs.get('HostedZoneId') + self.alias_target = kwargs.get('AliasTarget') @classmethod def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): @@ -143,6 +144,13 @@ class RecordSet(BaseModel): {% if record_set.ttl %} {{ record_set.ttl }} {% endif %} + {% if record_set.alias_target %} + + {{ record_set.alias_target['HostedZoneId'] }} + {{ record_set.alias_target['DNSName'] }} + {{ record_set.alias_target['EvaluateTargetHealth'] }} + + {% else %} {% for record in record_set.records %} @@ -150,6 +158,7 @@ class RecordSet(BaseModel): {% endfor %} + {% endif %} {% if record_set.health_check %} {{ record_set.health_check }} {% endif %} diff --git a/moto/route53/responses.py b/moto/route53/responses.py index 98ffa4c4..981362b1 100644 --- a/moto/route53/responses.py +++ b/moto/route53/responses.py @@ -134,10 +134,7 @@ class Route53(BaseResponse): # Depending on how many records there are, this may # or may not be a list resource_records = [resource_records] - record_values = [x['Value'] for x in resource_records] - elif 'AliasTarget' in record_set: - record_values = [record_set['AliasTarget']['DNSName']] - record_set['ResourceRecords'] = record_values + record_set['ResourceRecords'] = [x['Value'] for x in resource_records] if action == 'CREATE': the_zone.add_rrset(record_set) else: diff --git a/tests/test_route53/test_route53.py b/tests/test_route53/test_route53.py index d730f8dc..97cd82d2 100644 --- a/tests/test_route53/test_route53.py +++ b/tests/test_route53/test_route53.py @@ -172,14 +172,16 @@ def test_alias_rrset(): changes.commit() rrsets = conn.get_all_rrsets(zoneid, type="A") - rrset_records = [(rr_set.name, rr) for rr_set in rrsets for rr in rr_set.resource_records] - rrset_records.should.have.length_of(2) - rrset_records.should.contain(('foo.alias.testdns.aws.com.', 'foo.testdns.aws.com')) - rrset_records.should.contain(('bar.alias.testdns.aws.com.', 'bar.testdns.aws.com')) - rrsets[0].resource_records[0].should.equal('foo.testdns.aws.com') + alias_targets = [rr_set.alias_dns_name for rr_set in rrsets] + alias_targets.should.have.length_of(2) + alias_targets.should.contain('foo.testdns.aws.com') + alias_targets.should.contain('bar.testdns.aws.com') + rrsets[0].alias_dns_name.should.equal('foo.testdns.aws.com') + rrsets[0].resource_records.should.have.length_of(0) rrsets = conn.get_all_rrsets(zoneid, type="CNAME") rrsets.should.have.length_of(1) - rrsets[0].resource_records[0].should.equal('bar.testdns.aws.com') + rrsets[0].alias_dns_name.should.equal('bar.testdns.aws.com') + rrsets[0].resource_records.should.have.length_of(0) @mock_route53_deprecated @@ -582,6 +584,39 @@ def test_change_resource_record_sets_crud_valid(): cname_record_detail['TTL'].should.equal(60) cname_record_detail['ResourceRecords'].should.equal([{'Value': '192.168.1.1'}]) + # Update to add Alias. + cname_alias_record_endpoint_payload = { + 'Comment': 'Update to Alias prod.redis.db', + 'Changes': [ + { + 'Action': 'UPSERT', + 'ResourceRecordSet': { + 'Name': 'prod.redis.db.', + 'Type': 'A', + 'TTL': 60, + 'AliasTarget': { + 'HostedZoneId': hosted_zone_id, + 'DNSName': 'prod.redis.alias.', + 'EvaluateTargetHealth': False, + } + } + } + ] + } + conn.change_resource_record_sets(HostedZoneId=hosted_zone_id, ChangeBatch=cname_alias_record_endpoint_payload) + + response = conn.list_resource_record_sets(HostedZoneId=hosted_zone_id) + cname_alias_record_detail = response['ResourceRecordSets'][0] + cname_alias_record_detail['Name'].should.equal('prod.redis.db.') + cname_alias_record_detail['Type'].should.equal('A') + cname_alias_record_detail['TTL'].should.equal(60) + cname_alias_record_detail['AliasTarget'].should.equal({ + 'HostedZoneId': hosted_zone_id, + 'DNSName': 'prod.redis.alias.', + 'EvaluateTargetHealth': False, + }) + cname_alias_record_detail.should_not.contain('ResourceRecords') + # Delete record. delete_payload = { 'Comment': 'delete prod.redis.db', From 1a2fc66f84b527a5f8b65acf0f97a423d2d9925d Mon Sep 17 00:00:00 2001 From: Matthew Stevens Date: Mon, 1 Apr 2019 15:15:20 -0400 Subject: [PATCH 08/24] Adding dynamodb2 expression parser and fixing test cases --- moto/dynamodb2/comparisons.py | 1106 +++++++++++++++++++------ moto/dynamodb2/condition.py | 617 ++++++++++++++ moto/dynamodb2/models.py | 43 +- tests/test_dynamodb2/test_dynamodb.py | 63 +- 4 files changed, 1531 insertions(+), 298 deletions(-) create mode 100644 moto/dynamodb2/condition.py diff --git a/moto/dynamodb2/comparisons.py b/moto/dynamodb2/comparisons.py index 6d37345f..ac78d45b 100644 --- a/moto/dynamodb2/comparisons.py +++ b/moto/dynamodb2/comparisons.py @@ -1,6 +1,40 @@ from __future__ import unicode_literals import re import six +import re +import enum +from collections import deque +from collections import namedtuple + + +def get_filter_expression(expr, names, values): + """ + Parse a filter expression into an Op. + + Examples + expr = 'Id > 5 AND attribute_exists(test) AND Id BETWEEN 5 AND 6 OR length < 6 AND contains(test, 1) AND 5 IN (4,5, 6) OR (Id < 5 AND 5 > Id)' + expr = 'Id > 5 AND Subs < 7' + """ + parser = ConditionExpressionParser(expr, names, values) + return parser.parse() + + +class Op(object): + """ + Base class for a FilterExpression operator + """ + OP = '' + + def __init__(self, lhs, rhs): + self.lhs = lhs + self.rhs = rhs + + def expr(self, item): + raise NotImplementedError("Expr not defined for {0}".format(type(self))) + + def __repr__(self): + return '({0} {1} {2})'.format(self.lhs, self.OP, self.rhs) + # TODO add tests for all of these EQ_FUNCTION = lambda item_value, test_value: item_value == test_value # flake8: noqa @@ -49,292 +83,783 @@ class RecursionStopIteration(StopIteration): pass -def get_filter_expression(expr, names, values): - # Examples - # expr = 'Id > 5 AND attribute_exists(test) AND Id BETWEEN 5 AND 6 OR length < 6 AND contains(test, 1) AND 5 IN (4,5, 6) OR (Id < 5 AND 5 > Id)' - # expr = 'Id > 5 AND Subs < 7' - if names is None: - names = {} - if values is None: - values = {} +class ConditionExpressionParser: + def __init__(self, condition_expression, expression_attribute_names, + expression_attribute_values): + self.condition_expression = condition_expression + self.expression_attribute_names = expression_attribute_names + self.expression_attribute_values = expression_attribute_values - # Do substitutions - for key, value in names.items(): - expr = expr.replace(key, value) + def parse(self): + """Returns a syntax tree for the expression. - # Store correct types of values for use later - values_map = {} - for key, value in values.items(): - if 'N' in value: - values_map[key] = float(value['N']) - elif 'BOOL' in value: - values_map[key] = value['BOOL'] - elif 'S' in value: - values_map[key] = value['S'] - elif 'NS' in value: - values_map[key] = tuple(value['NS']) - elif 'SS' in value: - values_map[key] = tuple(value['SS']) - elif 'L' in value: - values_map[key] = tuple(value['L']) + The tree, and all of the nodes in the tree are a tuple of + - kind: str + - children/value: + list of nodes for parent nodes + value for leaf nodes + + Raises ValueError if the condition expression is invalid + Raises KeyError if expression attribute names/values are invalid + + Here are the types of nodes that can be returned. + The types of child nodes are denoted with a colon (:). + An arbitrary number of children is denoted with ... + + Condition: + ('OR', [lhs : Condition, rhs : Condition]) + ('AND', [lhs: Condition, rhs: Condition]) + ('NOT', [argument: Condition]) + ('PARENTHESES', [argument: Condition]) + ('FUNCTION', [('LITERAL', function_name: str), argument: Operand, ...]) + ('BETWEEN', [query: Operand, low: Operand, high: Operand]) + ('IN', [query: Operand, possible_value: Operand, ...]) + ('COMPARISON', [lhs: Operand, ('LITERAL', comparator: str), rhs: Operand]) + + Operand: + ('EXPRESSION_ATTRIBUTE_VALUE', value: dict, e.g. {'S': 'foobar'}) + ('PATH', [('LITERAL', path_element: str), ...]) + NOTE: Expression attribute names will be expanded + ('FUNCTION', [('LITERAL', 'size'), argument: Operand]) + + Literal: + ('LITERAL', value: str) + + """ + if not self.condition_expression: + return OpDefault(None, None) + nodes = self._lex_condition_expression() + nodes = self._parse_paths(nodes) + # NOTE: The docs say that functions should be parsed after + # IN, BETWEEN, and comparisons like <=. + # However, these expressions are invalid as function arguments, + # so it is okay to parse functions first. This needs to be done + # to interpret size() correctly as an operand. + nodes = self._apply_functions(nodes) + nodes = self._apply_comparator(nodes) + nodes = self._apply_in(nodes) + nodes = self._apply_between(nodes) + nodes = self._apply_parens_and_booleans(nodes) + node = nodes[0] + op = self._make_op_condition(node) + return op + + class Kind(enum.Enum): + """Defines types of nodes in the syntax tree.""" + + # Condition nodes + # --------------- + OR = enum.auto() + AND = enum.auto() + NOT = enum.auto() + PARENTHESES = enum.auto() + FUNCTION = enum.auto() + BETWEEN = enum.auto() + IN = enum.auto() + COMPARISON = enum.auto() + + # Operand nodes + # ------------- + EXPRESSION_ATTRIBUTE_VALUE = enum.auto() + PATH = enum.auto() + + # Literal nodes + # -------------- + LITERAL = enum.auto() + + + class Nonterminal(enum.Enum): + """Defines nonterminals for defining productions.""" + CONDITION = enum.auto() + OPERAND = enum.auto() + COMPARATOR = enum.auto() + FUNCTION_NAME = enum.auto() + IDENTIFIER = enum.auto() + AND = enum.auto() + OR = enum.auto() + NOT = enum.auto() + BETWEEN = enum.auto() + IN = enum.auto() + COMMA = enum.auto() + LEFT_PAREN = enum.auto() + RIGHT_PAREN = enum.auto() + WHITESPACE = enum.auto() + + + Node = namedtuple('Node', ['nonterminal', 'kind', 'text', 'value', 'children']) + + def _lex_condition_expression(self): + nodes = deque() + remaining_expression = self.condition_expression + while remaining_expression: + node, remaining_expression = \ + self._lex_one_node(remaining_expression) + if node.nonterminal == self.Nonterminal.WHITESPACE: + continue + nodes.append(node) + return nodes + + def _lex_one_node(self, remaining_expression): + # TODO: Handle indexing like [1] + attribute_regex = '(:|#)?[A-z0-9\-_]+' + patterns = [( + self.Nonterminal.WHITESPACE, re.compile('^ +') + ), ( + self.Nonterminal.COMPARATOR, re.compile( + '^(' + # Put long expressions first for greedy matching + '<>|' + '<=|' + '>=|' + '=|' + '<|' + '>)'), + ), ( + self.Nonterminal.OPERAND, re.compile( + '^' + + attribute_regex + '(\.' + attribute_regex + '|\[[0-9]\])*') + ), ( + self.Nonterminal.COMMA, re.compile('^,') + ), ( + self.Nonterminal.LEFT_PAREN, re.compile('^\(') + ), ( + self.Nonterminal.RIGHT_PAREN, re.compile('^\)') + )] + + for nonterminal, pattern in patterns: + match = pattern.match(remaining_expression) + if match: + match_text = match.group() + break else: - raise NotImplementedError() + raise ValueError("Cannot parse condition starting at: " + + remaining_expression) - # Remove all spaces, tbf we could just skip them in the next step. - # The number of known options is really small so we can do a fair bit of cheating - expr = list(expr.strip()) + value = match_text + node = self.Node( + nonterminal=nonterminal, + kind=self.Kind.LITERAL, + text=match_text, + value=match_text, + children=[]) - # DodgyTokenisation stage 1 - def is_value(val): - return val not in ('<', '>', '=', '(', ')') + remaining_expression = remaining_expression[len(match_text):] - def contains_keyword(val): - for kw in ('BETWEEN', 'IN', 'AND', 'OR', 'NOT'): - if kw in val: - return kw - return None + return node, remaining_expression - def is_function(val): - return val in ('attribute_exists', 'attribute_not_exists', 'attribute_type', 'begins_with', 'contains', 'size') + def _parse_paths(self, nodes): + output = deque() - # Does the main part of splitting between sections of characters - tokens = [] - stack = '' - while len(expr) > 0: - current_char = expr.pop(0) + while nodes: + node = nodes.popleft() - if current_char == ' ': - if len(stack) > 0: - tokens.append(stack) - stack = '' - elif current_char == ',': # Split params , - if len(stack) > 0: - tokens.append(stack) - stack = '' - elif is_value(current_char): - stack += current_char + if node.nonterminal == self.Nonterminal.OPERAND: + path = node.value.replace('[', '.[').split('.') + children = [ + self._parse_path_element(name) + for name in path] + if len(children) == 1: + child = children[0] + if child.nonterminal != self.Nonterminal.IDENTIFIER: + output.append(child) + continue + else: + for child in children: + self._assert( + child.nonterminal == self.Nonterminal.IDENTIFIER, + "Cannot use %s in path" % child.text, [node]) + output.append(self.Node( + nonterminal=self.Nonterminal.OPERAND, + kind=self.Kind.PATH, + text=node.text, + value=None, + children=children)) + else: + output.append(node) + return output - kw = contains_keyword(stack) - if kw is not None: - # We have a kw in the stack, could be AND or something like 5AND - tmp = stack.replace(kw, '') - if len(tmp) > 0: - tokens.append(tmp) - tokens.append(kw) - stack = '' + def _parse_path_element(self, name): + reserved = { + 'and': self.Nonterminal.AND, + 'or': self.Nonterminal.OR, + 'in': self.Nonterminal.IN, + 'between': self.Nonterminal.BETWEEN, + 'not': self.Nonterminal.NOT, + } + + functions = { + 'attribute_exists', + 'attribute_not_exists', + 'attribute_type', + 'begins_with', + 'contains', + 'size', + } + + + if name.lower() in reserved: + # e.g. AND + nonterminal = reserved[name.lower()] + return self.Node( + nonterminal=nonterminal, + kind=self.Kind.LITERAL, + text=name, + value=name, + children=[]) + elif name in functions: + # e.g. attribute_exists + return self.Node( + nonterminal=self.Nonterminal.FUNCTION_NAME, + kind=self.Kind.LITERAL, + text=name, + value=name, + children=[]) + elif name.startswith(':'): + # e.g. :value0 + return self.Node( + nonterminal=self.Nonterminal.OPERAND, + kind=self.Kind.EXPRESSION_ATTRIBUTE_VALUE, + text=name, + value=self._lookup_expression_attribute_value(name), + children=[]) + elif name.startswith('#'): + # e.g. #name0 + return self.Node( + nonterminal=self.Nonterminal.IDENTIFIER, + kind=self.Kind.LITERAL, + text=name, + value=self._lookup_expression_attribute_name(name), + children=[]) + elif name.startswith('['): + # e.g. [123] + if not name.endswith(']'): + raise ValueError("Bad path element %s" % name) + return self.Node( + nonterminal=self.Nonterminal.IDENTIFIER, + kind=self.Kind.LITERAL, + text=name, + value=int(name[1:-1]), + children=[]) else: - if len(stack) > 0: - tokens.append(stack) - tokens.append(current_char) - stack = '' - if len(stack) > 0: - tokens.append(stack) + # e.g. ItemId + return self.Node( + nonterminal=self.Nonterminal.IDENTIFIER, + kind=self.Kind.LITERAL, + text=name, + value=name, + children=[]) - def is_op(val): - return val in ('<', '>', '=', '>=', '<=', '<>', 'BETWEEN', 'IN', 'AND', 'OR', 'NOT') + def _lookup_expression_attribute_value(self, name): + return self.expression_attribute_values[name] - # DodgyTokenisation stage 2, it groups together some elements to make RPN'ing it later easier. - def handle_token(token, tokens2, token_iterator): - # ok so this essentially groups up some tokens to make later parsing easier, - # when it encounters brackets it will recurse and then unrecurse when RecursionStopIteration is raised. - if token == ')': - raise RecursionStopIteration() # Should be recursive so this should work - elif token == '(': - temp_list = [] + def _lookup_expression_attribute_name(self, name): + return self.expression_attribute_names[name] - try: + # NOTE: The following constructions are ordered from high precedence to low precedence + # according to + # https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Expressions.OperatorsAndFunctions.html#Expressions.OperatorsAndFunctions.Precedence + # + # = <> < <= > >= + # IN + # BETWEEN + # attribute_exists attribute_not_exists begins_with contains + # Parentheses + # NOT + # AND + # OR + # + # The grammar is taken from + # https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Expressions.OperatorsAndFunctions.html#Expressions.OperatorsAndFunctions.Syntax + # + # condition-expression ::= + # operand comparator operand + # operand BETWEEN operand AND operand + # operand IN ( operand (',' operand (, ...) )) + # function + # condition AND condition + # condition OR condition + # NOT condition + # ( condition ) + # + # comparator ::= + # = + # <> + # < + # <= + # > + # >= + # + # function ::= + # attribute_exists (path) + # attribute_not_exists (path) + # attribute_type (path, type) + # begins_with (path, substr) + # contains (path, operand) + # size (path) + + def _matches(self, nodes, production): + """Check if the nodes start with the given production. + + Parameters + ---------- + nodes: list of Node + production: list of str + The name of a Nonterminal, or '*' for anything + + """ + if len(nodes) < len(production): + return False + for i in range(len(production)): + if production[i] == '*': + continue + expected = getattr(self.Nonterminal, production[i]) + if nodes[i].nonterminal != expected: + return False + return True + + def _apply_comparator(self, nodes): + """Apply condition := operand comparator operand.""" + output = deque() + + while nodes: + if self._matches(nodes, ['*', 'COMPARATOR']): + self._assert( + self._matches(nodes, ['OPERAND', 'COMPARATOR', 'OPERAND']), + "Bad comparison", list(nodes)[:3]) + lhs = nodes.popleft() + comparator = nodes.popleft() + rhs = nodes.popleft() + output.append(self.Node( + nonterminal=self.Nonterminal.CONDITION, + kind=self.Kind.COMPARISON, + text=" ".join([ + lhs.text, + comparator.text, + rhs.text]), + value=None, + children=[lhs, comparator, rhs])) + else: + output.append(nodes.popleft()) + return output + + def _apply_in(self, nodes): + """Apply condition := operand IN ( operand , ... ).""" + output = deque() + while nodes: + if self._matches(nodes, ['*', 'IN']): + self._assert( + self._matches(nodes, ['OPERAND', 'IN', 'LEFT_PAREN']), + "Bad IN expression", list(nodes)[:3]) + lhs = nodes.popleft() + in_node = nodes.popleft() + left_paren = nodes.popleft() + all_children = [lhs, in_node, left_paren] + rhs = [] while True: - next_token = six.next(token_iterator) - handle_token(next_token, temp_list, token_iterator) - except RecursionStopIteration: - pass # Continue - except StopIteration: - ValueError('Malformed filter expression, type1') - - # Sigh, we only want to group a tuple if it doesnt contain operators - if any([is_op(item) for item in temp_list]): - # Its an expression - tokens2.append('(') - tokens2.extend(temp_list) - tokens2.append(')') + if self._matches(nodes, ['OPERAND', 'COMMA']): + operand = nodes.popleft() + separator = nodes.popleft() + all_children += [operand, separator] + rhs.append(operand) + elif self._matches(nodes, ['OPERAND', 'RIGHT_PAREN']): + operand = nodes.popleft() + separator = nodes.popleft() + all_children += [operand, separator] + rhs.append(operand) + break # Close + else: + self._assert( + False, + "Bad IN expression starting at", nodes) + output.append(self.Node( + nonterminal=self.Nonterminal.CONDITION, + kind=self.Kind.IN, + text=" ".join([t.text for t in all_children]), + value=None, + children=[lhs] + rhs)) else: - tokens2.append(tuple(temp_list)) - elif token == 'BETWEEN': - field = tokens2.pop() - # if values map contains a number, it would be a float - # so we need to int() it anyway - op1 = six.next(token_iterator) - op1 = int(values_map.get(op1, op1)) - and_op = six.next(token_iterator) - assert and_op == 'AND' - op2 = six.next(token_iterator) - op2 = int(values_map.get(op2, op2)) - tokens2.append(['between', field, op1, op2]) - elif is_function(token): - function_list = [token] + output.append(nodes.popleft()) + return output - lbracket = six.next(token_iterator) - assert lbracket == '(' - - next_token = six.next(token_iterator) - while next_token != ')': - if next_token in values_map: - next_token = values_map[next_token] - function_list.append(next_token) - next_token = six.next(token_iterator) - - tokens2.append(function_list) - else: - # Convert tokens back to real types - if token in values_map: - token = values_map[token] - - # Need to join >= <= <> - if len(tokens2) > 0 and ((tokens2[-1] == '>' and token == '=') or (tokens2[-1] == '<' and token == '=') or (tokens2[-1] == '<' and token == '>')): - tokens2.append(tokens2.pop() + token) + def _apply_between(self, nodes): + """Apply condition := operand BETWEEN operand AND operand.""" + output = deque() + while nodes: + if self._matches(nodes, ['*', 'BETWEEN']): + self._assert( + self._matches(nodes, ['OPERAND', 'BETWEEN', 'OPERAND', + 'AND', 'OPERAND']), + "Bad BETWEEN expression", list(nodes)[:5]) + lhs = nodes.popleft() + between_node = nodes.popleft() + low = nodes.popleft() + and_node = nodes.popleft() + high = nodes.popleft() + all_children = [lhs, between_node, low, and_node, high] + output.append(self.Node( + nonterminal=self.Nonterminal.CONDITION, + kind=self.Kind.BETWEEN, + text=" ".join([t.text for t in all_children]), + value=None, + children=[lhs, low, high])) else: - tokens2.append(token) + output.append(nodes.popleft()) + return output - tokens2 = [] - token_iterator = iter(tokens) - for token in token_iterator: - handle_token(token, tokens2, token_iterator) - - # Start of the Shunting-Yard algorithm. <-- Proper beast algorithm! - def is_number(val): - return val not in ('<', '>', '=', '>=', '<=', '<>', 'BETWEEN', 'IN', 'AND', 'OR', 'NOT') - - OPS = {'<': 5, '>': 5, '=': 5, '>=': 5, '<=': 5, '<>': 5, 'IN': 8, 'AND': 11, 'OR': 12, 'NOT': 10, 'BETWEEN': 9, '(': 100, ')': 100} - - def shunting_yard(token_list): - output = [] - op_stack = [] - - # Basically takes in an infix notation calculation, converts it to a reverse polish notation where there is no - # ambiguity on which order operators are applied. - while len(token_list) > 0: - token = token_list.pop(0) - - if token == '(': - op_stack.append(token) - elif token == ')': - while len(op_stack) > 0 and op_stack[-1] != '(': - output.append(op_stack.pop()) - lbracket = op_stack.pop() - assert lbracket == '(' - - elif is_number(token): - output.append(token) + def _apply_functions(self, nodes): + """Apply condition := function_name (operand , ...).""" + output = deque() + either_kind = {self.Kind.PATH, self.Kind.EXPRESSION_ATTRIBUTE_VALUE} + expected_argument_kind_map = { + 'attribute_exists': [{self.Kind.PATH}], + 'attribute_not_exists': [{self.Kind.PATH}], + 'attribute_type': [either_kind, {self.Kind.EXPRESSION_ATTRIBUTE_VALUE}], + 'begins_with': [either_kind, either_kind], + 'contains': [either_kind, either_kind], + 'size': [{self.Kind.PATH}], + } + while nodes: + if self._matches(nodes, ['FUNCTION_NAME']): + self._assert( + self._matches(nodes, ['FUNCTION_NAME', 'LEFT_PAREN', + 'OPERAND', '*']), + "Bad function expression at", list(nodes)[:4]) + function_name = nodes.popleft() + left_paren = nodes.popleft() + all_children = [function_name, left_paren] + arguments = [] + while True: + if self._matches(nodes, ['OPERAND', 'COMMA']): + operand = nodes.popleft() + separator = nodes.popleft() + all_children += [operand, separator] + arguments.append(operand) + elif self._matches(nodes, ['OPERAND', 'RIGHT_PAREN']): + operand = nodes.popleft() + separator = nodes.popleft() + all_children += [operand, separator] + arguments.append(operand) + break # Close paren + else: + self._assert( + False, + "Bad function expression", all_children + list(nodes)[:2]) + expected_kinds = expected_argument_kind_map[function_name.value] + self._assert( + len(arguments) == len(expected_kinds), + "Wrong number of arguments in", all_children) + for i in range(len(expected_kinds)): + self._assert( + arguments[i].kind in expected_kinds[i], + "Wrong type for argument %d in" % i, all_children) + if function_name.value == 'size': + nonterminal = self.Nonterminal.OPERAND + else: + nonterminal = self.Nonterminal.CONDITION + output.append(self.Node( + nonterminal=nonterminal, + kind=self.Kind.FUNCTION, + text=" ".join([t.text for t in all_children]), + value=None, + children=[function_name] + arguments)) else: - # Must be operator kw + output.append(nodes.popleft()) + return output - # Cheat, NOT is our only RIGHT associative operator, should really have dict of operator associativity - while len(op_stack) > 0 and OPS[op_stack[-1]] <= OPS[token] and op_stack[-1] != 'NOT': - output.append(op_stack.pop()) - op_stack.append(token) - while len(op_stack) > 0: - output.append(op_stack.pop()) + def _apply_parens_and_booleans(self, nodes, left_paren=None): + """Apply condition := ( condition ) and booleans.""" + output = deque() + while nodes: + if self._matches(nodes, ['LEFT_PAREN']): + parsed = self._apply_parens_and_booleans(nodes, left_paren=nodes.popleft()) + self._assert( + len(parsed) >= 1, + "Failed to close parentheses at", nodes) + parens = parsed.popleft() + self._assert( + parens.kind == self.Kind.PARENTHESES, + "Failed to close parentheses at", nodes) + output.append(parens) + nodes = parsed + elif self._matches(nodes, ['RIGHT_PAREN']): + self._assert( + left_paren is not None, + "Unmatched ) at", nodes) + close_paren = nodes.popleft() + children = self._apply_booleans(output) + all_children = [left_paren, *children, close_paren] + return deque([ + self.Node( + nonterminal=self.Nonterminal.CONDITION, + kind=self.Kind.PARENTHESES, + text=" ".join([t.text for t in all_children]), + value=None, + children=list(children), + ), *nodes]) + else: + output.append(nodes.popleft()) + + self._assert( + left_paren is None, + "Unmatched ( at", list(output)) + return self._apply_booleans(output) + + def _apply_booleans(self, nodes): + """Apply and, or, and not constructions.""" + nodes = self._apply_not(nodes) + nodes = self._apply_and(nodes) + nodes = self._apply_or(nodes) + # The expression should reduce to a single condition + self._assert( + len(nodes) == 1, + "Unexpected expression at", list(nodes)[1:]) + self._assert( + nodes[0].nonterminal == self.Nonterminal.CONDITION, + "Incomplete condition", nodes) + return nodes + + def _apply_not(self, nodes): + """Apply condition := NOT condition.""" + output = deque() + while nodes: + if self._matches(nodes, ['NOT']): + self._assert( + self._matches(nodes, ['NOT', 'CONDITION']), + "Bad NOT expression", list(nodes)[:2]) + not_node = nodes.popleft() + child = nodes.popleft() + output.append(self.Node( + nonterminal=self.Nonterminal.CONDITION, + kind=self.Kind.NOT, + text=" ".join([not_node.text, child.text]), + value=None, + children=[child])) + else: + output.append(nodes.popleft()) return output - output = shunting_yard(tokens2) - - # Hacky function to convert dynamo functions (which are represented as lists) to their Class equivalent - def to_func(val): - if isinstance(val, list): - func_name = val.pop(0) - # Expand rest of the list to arguments - val = FUNC_CLASS[func_name](*val) - - return val - - # Simple reverse polish notation execution. Builts up a nested filter object. - # The filter object then takes a dynamo item and returns true/false - stack = [] - for token in output: - if is_op(token): - op_cls = OP_CLASS[token] - - if token == 'NOT': - op1 = stack.pop() - op2 = True + def _apply_and(self, nodes): + """Apply condition := condition AND condition.""" + output = deque() + while nodes: + if self._matches(nodes, ['*', 'AND']): + self._assert( + self._matches(nodes, ['CONDITION', 'AND', 'CONDITION']), + "Bad AND expression", list(nodes)[:3]) + lhs = nodes.popleft() + and_node = nodes.popleft() + rhs = nodes.popleft() + all_children = [lhs, and_node, rhs] + output.append(self.Node( + nonterminal=self.Nonterminal.CONDITION, + kind=self.Kind.AND, + text=" ".join([t.text for t in all_children]), + value=None, + children=[lhs, rhs])) else: - op2 = stack.pop() - op1 = stack.pop() + output.append(nodes.popleft()) - stack.append(op_cls(op1, op2)) + return output + + def _apply_or(self, nodes): + """Apply condition := condition OR condition.""" + output = deque() + while nodes: + if self._matches(nodes, ['*', 'OR']): + self._assert( + self._matches(nodes, ['CONDITION', 'OR', 'CONDITION']), + "Bad OR expression", list(nodes)[:3]) + lhs = nodes.popleft() + or_node = nodes.popleft() + rhs = nodes.popleft() + all_children = [lhs, or_node, rhs] + output.append(self.Node( + nonterminal=self.Nonterminal.CONDITION, + kind=self.Kind.OR, + text=" ".join([t.text for t in all_children]), + value=None, + children=[lhs, rhs])) + else: + output.append(nodes.popleft()) + + return output + + def _make_operand(self, node): + if node.kind == self.Kind.PATH: + return AttributePath([child.value for child in node.children]) + elif node.kind == self.Kind.EXPRESSION_ATTRIBUTE_VALUE: + return AttributeValue(node.value) + elif node.kind == self.Kind.FUNCTION: + # size() + function_node, *arguments = node.children + function_name = function_node.value + arguments = [self._make_operand(arg) for arg in arguments] + return FUNC_CLASS[function_name](*arguments) else: - stack.append(to_func(token)) - - result = stack.pop(0) - if len(stack) > 0: - raise ValueError('Malformed filter expression, type2') - - return result + raise ValueError("Unknown operand: %r" % node) -class Op(object): - """ - Base class for a FilterExpression operator - """ - OP = '' + def _make_op_condition(self, node): + if node.kind == self.Kind.OR: + lhs, rhs = node.children + return OpOr( + self._make_op_condition(lhs), + self._make_op_condition(rhs)) + elif node.kind == self.Kind.AND: + lhs, rhs = node.children + return OpAnd( + self._make_op_condition(lhs), + self._make_op_condition(rhs)) + elif node.kind == self.Kind.NOT: + child, = node.children + return OpNot(self._make_op_condition(child), None) + elif node.kind == self.Kind.PARENTHESES: + child, = node.children + return self._make_op_condition(child) + elif node.kind == self.Kind.FUNCTION: + function_node, *arguments = node.children + function_name = function_node.value + arguments = [self._make_operand(arg) for arg in arguments] + return FUNC_CLASS[function_name](*arguments) + elif node.kind == self.Kind.BETWEEN: + query, low, high = node.children + return FuncBetween( + self._make_operand(query), + self._make_operand(low), + self._make_operand(high)) + elif node.kind == self.Kind.IN: + query, *possible_values = node.children + query = self._make_operand(query) + possible_values = [self._make_operand(v) for v in possible_values] + return FuncIn(query, *possible_values) + elif node.kind == self.Kind.COMPARISON: + lhs, comparator, rhs = node.children + return OP_CLASS[comparator.value]( + self._make_operand(lhs), + self._make_operand(rhs)) + else: + raise ValueError("Unknown expression node kind %r" % node.kind) - def __init__(self, lhs, rhs): - self.lhs = lhs - self.rhs = rhs + def _print_debug(self, nodes): + print('ROOT') + for node in nodes: + self._print_node_recursive(node, depth=1) + + def _print_node_recursive(self, node, depth=0): + if len(node.children) > 0: + print(' ' * depth, node.nonterminal, node.kind) + for child in node.children: + self._print_node_recursive(child, depth=depth + 1) + else: + print(' ' * depth, node.nonterminal, node.kind, node.value) + + + + def _assert(self, condition, message, nodes): + if not condition: + raise ValueError(message + " " + " ".join([t.text for t in nodes])) + + +class Operand(object): + def expr(self, item): + raise NotImplementedError + + def get_type(self, item): + raise NotImplementedError + + +class AttributePath(Operand): + def __init__(self, path): + """Initialize the AttributePath. + + Parameters + ---------- + path: list of int/str - def _lhs(self, item): """ - :type item: moto.dynamodb2.models.Item - """ - lhs = self.lhs - if isinstance(self.lhs, (Op, Func)): - lhs = self.lhs.expr(item) - elif isinstance(self.lhs, six.string_types): - try: - lhs = item.attrs[self.lhs].cast_value - except Exception: - pass + assert len(path) >= 1 + self.path = path - return lhs - - def _rhs(self, item): - rhs = self.rhs - if isinstance(self.rhs, (Op, Func)): - rhs = self.rhs.expr(item) - elif isinstance(self.rhs, six.string_types): - try: - rhs = item.attrs[self.rhs].cast_value - except Exception: - pass - return rhs + def _get_attr(self, item): + base = self.path[0] + if base not in item.attrs: + return None + attr = item.attrs[base] + for name in self.path[1:]: + attr = attr.child_attr(name) + if attr is None: + return None + return attr def expr(self, item): - return True + attr = self._get_attr(item) + if attr is None: + return None + else: + return attr.cast_value + + def get_type(self, item): + attr = self._get_attr(item) + if attr is None: + return None + else: + return attr.type def __repr__(self): - return '({0} {1} {2})'.format(self.lhs, self.OP, self.rhs) + return self.path -class Func(object): - """ - Base class for a FilterExpression function - """ - FUNC = 'Unknown' +class AttributeValue(Operand): + def __init__(self, value): + """Initialize the AttributePath. + + Parameters + ---------- + value: dict + e.g. {'N': '1.234'} + + """ + self.type = list(value.keys())[0] + if 'N' in value: + self.value = float(value['N']) + elif 'BOOL' in value: + self.value = value['BOOL'] + elif 'S' in value: + self.value = value['S'] + elif 'NS' in value: + self.value = tuple(value['NS']) + elif 'SS' in value: + self.value = tuple(value['SS']) + elif 'L' in value: + self.value = tuple(value['L']) + else: + # TODO: Handle all attribute types + raise NotImplementedError() def expr(self, item): - return True + return self.value + + def get_type(self, item): + return self.type def __repr__(self): - return 'Func(...)'.format(self.FUNC) + return repr(self.value) + + +class OpDefault(Op): + OP = 'NONE' + + def expr(self, item): + """If no condition is specified, always True.""" + return True class OpNot(Op): OP = 'NOT' def expr(self, item): - lhs = self._lhs(item) - + lhs = self.lhs.expr(item) return not lhs def __str__(self): @@ -345,8 +870,8 @@ class OpAnd(Op): OP = 'AND' def expr(self, item): - lhs = self._lhs(item) - rhs = self._rhs(item) + lhs = self.lhs.expr(item) + rhs = self.rhs.expr(item) return lhs and rhs @@ -354,8 +879,8 @@ class OpLessThan(Op): OP = '<' def expr(self, item): - lhs = self._lhs(item) - rhs = self._rhs(item) + lhs = self.lhs.expr(item) + rhs = self.rhs.expr(item) return lhs < rhs @@ -363,8 +888,8 @@ class OpGreaterThan(Op): OP = '>' def expr(self, item): - lhs = self._lhs(item) - rhs = self._rhs(item) + lhs = self.lhs.expr(item) + rhs = self.rhs.expr(item) return lhs > rhs @@ -372,8 +897,8 @@ class OpEqual(Op): OP = '=' def expr(self, item): - lhs = self._lhs(item) - rhs = self._rhs(item) + lhs = self.lhs.expr(item) + rhs = self.rhs.expr(item) return lhs == rhs @@ -381,8 +906,8 @@ class OpNotEqual(Op): OP = '<>' def expr(self, item): - lhs = self._lhs(item) - rhs = self._rhs(item) + lhs = self.lhs.expr(item) + rhs = self.rhs.expr(item) return lhs != rhs @@ -390,8 +915,8 @@ class OpLessThanOrEqual(Op): OP = '<=' def expr(self, item): - lhs = self._lhs(item) - rhs = self._rhs(item) + lhs = self.lhs.expr(item) + rhs = self.rhs.expr(item) return lhs <= rhs @@ -399,8 +924,8 @@ class OpGreaterThanOrEqual(Op): OP = '>=' def expr(self, item): - lhs = self._lhs(item) - rhs = self._rhs(item) + lhs = self.lhs.expr(item) + rhs = self.rhs.expr(item) return lhs >= rhs @@ -408,8 +933,8 @@ class OpOr(Op): OP = 'OR' def expr(self, item): - lhs = self._lhs(item) - rhs = self._rhs(item) + lhs = self.lhs.expr(item) + rhs = self.rhs.expr(item) return lhs or rhs @@ -417,19 +942,38 @@ class OpIn(Op): OP = 'IN' def expr(self, item): - lhs = self._lhs(item) - rhs = self._rhs(item) + lhs = self.lhs.expr(item) + rhs = self.rhs.expr(item) return lhs in rhs +class Func(object): + """ + Base class for a FilterExpression function + """ + FUNC = 'Unknown' + + def __init__(self, *arguments): + self.arguments = arguments + + def expr(self, item): + raise NotImplementedError + + def __repr__(self): + return '{0}({1})'.format( + self.FUNC, + " ".join([repr(arg) for arg in self.arguments])) + + class FuncAttrExists(Func): FUNC = 'attribute_exists' def __init__(self, attribute): self.attr = attribute + super().__init__(attribute) def expr(self, item): - return self.attr in item.attrs + return self.attr.get_type(item) is not None class FuncAttrNotExists(Func): @@ -437,9 +981,10 @@ class FuncAttrNotExists(Func): def __init__(self, attribute): self.attr = attribute + super().__init__(attribute) def expr(self, item): - return self.attr not in item.attrs + return self.attr.get_type(item) is None class FuncAttrType(Func): @@ -448,9 +993,10 @@ class FuncAttrType(Func): def __init__(self, attribute, _type): self.attr = attribute self.type = _type + super().__init__(attribute, _type) def expr(self, item): - return self.attr in item.attrs and item.attrs[self.attr].type == self.type + return self.attr.get_type(item) == self.type.expr(item) class FuncBeginsWith(Func): @@ -459,9 +1005,14 @@ class FuncBeginsWith(Func): def __init__(self, attribute, substr): self.attr = attribute self.substr = substr + super().__init__(attribute, substr) def expr(self, item): - return self.attr in item.attrs and item.attrs[self.attr].type == 'S' and item.attrs[self.attr].value.startswith(self.substr) + if self.attr.get_type(item) != 'S': + return False + if self.substr.get_type(item) != 'S': + return False + return self.attr.expr(item).startswith(self.substr.expr(item)) class FuncContains(Func): @@ -470,13 +1021,11 @@ class FuncContains(Func): def __init__(self, attribute, operand): self.attr = attribute self.operand = operand + super().__init__(attribute, operand) def expr(self, item): - if self.attr not in item.attrs: - return False - - if item.attrs[self.attr].type in ('S', 'SS', 'NS', 'BS', 'L', 'M'): - return self.operand in item.attrs[self.attr].value + if self.attr.get_type(item) in ('S', 'SS', 'NS', 'BS', 'L', 'M'): + return self.operand.expr(item) in self.attr.expr(item) return False @@ -485,29 +1034,44 @@ class FuncSize(Func): def __init__(self, attribute): self.attr = attribute + super().__init__(attribute) def expr(self, item): - if self.attr not in item.attrs: + if self.attr.get_type(item) is None: raise ValueError('Invalid attribute name {0}'.format(self.attr)) - if item.attrs[self.attr].type in ('S', 'SS', 'NS', 'B', 'BS', 'L', 'M'): - return len(item.attrs[self.attr].value) + if self.attr.get_type(item) in ('S', 'SS', 'NS', 'B', 'BS', 'L', 'M'): + return len(self.attr.expr(item)) raise ValueError('Invalid filter expression') class FuncBetween(Func): - FUNC = 'between' + FUNC = 'BETWEEN' def __init__(self, attribute, start, end): self.attr = attribute self.start = start self.end = end + super().__init__(attribute, start, end) def expr(self, item): - if self.attr not in item.attrs: - raise ValueError('Invalid attribute name {0}'.format(self.attr)) + return self.start.expr(item) <= self.attr.expr(item) <= self.end.expr(item) - return self.start <= item.attrs[self.attr].cast_value <= self.end + +class FuncIn(Func): + FUNC = 'IN' + + def __init__(self, attribute, *possible_values): + self.attr = attribute + self.possible_values = possible_values + super().__init__(attribute, *possible_values) + + def expr(self, item): + for possible_value in self.possible_values: + if self.attr.expr(item) == possible_value.expr(item): + return True + + return False OP_CLASS = { diff --git a/moto/dynamodb2/condition.py b/moto/dynamodb2/condition.py new file mode 100644 index 00000000..b50678e2 --- /dev/null +++ b/moto/dynamodb2/condition.py @@ -0,0 +1,617 @@ +import re +import json +import enum +from collections import deque +from collections import namedtuple + + +class Kind(enum.Enum): + """Defines types of nodes in the syntax tree.""" + + # Condition nodes + # --------------- + OR = enum.auto() + AND = enum.auto() + NOT = enum.auto() + PARENTHESES = enum.auto() + FUNCTION = enum.auto() + BETWEEN = enum.auto() + IN = enum.auto() + COMPARISON = enum.auto() + + # Operand nodes + # ------------- + EXPRESSION_ATTRIBUTE_VALUE = enum.auto() + PATH = enum.auto() + + # Literal nodes + # -------------- + LITERAL = enum.auto() + + +class Nonterminal(enum.Enum): + """Defines nonterminals for defining productions.""" + CONDITION = enum.auto() + OPERAND = enum.auto() + COMPARATOR = enum.auto() + FUNCTION_NAME = enum.auto() + IDENTIFIER = enum.auto() + AND = enum.auto() + OR = enum.auto() + NOT = enum.auto() + BETWEEN = enum.auto() + IN = enum.auto() + COMMA = enum.auto() + LEFT_PAREN = enum.auto() + RIGHT_PAREN = enum.auto() + WHITESPACE = enum.auto() + + +Node = namedtuple('Node', ['nonterminal', 'kind', 'text', 'value', 'children']) + + +class ConditionExpressionParser: + def __init__(self, condition_expression, expression_attribute_names, + expression_attribute_values): + self.condition_expression = condition_expression + self.expression_attribute_names = expression_attribute_names + self.expression_attribute_values = expression_attribute_values + + def parse(self): + """Returns a syntax tree for the expression. + + The tree, and all of the nodes in the tree are a tuple of + - kind: str + - children/value: + list of nodes for parent nodes + value for leaf nodes + + Raises AssertionError if the condition expression is invalid + Raises KeyError if expression attribute names/values are invalid + + Here are the types of nodes that can be returned. + The types of child nodes are denoted with a colon (:). + An arbitrary number of children is denoted with ... + + Condition: + ('OR', [lhs : Condition, rhs : Condition]) + ('AND', [lhs: Condition, rhs: Condition]) + ('NOT', [argument: Condition]) + ('PARENTHESES', [argument: Condition]) + ('FUNCTION', [('LITERAL', function_name: str), argument: Operand, ...]) + ('BETWEEN', [query: Operand, low: Operand, high: Operand]) + ('IN', [query: Operand, possible_value: Operand, ...]) + ('COMPARISON', [lhs: Operand, ('LITERAL', comparator: str), rhs: Operand]) + + Operand: + ('EXPRESSION_ATTRIBUTE_VALUE', value: dict, e.g. {'S': 'foobar'}) + ('PATH', [('LITERAL', path_element: str), ...]) + NOTE: Expression attribute names will be expanded + + Literal: + ('LITERAL', value: str) + + """ + if not self.condition_expression: + return None + nodes = self._lex_condition_expression() + nodes = self._parse_paths(nodes) + self._print_debug(nodes) + nodes = self._apply_comparator(nodes) + self._print_debug(nodes) + nodes = self._apply_in(nodes) + self._print_debug(nodes) + nodes = self._apply_between(nodes) + self._print_debug(nodes) + nodes = self._apply_functions(nodes) + self._print_debug(nodes) + nodes = self._apply_parens_and_booleans(nodes) + self._print_debug(nodes) + node = nodes[0] + return self._make_node_tree(node) + + def _lex_condition_expression(self): + nodes = deque() + remaining_expression = self.condition_expression + while remaining_expression: + node, remaining_expression = \ + self._lex_one_node(remaining_expression) + if node.nonterminal == Nonterminal.WHITESPACE: + continue + nodes.append(node) + return nodes + + def _lex_one_node(self, remaining_expression): + + attribute_regex = '(:|#)?[A-z0-9\-_]+' + patterns = [( + Nonterminal.WHITESPACE, re.compile('^ +') + ), ( + Nonterminal.COMPARATOR, re.compile( + '^(' + '=|' + '<>|' + '<|' + '<=|' + '>|' + '>=)'), + ), ( + Nonterminal.OPERAND, re.compile( + '^' + + attribute_regex + '(\.' + attribute_regex + ')*') + ), ( + Nonterminal.COMMA, re.compile('^,') + ), ( + Nonterminal.LEFT_PAREN, re.compile('^\(') + ), ( + Nonterminal.RIGHT_PAREN, re.compile('^\)') + )] + + for nonterminal, pattern in patterns: + match = pattern.match(remaining_expression) + if match: + match_text = match.group() + break + else: + raise AssertionError("Cannot parse condition starting at: " + + remaining_expression) + + value = match_text + node = Node( + nonterminal=nonterminal, + kind=Kind.LITERAL, + text=match_text, + value=match_text, + children=[]) + + remaining_expression = remaining_expression[len(match_text):] + + return node, remaining_expression + + def _parse_paths(self, nodes): + output = deque() + + while nodes: + node = nodes.popleft() + + if node.nonterminal == Nonterminal.OPERAND: + path = node.value.split('.') + children = [ + self._parse_path_element(name) + for name in path] + if len(children) == 1: + child = children[0] + if child.nonterminal != Nonterminal.IDENTIFIER: + output.append(child) + continue + else: + for child in children: + self._assert( + child.nonterminal == Nonterminal.IDENTIFIER, + "Cannot use %s in path" % child.text, [node]) + output.append(Node( + nonterminal=Nonterminal.OPERAND, + kind=Kind.PATH, + text=node.text, + value=None, + children=children)) + else: + output.append(node) + return output + + def _parse_path_element(self, name): + reserved = { + 'AND': Nonterminal.AND, + 'OR': Nonterminal.OR, + 'IN': Nonterminal.IN, + 'BETWEEN': Nonterminal.BETWEEN, + 'NOT': Nonterminal.NOT, + } + + functions = { + 'attribute_exists', + 'attribute_not_exists', + 'attribute_type', + 'begins_with', + 'contains', + 'size', + } + + + if name in reserved: + nonterminal = reserved[name] + return Node( + nonterminal=nonterminal, + kind=Kind.LITERAL, + text=name, + value=name, + children=[]) + elif name in functions: + return Node( + nonterminal=Nonterminal.FUNCTION_NAME, + kind=Kind.LITERAL, + text=name, + value=name, + children=[]) + elif name.startswith(':'): + return Node( + nonterminal=Nonterminal.OPERAND, + kind=Kind.EXPRESSION_ATTRIBUTE_VALUE, + text=name, + value=self._lookup_expression_attribute_value(name), + children=[]) + elif name.startswith('#'): + return Node( + nonterminal=Nonterminal.IDENTIFIER, + kind=Kind.LITERAL, + text=name, + value=self._lookup_expression_attribute_name(name), + children=[]) + else: + return Node( + nonterminal=Nonterminal.IDENTIFIER, + kind=Kind.LITERAL, + text=name, + value=name, + children=[]) + + def _lookup_expression_attribute_value(self, name): + return self.expression_attribute_values[name] + + def _lookup_expression_attribute_name(self, name): + return self.expression_attribute_names[name] + + # NOTE: The following constructions are ordered from high precedence to low precedence + # according to + # https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Expressions.OperatorsAndFunctions.html#Expressions.OperatorsAndFunctions.Precedence + # + # = <> < <= > >= + # IN + # BETWEEN + # attribute_exists attribute_not_exists begins_with contains + # Parentheses + # NOT + # AND + # OR + # + # The grammar is taken from + # https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Expressions.OperatorsAndFunctions.html#Expressions.OperatorsAndFunctions.Syntax + # + # condition-expression ::= + # operand comparator operand + # operand BETWEEN operand AND operand + # operand IN ( operand (',' operand (, ...) )) + # function + # condition AND condition + # condition OR condition + # NOT condition + # ( condition ) + # + # comparator ::= + # = + # <> + # < + # <= + # > + # >= + # + # function ::= + # attribute_exists (path) + # attribute_not_exists (path) + # attribute_type (path, type) + # begins_with (path, substr) + # contains (path, operand) + # size (path) + + def _matches(self, nodes, production): + """Check if the nodes start with the given production. + + Parameters + ---------- + nodes: list of Node + production: list of str + The name of a Nonterminal, or '*' for anything + + """ + if len(nodes) < len(production): + return False + for i in range(len(production)): + if production[i] == '*': + continue + expected = getattr(Nonterminal, production[i]) + if nodes[i].nonterminal != expected: + return False + return True + + def _apply_comparator(self, nodes): + """Apply condition := operand comparator operand.""" + output = deque() + + while nodes: + if self._matches(nodes, ['*', 'COMPARATOR']): + self._assert( + self._matches(nodes, ['OPERAND', 'COMPARATOR', 'OPERAND']), + "Bad comparison", list(nodes)[:3]) + lhs = nodes.popleft() + comparator = nodes.popleft() + rhs = nodes.popleft() + output.append(Node( + nonterminal=Nonterminal.CONDITION, + kind=Kind.COMPARISON, + text=" ".join([ + lhs.text, + comparator.text, + rhs.text]), + value=None, + children=[lhs, comparator, rhs])) + else: + output.append(nodes.popleft()) + return output + + def _apply_in(self, nodes): + """Apply condition := operand IN ( operand , ... ).""" + output = deque() + while nodes: + if self._matches(nodes, ['*', 'IN']): + self._assert( + self._matches(nodes, ['OPERAND', 'IN', 'LEFT_PAREN']), + "Bad IN expression", list(nodes)[:3]) + lhs = nodes.popleft() + in_node = nodes.popleft() + left_paren = nodes.popleft() + all_children = [lhs, in_node, left_paren] + rhs = [] + while True: + if self._matches(nodes, ['OPERAND', 'COMMA']): + operand = nodes.popleft() + separator = nodes.popleft() + all_children += [operand, separator] + rhs.append(operand) + elif self._matches(nodes, ['OPERAND', 'RIGHT_PAREN']): + operand = nodes.popleft() + separator = nodes.popleft() + all_children += [operand, separator] + rhs.append(operand) + break # Close + else: + self._assert( + False, + "Bad IN expression starting at", nodes) + output.append(Node( + nonterminal=Nonterminal.CONDITION, + kind=Kind.IN, + text=" ".join([t.text for t in all_children]), + value=None, + children=[lhs] + rhs)) + else: + output.append(nodes.popleft()) + return output + + def _apply_between(self, nodes): + """Apply condition := operand BETWEEN operand AND operand.""" + output = deque() + while nodes: + if self._matches(nodes, ['*', 'BETWEEN']): + self._assert( + self._matches(nodes, ['OPERAND', 'BETWEEN', 'OPERAND', + 'AND', 'OPERAND']), + "Bad BETWEEN expression", list(nodes)[:5]) + lhs = nodes.popleft() + between_node = nodes.popleft() + low = nodes.popleft() + and_node = nodes.popleft() + high = nodes.popleft() + all_children = [lhs, between_node, low, and_node, high] + output.append(Node( + nonterminal=Nonterminal.CONDITION, + kind=Kind.BETWEEN, + text=" ".join([t.text for t in all_children]), + value=None, + children=[lhs, low, high])) + else: + output.append(nodes.popleft()) + return output + + def _apply_functions(self, nodes): + """Apply condition := function_name (operand , ...).""" + output = deque() + expected_argument_kind_map = { + 'attribute_exists': [{Kind.PATH}], + 'attribute_not_exists': [{Kind.PATH}], + 'attribute_type': [{Kind.PATH}, {Kind.EXPRESSION_ATTRIBUTE_VALUE}], + 'begins_with': [{Kind.PATH}, {Kind.EXPRESSION_ATTRIBUTE_VALUE}], + 'contains': [{Kind.PATH}, {Kind.PATH, Kind.EXPRESSION_ATTRIBUTE_VALUE}], + 'size': [{Kind.PATH}], + } + while nodes: + if self._matches(nodes, ['FUNCTION_NAME']): + self._assert( + self._matches(nodes, ['FUNCTION_NAME', 'LEFT_PAREN', + 'OPERAND', '*']), + "Bad function expression at", list(nodes)[:4]) + function_name = nodes.popleft() + left_paren = nodes.popleft() + all_children = [function_name, left_paren] + arguments = [] + while True: + if self._matches(nodes, ['OPERAND', 'COMMA']): + operand = nodes.popleft() + separator = nodes.popleft() + all_children += [operand, separator] + arguments.append(operand) + elif self._matches(nodes, ['OPERAND', 'RIGHT_PAREN']): + operand = nodes.popleft() + separator = nodes.popleft() + all_children += [operand, separator] + arguments.append(operand) + break # Close paren + else: + self._assert( + False, + "Bad function expression", all_children + list(nodes)[:2]) + expected_kinds = expected_argument_kind_map[function_name.value] + self._assert( + len(arguments) == len(expected_kinds), + "Wrong number of arguments in", all_children) + for i in range(len(expected_kinds)): + self._assert( + arguments[i].kind in expected_kinds[i], + "Wrong type for argument %d in" % i, all_children) + output.append(Node( + nonterminal=Nonterminal.CONDITION, + kind=Kind.FUNCTION, + text=" ".join([t.text for t in all_children]), + value=None, + children=[function_name] + arguments)) + else: + output.append(nodes.popleft()) + return output + + def _apply_parens_and_booleans(self, nodes, left_paren=None): + """Apply condition := ( condition ) and booleans.""" + output = deque() + while nodes: + if self._matches(nodes, ['LEFT_PAREN']): + parsed = self._apply_parens_and_booleans(nodes, left_paren=nodes.popleft()) + self._assert( + len(parsed) >= 1, + "Failed to close parentheses at", nodes) + parens = parsed.popleft() + self._assert( + parens.kind == Kind.PARENTHESES, + "Failed to close parentheses at", nodes) + output.append(parens) + nodes = parsed + elif self._matches(nodes, ['RIGHT_PAREN']): + self._assert( + left_paren is not None, + "Unmatched ) at", nodes) + close_paren = nodes.popleft() + children = self._apply_booleans(output) + all_children = [left_paren, *children, close_paren] + return deque([ + Node( + nonterminal=Nonterminal.CONDITION, + kind=Kind.PARENTHESES, + text=" ".join([t.text for t in all_children]), + value=None, + children=list(children), + ), *nodes]) + else: + output.append(nodes.popleft()) + + self._assert( + left_paren is None, + "Unmatched ( at", list(output)) + return self._apply_booleans(output) + + def _apply_booleans(self, nodes): + """Apply and, or, and not constructions.""" + nodes = self._apply_not(nodes) + nodes = self._apply_and(nodes) + nodes = self._apply_or(nodes) + # The expression should reduce to a single condition + self._assert( + len(nodes) == 1, + "Unexpected expression at", list(nodes)[1:]) + self._assert( + nodes[0].nonterminal == Nonterminal.CONDITION, + "Incomplete condition", nodes) + return nodes + + def _apply_not(self, nodes): + """Apply condition := NOT condition.""" + output = deque() + while nodes: + if self._matches(nodes, ['NOT']): + self._assert( + self._matches(nodes, ['NOT', 'CONDITION']), + "Bad NOT expression", list(nodes)[:2]) + not_node = nodes.popleft() + child = nodes.popleft() + output.append(Node( + nonterminal=Nonterminal.CONDITION, + kind=Kind.NOT, + text=" ".join([not_node['text'], value['text']]), + value=None, + children=[child])) + else: + output.append(nodes.popleft()) + + return output + + def _apply_and(self, nodes): + """Apply condition := condition AND condition.""" + output = deque() + while nodes: + if self._matches(nodes, ['*', 'AND']): + self._assert( + self._matches(nodes, ['CONDITION', 'AND', 'CONDITION']), + "Bad AND expression", list(nodes)[:3]) + lhs = nodes.popleft() + and_node = nodes.popleft() + rhs = nodes.popleft() + all_children = [lhs, and_node, rhs] + output.append(Node( + nonterminal=Nonterminal.CONDITION, + kind=Kind.AND, + text=" ".join([t.text for t in all_children]), + value=None, + children=[lhs, rhs])) + else: + output.append(nodes.popleft()) + + return output + + def _apply_or(self, nodes): + """Apply condition := condition OR condition.""" + output = deque() + while nodes: + if self._matches(nodes, ['*', 'OR']): + self._assert( + self._matches(nodes, ['CONDITION', 'OR', 'CONDITION']), + "Bad OR expression", list(nodes)[:3]) + lhs = nodes.popleft() + or_node = nodes.popleft() + rhs = nodes.popleft() + all_children = [lhs, or_node, rhs] + output.append(Node( + nonterminal=Nonterminal.CONDITION, + kind=Kind.OR, + text=" ".join([t.text for t in all_children]), + value=None, + children=[lhs, rhs])) + else: + output.append(nodes.popleft()) + + return output + + def _make_node_tree(self, node): + if len(node.children) > 0: + return ( + node.kind.name, + [ + self._make_node_tree(child) + for child in node.children + ]) + else: + return (node.kind.name, node.value) + + def _print_debug(self, nodes): + print('ROOT') + for node in nodes: + self._print_node_recursive(node, depth=1) + + def _print_node_recursive(self, node, depth=0): + if len(node.children) > 0: + print(' ' * depth, node.nonterminal, node.kind) + for child in node.children: + self._print_node_recursive(child, depth=depth + 1) + else: + print(' ' * depth, node.nonterminal, node.kind, node.value) + + + + def _assert(self, condition, message, nodes): + if not condition: + raise AssertionError(message + " " + " ".join([t.text for t in nodes])) diff --git a/moto/dynamodb2/models.py b/moto/dynamodb2/models.py index 6bcde41b..300479e9 100644 --- a/moto/dynamodb2/models.py +++ b/moto/dynamodb2/models.py @@ -68,10 +68,34 @@ class DynamoType(object): except ValueError: return float(self.value) elif self.is_set(): - return set(self.value) + sub_type = self.type[0] + return set([DynamoType({sub_type: v}).cast_value + for v in self.value]) + elif self.is_list(): + return [DynamoType(v).cast_value for v in self.value] + elif self.is_map(): + return dict([ + (k, DynamoType(v).cast_value) + for k, v in self.value.items()]) else: return self.value + def child_attr(self, key): + """ + Get Map or List children by key. str for Map, int for List. + + Returns DynamoType or None. + """ + if isinstance(key, str) and self.is_map() and key in self.value: + return DynamoType(self.value[key]) + + if isinstance(key, int) and self.is_list(): + idx = key + if idx >= 0 and idx < len(self.value): + return DynamoType(self.value[idx]) + + return None + def to_json(self): return {self.type: self.value} @@ -89,6 +113,12 @@ class DynamoType(object): def is_set(self): return self.type == 'SS' or self.type == 'NS' or self.type == 'BS' + def is_list(self): + return self.type == 'L' + + def is_map(self): + return self.type == 'M' + def same_type(self, other): return self.type == other.type @@ -954,10 +984,7 @@ class DynamoDBBackend(BaseBackend): range_values = [DynamoType(range_value) for range_value in range_value_dicts] - if filter_expression is not None: - filter_expression = get_filter_expression(filter_expression, expr_names, expr_values) - else: - filter_expression = Op(None, None) # Will always eval to true + filter_expression = get_filter_expression(filter_expression, expr_names, expr_values) return table.query(hash_key, range_comparison, range_values, limit, exclusive_start_key, scan_index_forward, projection_expression, index_name, filter_expression, **filter_kwargs) @@ -972,10 +999,8 @@ class DynamoDBBackend(BaseBackend): dynamo_types = [DynamoType(value) for value in comparison_values] scan_filters[key] = (comparison_operator, dynamo_types) - if filter_expression is not None: - filter_expression = get_filter_expression(filter_expression, expr_names, expr_values) - else: - filter_expression = Op(None, None) # Will always eval to true + + filter_expression = get_filter_expression(filter_expression, expr_names, expr_values) return table.scan(scan_filters, limit, exclusive_start_key, filter_expression, index_name) diff --git a/tests/test_dynamodb2/test_dynamodb.py b/tests/test_dynamodb2/test_dynamodb.py index 77846de0..932139ee 100644 --- a/tests/test_dynamodb2/test_dynamodb.py +++ b/tests/test_dynamodb2/test_dynamodb.py @@ -676,44 +676,47 @@ def test_filter_expression(): filter_expr.expr(row1).should.be(True) # NOT test 2 - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('NOT (Id = :v0)', {}, {':v0': {'N': 8}}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression('NOT (Id = :v0)', {}, {':v0': {'N': '8'}}) filter_expr.expr(row1).should.be(False) # Id = 8 so should be false # AND test - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id > :v0 AND Subs < :v1', {}, {':v0': {'N': 5}, ':v1': {'N': 7}}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id > :v0 AND Subs < :v1', {}, {':v0': {'N': '5'}, ':v1': {'N': '7'}}) filter_expr.expr(row1).should.be(True) filter_expr.expr(row2).should.be(False) # OR test - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id = :v0 OR Id=:v1', {}, {':v0': {'N': 5}, ':v1': {'N': 8}}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id = :v0 OR Id=:v1', {}, {':v0': {'N': '5'}, ':v1': {'N': '8'}}) filter_expr.expr(row1).should.be(True) # BETWEEN test - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id BETWEEN :v0 AND :v1', {}, {':v0': {'N': 5}, ':v1': {'N': 10}}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id BETWEEN :v0 AND :v1', {}, {':v0': {'N': '5'}, ':v1': {'N': '10'}}) filter_expr.expr(row1).should.be(True) # PAREN test - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id = :v0 AND (Subs = :v0 OR Subs = :v1)', {}, {':v0': {'N': 8}, ':v1': {'N': 5}}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id = :v0 AND (Subs = :v0 OR Subs = :v1)', {}, {':v0': {'N': '8'}, ':v1': {'N': '5'}}) filter_expr.expr(row1).should.be(True) # IN test - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id IN :v0', {}, {':v0': {'NS': [7, 8, 9]}}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id IN (:v0, :v1, :v2)', {}, { + ':v0': {'N': '7'}, + ':v1': {'N': '8'}, + ':v2': {'N': '9'}}) filter_expr.expr(row1).should.be(True) # attribute function tests (with extra spaces) filter_expr = moto.dynamodb2.comparisons.get_filter_expression('attribute_exists(Id) AND attribute_not_exists (User)', {}, {}) filter_expr.expr(row1).should.be(True) - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('attribute_type(Id, N)', {}, {}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression('attribute_type(Id, :v0)', {}, {':v0': {'S': 'N'}}) filter_expr.expr(row1).should.be(True) # beginswith function test - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('begins_with(Desc, Some)', {}, {}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression('begins_with(Desc, :v0)', {}, {':v0': {'S': 'Some'}}) filter_expr.expr(row1).should.be(True) filter_expr.expr(row2).should.be(False) # contains function test - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('contains(KV, test1)', {}, {}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression('contains(KV, :v0)', {}, {':v0': {'S': 'test1'}}) filter_expr.expr(row1).should.be(True) filter_expr.expr(row2).should.be(False) @@ -754,14 +757,26 @@ def test_query_filter(): TableName='test1', Item={ 'client': {'S': 'client1'}, - 'app': {'S': 'app1'} + 'app': {'S': 'app1'}, + 'nested': {'M': { + 'version': {'S': 'version1'}, + 'contents': {'L': [ + {'S': 'value1'}, {'S': 'value2'}, + ]}, + }}, } ) client.put_item( TableName='test1', Item={ 'client': {'S': 'client1'}, - 'app': {'S': 'app2'} + 'app': {'S': 'app2'}, + 'nested': {'M': { + 'version': {'S': 'version2'}, + 'contents': {'L': [ + {'S': 'value1'}, {'S': 'value2'}, + ]}, + }}, } ) @@ -783,6 +798,18 @@ def test_query_filter(): ) assert response['Count'] == 2 + response = table.query( + KeyConditionExpression=Key('client').eq('client1'), + FilterExpression=Attr('nested.version').contains('version') + ) + assert response['Count'] == 2 + + response = table.query( + KeyConditionExpression=Key('client').eq('client1'), + FilterExpression=Attr('nested.contents[0]').eq('value1') + ) + assert response['Count'] == 2 + @mock_dynamodb2 def test_scan_filter(): @@ -1061,7 +1088,7 @@ def test_delete_item(): with assert_raises(ClientError) as ex: table.delete_item(Key={'client': 'client1', 'app': 'app1'}, ReturnValues='ALL_NEW') - + # Test deletion and returning old value response = table.delete_item(Key={'client': 'client1', 'app': 'app1'}, ReturnValues='ALL_OLD') response['Attributes'].should.contain('client') @@ -1364,7 +1391,7 @@ def test_put_return_attributes(): ReturnValues='NONE' ) assert 'Attributes' not in r - + r = dynamodb.put_item( TableName='moto-test', Item={'id': {'S': 'foo'}, 'col1': {'S': 'val2'}}, @@ -1381,7 +1408,7 @@ def test_put_return_attributes(): ex.exception.response['Error']['Code'].should.equal('ValidationException') ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) ex.exception.response['Error']['Message'].should.equal('Return values set to invalid value') - + @mock_dynamodb2 def test_query_global_secondary_index_when_created_via_update_table_resource(): @@ -1489,7 +1516,7 @@ def test_dynamodb_streams_1(): 'StreamViewType': 'NEW_AND_OLD_IMAGES' } ) - + assert 'StreamSpecification' in resp['TableDescription'] assert resp['TableDescription']['StreamSpecification'] == { 'StreamEnabled': True, @@ -1497,11 +1524,11 @@ def test_dynamodb_streams_1(): } assert 'LatestStreamLabel' in resp['TableDescription'] assert 'LatestStreamArn' in resp['TableDescription'] - + resp = conn.delete_table(TableName='test-streams') assert 'StreamSpecification' in resp['TableDescription'] - + @mock_dynamodb2 def test_dynamodb_streams_2(): @@ -1532,7 +1559,7 @@ def test_dynamodb_streams_2(): assert 'LatestStreamLabel' in resp['TableDescription'] assert 'LatestStreamArn' in resp['TableDescription'] - + @mock_dynamodb2 def test_condition_expressions(): client = boto3.client('dynamodb', region_name='us-east-1') From 271265451822d7ca7c1a9d68386f84b5b7323aeb Mon Sep 17 00:00:00 2001 From: Matthew Stevens Date: Mon, 1 Apr 2019 16:23:49 -0400 Subject: [PATCH 09/24] Using Ops for dynamodb expected dicts --- moto/dynamodb2/comparisons.py | 122 ++++++++++++++++++++++++++-------- moto/dynamodb2/models.py | 52 ++------------- 2 files changed, 101 insertions(+), 73 deletions(-) diff --git a/moto/dynamodb2/comparisons.py b/moto/dynamodb2/comparisons.py index ac78d45b..06d99260 100644 --- a/moto/dynamodb2/comparisons.py +++ b/moto/dynamodb2/comparisons.py @@ -19,6 +19,63 @@ def get_filter_expression(expr, names, values): return parser.parse() +def get_expected(expected): + """ + Parse a filter expression into an Op. + + Examples + expr = 'Id > 5 AND attribute_exists(test) AND Id BETWEEN 5 AND 6 OR length < 6 AND contains(test, 1) AND 5 IN (4,5, 6) OR (Id < 5 AND 5 > Id)' + expr = 'Id > 5 AND Subs < 7' + """ + ops = { + 'EQ': OpEqual, + 'NE': OpNotEqual, + 'LE': OpLessThanOrEqual, + 'LT': OpLessThan, + 'GE': OpGreaterThanOrEqual, + 'GT': OpGreaterThan, + 'NOT_NULL': FuncAttrExists, + 'NULL': FuncAttrNotExists, + 'CONTAINS': FuncContains, + 'NOT_CONTAINS': FuncNotContains, + 'BEGINS_WITH': FuncBeginsWith, + 'IN': FuncIn, + 'BETWEEN': FuncBetween, + } + + # NOTE: Always uses ConditionalOperator=AND + conditions = [] + for key, cond in expected.items(): + path = AttributePath([key]) + if 'Exists' in cond: + if cond['Exists']: + conditions.append(FuncAttrExists(path)) + else: + conditions.append(FuncAttrNotExists(path)) + elif 'Value' in cond: + conditions.append(OpEqual(path, AttributeValue(cond['Value']))) + elif 'ComparisonOperator' in cond: + operator_name = cond['ComparisonOperator'] + values = [ + AttributeValue(v) + for v in cond.get("AttributeValueList", [])] + print(path, values) + OpClass = ops[operator_name] + conditions.append(OpClass(path, *values)) + + # NOTE: Ignore ConditionalOperator + ConditionalOp = OpAnd + if conditions: + output = conditions[0] + for condition in conditions[1:]: + output = ConditionalOp(output, condition) + else: + return OpDefault(None, None) + + print("EXPECTED:", expected, output) + return output + + class Op(object): """ Base class for a FilterExpression operator @@ -782,14 +839,19 @@ class AttributePath(Operand): self.path = path def _get_attr(self, item): + if item is None: + return None + base = self.path[0] if base not in item.attrs: return None attr = item.attrs[base] + for name in self.path[1:]: attr = attr.child_attr(name) if attr is None: return None + return attr def expr(self, item): @@ -807,7 +869,7 @@ class AttributePath(Operand): return attr.type def __repr__(self): - return self.path + return ".".join(self.path) class AttributeValue(Operand): @@ -821,23 +883,27 @@ class AttributeValue(Operand): """ self.type = list(value.keys())[0] - if 'N' in value: - self.value = float(value['N']) - elif 'BOOL' in value: - self.value = value['BOOL'] - elif 'S' in value: - self.value = value['S'] - elif 'NS' in value: - self.value = tuple(value['NS']) - elif 'SS' in value: - self.value = tuple(value['SS']) - elif 'L' in value: - self.value = tuple(value['L']) - else: - # TODO: Handle all attribute types - raise NotImplementedError() + self.value = value[self.type] def expr(self, item): + # TODO: Reuse DynamoType code + if self.type == 'N': + try: + return int(self.value) + except ValueError: + return float(self.value) + elif self.type in ['SS', 'NS', 'BS']: + sub_type = self.type[0] + return set([AttributeValue({sub_type: v}).expr(item) + for v in self.value]) + elif self.type == 'L': + return [AttributeValue(v).expr(item) for v in self.value] + elif self.type == 'M': + return dict([ + (k, AttributeValue(v).expr(item)) + for k, v in self.value.items()]) + else: + return self.value return self.value def get_type(self, item): @@ -976,15 +1042,8 @@ class FuncAttrExists(Func): return self.attr.get_type(item) is not None -class FuncAttrNotExists(Func): - FUNC = 'attribute_not_exists' - - def __init__(self, attribute): - self.attr = attribute - super().__init__(attribute) - - def expr(self, item): - return self.attr.get_type(item) is None +def FuncAttrNotExists(attribute): + return OpNot(FuncAttrExists(attribute), None) class FuncAttrType(Func): @@ -1024,13 +1083,20 @@ class FuncContains(Func): super().__init__(attribute, operand) def expr(self, item): - if self.attr.get_type(item) in ('S', 'SS', 'NS', 'BS', 'L', 'M'): - return self.operand.expr(item) in self.attr.expr(item) + if self.attr.get_type(item) in ('S', 'SS', 'NS', 'BS', 'L'): + try: + return self.operand.expr(item) in self.attr.expr(item) + except TypeError: + return False return False +def FuncNotContains(attribute, operand): + return OpNot(FuncContains(attribute, operand), None) + + class FuncSize(Func): - FUNC = 'contains' + FUNC = 'size' def __init__(self, attribute): self.attr = attribute diff --git a/moto/dynamodb2/models.py b/moto/dynamodb2/models.py index 300479e9..bdf59df1 100644 --- a/moto/dynamodb2/models.py +++ b/moto/dynamodb2/models.py @@ -13,6 +13,9 @@ from moto.core import BaseBackend, BaseModel from moto.core.utils import unix_time from moto.core.exceptions import JsonRESTError from .comparisons import get_comparison_func, get_filter_expression, Op +from .comparisons import get_comparison_func +from .comparisons import get_filter_expression +from .comparisons import get_expected from .exceptions import InvalidIndexNameError @@ -557,29 +560,9 @@ class Table(BaseModel): self.range_key_type, item_attrs) if not overwrite: - if current is None: - current_attr = {} - elif hasattr(current, 'attrs'): - current_attr = current.attrs - else: - current_attr = current + if not get_expected(expected).expr(current): + raise ValueError('The conditional request failed') - for key, val in expected.items(): - if 'Exists' in val and val['Exists'] is False \ - or 'ComparisonOperator' in val and val['ComparisonOperator'] == 'NULL': - if key in current_attr: - raise ValueError("The conditional request failed") - elif key not in current_attr: - raise ValueError("The conditional request failed") - elif 'Value' in val and DynamoType(val['Value']).value != current_attr[key].value: - raise ValueError("The conditional request failed") - elif 'ComparisonOperator' in val: - dynamo_types = [ - DynamoType(ele) for ele in - val.get("AttributeValueList", []) - ] - if not current_attr[key].compare(val['ComparisonOperator'], dynamo_types): - raise ValueError('The conditional request failed') if range_value: self.items[hash_value][range_value] = item else: @@ -1024,32 +1007,11 @@ class DynamoDBBackend(BaseBackend): item = table.get_item(hash_value, range_value) - if item is None: - item_attr = {} - elif hasattr(item, 'attrs'): - item_attr = item.attrs - else: - item_attr = item - if not expected: expected = {} - for key, val in expected.items(): - if 'Exists' in val and val['Exists'] is False \ - or 'ComparisonOperator' in val and val['ComparisonOperator'] == 'NULL': - if key in item_attr: - raise ValueError("The conditional request failed") - elif key not in item_attr: - raise ValueError("The conditional request failed") - elif 'Value' in val and DynamoType(val['Value']).value != item_attr[key].value: - raise ValueError("The conditional request failed") - elif 'ComparisonOperator' in val: - dynamo_types = [ - DynamoType(ele) for ele in - val.get("AttributeValueList", []) - ] - if not item_attr[key].compare(val['ComparisonOperator'], dynamo_types): - raise ValueError('The conditional request failed') + if not get_expected(expected).expr(item): + raise ValueError('The conditional request failed') # Update does not fail on new items, so create one if item is None: From 57b668c8323761b173276607f1263863207ce053 Mon Sep 17 00:00:00 2001 From: Matthew Stevens Date: Mon, 1 Apr 2019 16:48:00 -0400 Subject: [PATCH 10/24] Using Ops for dynamodb condition expressions --- moto/dynamodb2/comparisons.py | 16 ++++++-------- moto/dynamodb2/models.py | 26 ++++++++++++++++++---- moto/dynamodb2/responses.py | 32 ++++++++++++--------------- tests/test_dynamodb2/test_dynamodb.py | 15 +++++++++++++ 4 files changed, 58 insertions(+), 31 deletions(-) diff --git a/moto/dynamodb2/comparisons.py b/moto/dynamodb2/comparisons.py index 06d99260..4095acba 100644 --- a/moto/dynamodb2/comparisons.py +++ b/moto/dynamodb2/comparisons.py @@ -59,7 +59,6 @@ def get_expected(expected): values = [ AttributeValue(v) for v in cond.get("AttributeValueList", [])] - print(path, values) OpClass = ops[operator_name] conditions.append(OpClass(path, *values)) @@ -72,7 +71,6 @@ def get_expected(expected): else: return OpDefault(None, None) - print("EXPECTED:", expected, output) return output @@ -486,7 +484,7 @@ class ConditionExpressionParser: lhs = nodes.popleft() comparator = nodes.popleft() rhs = nodes.popleft() - output.append(self.Node( + nodes.appendleft(self.Node( nonterminal=self.Nonterminal.CONDITION, kind=self.Kind.COMPARISON, text=" ".join([ @@ -528,7 +526,7 @@ class ConditionExpressionParser: self._assert( False, "Bad IN expression starting at", nodes) - output.append(self.Node( + nodes.appendleft(self.Node( nonterminal=self.Nonterminal.CONDITION, kind=self.Kind.IN, text=" ".join([t.text for t in all_children]), @@ -553,7 +551,7 @@ class ConditionExpressionParser: and_node = nodes.popleft() high = nodes.popleft() all_children = [lhs, between_node, low, and_node, high] - output.append(self.Node( + nodes.appendleft(self.Node( nonterminal=self.Nonterminal.CONDITION, kind=self.Kind.BETWEEN, text=" ".join([t.text for t in all_children]), @@ -613,7 +611,7 @@ class ConditionExpressionParser: nonterminal = self.Nonterminal.OPERAND else: nonterminal = self.Nonterminal.CONDITION - output.append(self.Node( + nodes.appendleft(self.Node( nonterminal=nonterminal, kind=self.Kind.FUNCTION, text=" ".join([t.text for t in all_children]), @@ -685,7 +683,7 @@ class ConditionExpressionParser: "Bad NOT expression", list(nodes)[:2]) not_node = nodes.popleft() child = nodes.popleft() - output.append(self.Node( + nodes.appendleft(self.Node( nonterminal=self.Nonterminal.CONDITION, kind=self.Kind.NOT, text=" ".join([not_node.text, child.text]), @@ -708,7 +706,7 @@ class ConditionExpressionParser: and_node = nodes.popleft() rhs = nodes.popleft() all_children = [lhs, and_node, rhs] - output.append(self.Node( + nodes.appendleft(self.Node( nonterminal=self.Nonterminal.CONDITION, kind=self.Kind.AND, text=" ".join([t.text for t in all_children]), @@ -731,7 +729,7 @@ class ConditionExpressionParser: or_node = nodes.popleft() rhs = nodes.popleft() all_children = [lhs, or_node, rhs] - output.append(self.Node( + nodes.appendleft(self.Node( nonterminal=self.Nonterminal.CONDITION, kind=self.Kind.OR, text=" ".join([t.text for t in all_children]), diff --git a/moto/dynamodb2/models.py b/moto/dynamodb2/models.py index bdf59df1..037db3d7 100644 --- a/moto/dynamodb2/models.py +++ b/moto/dynamodb2/models.py @@ -537,7 +537,9 @@ class Table(BaseModel): keys.append(range_key) return keys - def put_item(self, item_attrs, expected=None, overwrite=False): + def put_item(self, item_attrs, expected=None, condition_expression=None, + expression_attribute_names=None, + expression_attribute_values=None, overwrite=False): hash_value = DynamoType(item_attrs.get(self.hash_key_attr)) if self.has_range_key: range_value = DynamoType(item_attrs.get(self.range_key_attr)) @@ -562,6 +564,12 @@ class Table(BaseModel): if not overwrite: if not get_expected(expected).expr(current): raise ValueError('The conditional request failed') + condition_op = get_filter_expression( + condition_expression, + expression_attribute_names, + expression_attribute_values) + if not condition_op.expr(current): + raise ValueError('The conditional request failed') if range_value: self.items[hash_value][range_value] = item @@ -907,11 +915,15 @@ class DynamoDBBackend(BaseBackend): table.global_indexes = list(gsis_by_name.values()) return table - def put_item(self, table_name, item_attrs, expected=None, overwrite=False): + def put_item(self, table_name, item_attrs, expected=None, + condition_expression=None, expression_attribute_names=None, + expression_attribute_values=None, overwrite=False): table = self.tables.get(table_name) if not table: return None - return table.put_item(item_attrs, expected, overwrite) + return table.put_item(item_attrs, expected, condition_expression, + expression_attribute_names, + expression_attribute_values, overwrite) def get_table_keys_name(self, table_name, keys): """ @@ -988,7 +1000,7 @@ class DynamoDBBackend(BaseBackend): return table.scan(scan_filters, limit, exclusive_start_key, filter_expression, index_name) def update_item(self, table_name, key, update_expression, attribute_updates, expression_attribute_names, - expression_attribute_values, expected=None): + expression_attribute_values, expected=None, condition_expression=None): table = self.get_table(table_name) if all([table.hash_key_attr in key, table.range_key_attr in key]): @@ -1012,6 +1024,12 @@ class DynamoDBBackend(BaseBackend): if not get_expected(expected).expr(item): raise ValueError('The conditional request failed') + condition_op = get_filter_expression( + condition_expression, + expression_attribute_names, + expression_attribute_values) + if not condition_op.expr(current): + raise ValueError('The conditional request failed') # Update does not fail on new items, so create one if item is None: diff --git a/moto/dynamodb2/responses.py b/moto/dynamodb2/responses.py index 7eb56574..13dde683 100644 --- a/moto/dynamodb2/responses.py +++ b/moto/dynamodb2/responses.py @@ -288,18 +288,18 @@ class DynamoHandler(BaseResponse): # Attempt to parse simple ConditionExpressions into an Expected # expression - if not expected: - condition_expression = self.body.get('ConditionExpression') - expression_attribute_names = self.body.get('ExpressionAttributeNames', {}) - expression_attribute_values = self.body.get('ExpressionAttributeValues', {}) - expected = condition_expression_to_expected(condition_expression, - expression_attribute_names, - expression_attribute_values) - if expected: - overwrite = False + condition_expression = self.body.get('ConditionExpression') + expression_attribute_names = self.body.get('ExpressionAttributeNames', {}) + expression_attribute_values = self.body.get('ExpressionAttributeValues', {}) + + if condition_expression: + overwrite = False try: - result = self.dynamodb_backend.put_item(name, item, expected, overwrite) + result = self.dynamodb_backend.put_item( + name, item, expected, condition_expression, + expression_attribute_names, expression_attribute_values, + overwrite) except ValueError: er = 'com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException' return self.error(er, 'A condition specified in the operation could not be evaluated.') @@ -652,13 +652,9 @@ class DynamoHandler(BaseResponse): # Attempt to parse simple ConditionExpressions into an Expected # expression - if not expected: - condition_expression = self.body.get('ConditionExpression') - expression_attribute_names = self.body.get('ExpressionAttributeNames', {}) - expression_attribute_values = self.body.get('ExpressionAttributeValues', {}) - expected = condition_expression_to_expected(condition_expression, - expression_attribute_names, - expression_attribute_values) + condition_expression = self.body.get('ConditionExpression') + expression_attribute_names = self.body.get('ExpressionAttributeNames', {}) + expression_attribute_values = self.body.get('ExpressionAttributeValues', {}) # Support spaces between operators in an update expression # E.g. `a = b + c` -> `a=b+c` @@ -669,7 +665,7 @@ class DynamoHandler(BaseResponse): try: item = self.dynamodb_backend.update_item( name, key, update_expression, attribute_updates, expression_attribute_names, - expression_attribute_values, expected + expression_attribute_values, expected, condition_expression ) except ValueError: er = 'com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException' diff --git a/tests/test_dynamodb2/test_dynamodb.py b/tests/test_dynamodb2/test_dynamodb.py index 932139ee..f87e84fb 100644 --- a/tests/test_dynamodb2/test_dynamodb.py +++ b/tests/test_dynamodb2/test_dynamodb.py @@ -1616,6 +1616,21 @@ def test_condition_expressions(): } ) + client.put_item( + TableName='test1', + Item={ + 'client': {'S': 'client1'}, + 'app': {'S': 'app1'}, + 'match': {'S': 'match'}, + 'existing': {'S': 'existing'}, + }, + ConditionExpression='attribute_exists(#nonexistent) OR attribute_exists(#existing)', + ExpressionAttributeNames={ + '#nonexistent': 'nope', + '#existing': 'existing' + } + ) + with assert_raises(client.exceptions.ConditionalCheckFailedException): client.put_item( TableName='test1', From 6fd47f843fb046305d6379b4f791dbe76569f87a Mon Sep 17 00:00:00 2001 From: Matthew Stevens Date: Mon, 1 Apr 2019 17:00:02 -0400 Subject: [PATCH 11/24] Test case for #1819 --- tests/test_dynamodb2/test_dynamodb.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/test_dynamodb2/test_dynamodb.py b/tests/test_dynamodb2/test_dynamodb.py index f87e84fb..d2178205 100644 --- a/tests/test_dynamodb2/test_dynamodb.py +++ b/tests/test_dynamodb2/test_dynamodb.py @@ -1631,6 +1631,24 @@ def test_condition_expressions(): } ) + client.put_item( + TableName='test1', + Item={ + 'client': {'S': 'client1'}, + 'app': {'S': 'app1'}, + 'match': {'S': 'match'}, + 'existing': {'S': 'existing'}, + }, + ConditionExpression='#client BETWEEN :a AND :z', + ExpressionAttributeNames={ + '#client': 'client', + }, + ExpressionAttributeValues={ + ':a': {'S': 'a'}, + ':z': {'S': 'z'}, + } + ) + with assert_raises(client.exceptions.ConditionalCheckFailedException): client.put_item( TableName='test1', From 8a90971ba152a692ac9f17ef346739630136e6ad Mon Sep 17 00:00:00 2001 From: Matthew Stevens Date: Mon, 1 Apr 2019 17:02:14 -0400 Subject: [PATCH 12/24] Adding test cases for #1587 --- tests/test_dynamodb2/test_dynamodb.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/test_dynamodb2/test_dynamodb.py b/tests/test_dynamodb2/test_dynamodb.py index d2178205..0ea1d64e 100644 --- a/tests/test_dynamodb2/test_dynamodb.py +++ b/tests/test_dynamodb2/test_dynamodb.py @@ -1649,6 +1649,24 @@ def test_condition_expressions(): } ) + client.put_item( + TableName='test1', + Item={ + 'client': {'S': 'client1'}, + 'app': {'S': 'app1'}, + 'match': {'S': 'match'}, + 'existing': {'S': 'existing'}, + }, + ConditionExpression='#client IN (:client1, :client2)', + ExpressionAttributeNames={ + '#client': 'client', + }, + ExpressionAttributeValues={ + ':client1': {'S': 'client1'}, + ':client2': {'S': 'client2'}, + } + ) + with assert_raises(client.exceptions.ConditionalCheckFailedException): client.put_item( TableName='test1', From 94503285274e670acba6ac441539aa7153d6915d Mon Sep 17 00:00:00 2001 From: Matthew Stevens Date: Mon, 1 Apr 2019 17:03:58 -0400 Subject: [PATCH 13/24] Deleting unnecessary dynamodb2 file --- moto/dynamodb2/condition.py | 617 ------------------------------------ 1 file changed, 617 deletions(-) delete mode 100644 moto/dynamodb2/condition.py diff --git a/moto/dynamodb2/condition.py b/moto/dynamodb2/condition.py deleted file mode 100644 index b50678e2..00000000 --- a/moto/dynamodb2/condition.py +++ /dev/null @@ -1,617 +0,0 @@ -import re -import json -import enum -from collections import deque -from collections import namedtuple - - -class Kind(enum.Enum): - """Defines types of nodes in the syntax tree.""" - - # Condition nodes - # --------------- - OR = enum.auto() - AND = enum.auto() - NOT = enum.auto() - PARENTHESES = enum.auto() - FUNCTION = enum.auto() - BETWEEN = enum.auto() - IN = enum.auto() - COMPARISON = enum.auto() - - # Operand nodes - # ------------- - EXPRESSION_ATTRIBUTE_VALUE = enum.auto() - PATH = enum.auto() - - # Literal nodes - # -------------- - LITERAL = enum.auto() - - -class Nonterminal(enum.Enum): - """Defines nonterminals for defining productions.""" - CONDITION = enum.auto() - OPERAND = enum.auto() - COMPARATOR = enum.auto() - FUNCTION_NAME = enum.auto() - IDENTIFIER = enum.auto() - AND = enum.auto() - OR = enum.auto() - NOT = enum.auto() - BETWEEN = enum.auto() - IN = enum.auto() - COMMA = enum.auto() - LEFT_PAREN = enum.auto() - RIGHT_PAREN = enum.auto() - WHITESPACE = enum.auto() - - -Node = namedtuple('Node', ['nonterminal', 'kind', 'text', 'value', 'children']) - - -class ConditionExpressionParser: - def __init__(self, condition_expression, expression_attribute_names, - expression_attribute_values): - self.condition_expression = condition_expression - self.expression_attribute_names = expression_attribute_names - self.expression_attribute_values = expression_attribute_values - - def parse(self): - """Returns a syntax tree for the expression. - - The tree, and all of the nodes in the tree are a tuple of - - kind: str - - children/value: - list of nodes for parent nodes - value for leaf nodes - - Raises AssertionError if the condition expression is invalid - Raises KeyError if expression attribute names/values are invalid - - Here are the types of nodes that can be returned. - The types of child nodes are denoted with a colon (:). - An arbitrary number of children is denoted with ... - - Condition: - ('OR', [lhs : Condition, rhs : Condition]) - ('AND', [lhs: Condition, rhs: Condition]) - ('NOT', [argument: Condition]) - ('PARENTHESES', [argument: Condition]) - ('FUNCTION', [('LITERAL', function_name: str), argument: Operand, ...]) - ('BETWEEN', [query: Operand, low: Operand, high: Operand]) - ('IN', [query: Operand, possible_value: Operand, ...]) - ('COMPARISON', [lhs: Operand, ('LITERAL', comparator: str), rhs: Operand]) - - Operand: - ('EXPRESSION_ATTRIBUTE_VALUE', value: dict, e.g. {'S': 'foobar'}) - ('PATH', [('LITERAL', path_element: str), ...]) - NOTE: Expression attribute names will be expanded - - Literal: - ('LITERAL', value: str) - - """ - if not self.condition_expression: - return None - nodes = self._lex_condition_expression() - nodes = self._parse_paths(nodes) - self._print_debug(nodes) - nodes = self._apply_comparator(nodes) - self._print_debug(nodes) - nodes = self._apply_in(nodes) - self._print_debug(nodes) - nodes = self._apply_between(nodes) - self._print_debug(nodes) - nodes = self._apply_functions(nodes) - self._print_debug(nodes) - nodes = self._apply_parens_and_booleans(nodes) - self._print_debug(nodes) - node = nodes[0] - return self._make_node_tree(node) - - def _lex_condition_expression(self): - nodes = deque() - remaining_expression = self.condition_expression - while remaining_expression: - node, remaining_expression = \ - self._lex_one_node(remaining_expression) - if node.nonterminal == Nonterminal.WHITESPACE: - continue - nodes.append(node) - return nodes - - def _lex_one_node(self, remaining_expression): - - attribute_regex = '(:|#)?[A-z0-9\-_]+' - patterns = [( - Nonterminal.WHITESPACE, re.compile('^ +') - ), ( - Nonterminal.COMPARATOR, re.compile( - '^(' - '=|' - '<>|' - '<|' - '<=|' - '>|' - '>=)'), - ), ( - Nonterminal.OPERAND, re.compile( - '^' + - attribute_regex + '(\.' + attribute_regex + ')*') - ), ( - Nonterminal.COMMA, re.compile('^,') - ), ( - Nonterminal.LEFT_PAREN, re.compile('^\(') - ), ( - Nonterminal.RIGHT_PAREN, re.compile('^\)') - )] - - for nonterminal, pattern in patterns: - match = pattern.match(remaining_expression) - if match: - match_text = match.group() - break - else: - raise AssertionError("Cannot parse condition starting at: " + - remaining_expression) - - value = match_text - node = Node( - nonterminal=nonterminal, - kind=Kind.LITERAL, - text=match_text, - value=match_text, - children=[]) - - remaining_expression = remaining_expression[len(match_text):] - - return node, remaining_expression - - def _parse_paths(self, nodes): - output = deque() - - while nodes: - node = nodes.popleft() - - if node.nonterminal == Nonterminal.OPERAND: - path = node.value.split('.') - children = [ - self._parse_path_element(name) - for name in path] - if len(children) == 1: - child = children[0] - if child.nonterminal != Nonterminal.IDENTIFIER: - output.append(child) - continue - else: - for child in children: - self._assert( - child.nonterminal == Nonterminal.IDENTIFIER, - "Cannot use %s in path" % child.text, [node]) - output.append(Node( - nonterminal=Nonterminal.OPERAND, - kind=Kind.PATH, - text=node.text, - value=None, - children=children)) - else: - output.append(node) - return output - - def _parse_path_element(self, name): - reserved = { - 'AND': Nonterminal.AND, - 'OR': Nonterminal.OR, - 'IN': Nonterminal.IN, - 'BETWEEN': Nonterminal.BETWEEN, - 'NOT': Nonterminal.NOT, - } - - functions = { - 'attribute_exists', - 'attribute_not_exists', - 'attribute_type', - 'begins_with', - 'contains', - 'size', - } - - - if name in reserved: - nonterminal = reserved[name] - return Node( - nonterminal=nonterminal, - kind=Kind.LITERAL, - text=name, - value=name, - children=[]) - elif name in functions: - return Node( - nonterminal=Nonterminal.FUNCTION_NAME, - kind=Kind.LITERAL, - text=name, - value=name, - children=[]) - elif name.startswith(':'): - return Node( - nonterminal=Nonterminal.OPERAND, - kind=Kind.EXPRESSION_ATTRIBUTE_VALUE, - text=name, - value=self._lookup_expression_attribute_value(name), - children=[]) - elif name.startswith('#'): - return Node( - nonterminal=Nonterminal.IDENTIFIER, - kind=Kind.LITERAL, - text=name, - value=self._lookup_expression_attribute_name(name), - children=[]) - else: - return Node( - nonterminal=Nonterminal.IDENTIFIER, - kind=Kind.LITERAL, - text=name, - value=name, - children=[]) - - def _lookup_expression_attribute_value(self, name): - return self.expression_attribute_values[name] - - def _lookup_expression_attribute_name(self, name): - return self.expression_attribute_names[name] - - # NOTE: The following constructions are ordered from high precedence to low precedence - # according to - # https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Expressions.OperatorsAndFunctions.html#Expressions.OperatorsAndFunctions.Precedence - # - # = <> < <= > >= - # IN - # BETWEEN - # attribute_exists attribute_not_exists begins_with contains - # Parentheses - # NOT - # AND - # OR - # - # The grammar is taken from - # https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Expressions.OperatorsAndFunctions.html#Expressions.OperatorsAndFunctions.Syntax - # - # condition-expression ::= - # operand comparator operand - # operand BETWEEN operand AND operand - # operand IN ( operand (',' operand (, ...) )) - # function - # condition AND condition - # condition OR condition - # NOT condition - # ( condition ) - # - # comparator ::= - # = - # <> - # < - # <= - # > - # >= - # - # function ::= - # attribute_exists (path) - # attribute_not_exists (path) - # attribute_type (path, type) - # begins_with (path, substr) - # contains (path, operand) - # size (path) - - def _matches(self, nodes, production): - """Check if the nodes start with the given production. - - Parameters - ---------- - nodes: list of Node - production: list of str - The name of a Nonterminal, or '*' for anything - - """ - if len(nodes) < len(production): - return False - for i in range(len(production)): - if production[i] == '*': - continue - expected = getattr(Nonterminal, production[i]) - if nodes[i].nonterminal != expected: - return False - return True - - def _apply_comparator(self, nodes): - """Apply condition := operand comparator operand.""" - output = deque() - - while nodes: - if self._matches(nodes, ['*', 'COMPARATOR']): - self._assert( - self._matches(nodes, ['OPERAND', 'COMPARATOR', 'OPERAND']), - "Bad comparison", list(nodes)[:3]) - lhs = nodes.popleft() - comparator = nodes.popleft() - rhs = nodes.popleft() - output.append(Node( - nonterminal=Nonterminal.CONDITION, - kind=Kind.COMPARISON, - text=" ".join([ - lhs.text, - comparator.text, - rhs.text]), - value=None, - children=[lhs, comparator, rhs])) - else: - output.append(nodes.popleft()) - return output - - def _apply_in(self, nodes): - """Apply condition := operand IN ( operand , ... ).""" - output = deque() - while nodes: - if self._matches(nodes, ['*', 'IN']): - self._assert( - self._matches(nodes, ['OPERAND', 'IN', 'LEFT_PAREN']), - "Bad IN expression", list(nodes)[:3]) - lhs = nodes.popleft() - in_node = nodes.popleft() - left_paren = nodes.popleft() - all_children = [lhs, in_node, left_paren] - rhs = [] - while True: - if self._matches(nodes, ['OPERAND', 'COMMA']): - operand = nodes.popleft() - separator = nodes.popleft() - all_children += [operand, separator] - rhs.append(operand) - elif self._matches(nodes, ['OPERAND', 'RIGHT_PAREN']): - operand = nodes.popleft() - separator = nodes.popleft() - all_children += [operand, separator] - rhs.append(operand) - break # Close - else: - self._assert( - False, - "Bad IN expression starting at", nodes) - output.append(Node( - nonterminal=Nonterminal.CONDITION, - kind=Kind.IN, - text=" ".join([t.text for t in all_children]), - value=None, - children=[lhs] + rhs)) - else: - output.append(nodes.popleft()) - return output - - def _apply_between(self, nodes): - """Apply condition := operand BETWEEN operand AND operand.""" - output = deque() - while nodes: - if self._matches(nodes, ['*', 'BETWEEN']): - self._assert( - self._matches(nodes, ['OPERAND', 'BETWEEN', 'OPERAND', - 'AND', 'OPERAND']), - "Bad BETWEEN expression", list(nodes)[:5]) - lhs = nodes.popleft() - between_node = nodes.popleft() - low = nodes.popleft() - and_node = nodes.popleft() - high = nodes.popleft() - all_children = [lhs, between_node, low, and_node, high] - output.append(Node( - nonterminal=Nonterminal.CONDITION, - kind=Kind.BETWEEN, - text=" ".join([t.text for t in all_children]), - value=None, - children=[lhs, low, high])) - else: - output.append(nodes.popleft()) - return output - - def _apply_functions(self, nodes): - """Apply condition := function_name (operand , ...).""" - output = deque() - expected_argument_kind_map = { - 'attribute_exists': [{Kind.PATH}], - 'attribute_not_exists': [{Kind.PATH}], - 'attribute_type': [{Kind.PATH}, {Kind.EXPRESSION_ATTRIBUTE_VALUE}], - 'begins_with': [{Kind.PATH}, {Kind.EXPRESSION_ATTRIBUTE_VALUE}], - 'contains': [{Kind.PATH}, {Kind.PATH, Kind.EXPRESSION_ATTRIBUTE_VALUE}], - 'size': [{Kind.PATH}], - } - while nodes: - if self._matches(nodes, ['FUNCTION_NAME']): - self._assert( - self._matches(nodes, ['FUNCTION_NAME', 'LEFT_PAREN', - 'OPERAND', '*']), - "Bad function expression at", list(nodes)[:4]) - function_name = nodes.popleft() - left_paren = nodes.popleft() - all_children = [function_name, left_paren] - arguments = [] - while True: - if self._matches(nodes, ['OPERAND', 'COMMA']): - operand = nodes.popleft() - separator = nodes.popleft() - all_children += [operand, separator] - arguments.append(operand) - elif self._matches(nodes, ['OPERAND', 'RIGHT_PAREN']): - operand = nodes.popleft() - separator = nodes.popleft() - all_children += [operand, separator] - arguments.append(operand) - break # Close paren - else: - self._assert( - False, - "Bad function expression", all_children + list(nodes)[:2]) - expected_kinds = expected_argument_kind_map[function_name.value] - self._assert( - len(arguments) == len(expected_kinds), - "Wrong number of arguments in", all_children) - for i in range(len(expected_kinds)): - self._assert( - arguments[i].kind in expected_kinds[i], - "Wrong type for argument %d in" % i, all_children) - output.append(Node( - nonterminal=Nonterminal.CONDITION, - kind=Kind.FUNCTION, - text=" ".join([t.text for t in all_children]), - value=None, - children=[function_name] + arguments)) - else: - output.append(nodes.popleft()) - return output - - def _apply_parens_and_booleans(self, nodes, left_paren=None): - """Apply condition := ( condition ) and booleans.""" - output = deque() - while nodes: - if self._matches(nodes, ['LEFT_PAREN']): - parsed = self._apply_parens_and_booleans(nodes, left_paren=nodes.popleft()) - self._assert( - len(parsed) >= 1, - "Failed to close parentheses at", nodes) - parens = parsed.popleft() - self._assert( - parens.kind == Kind.PARENTHESES, - "Failed to close parentheses at", nodes) - output.append(parens) - nodes = parsed - elif self._matches(nodes, ['RIGHT_PAREN']): - self._assert( - left_paren is not None, - "Unmatched ) at", nodes) - close_paren = nodes.popleft() - children = self._apply_booleans(output) - all_children = [left_paren, *children, close_paren] - return deque([ - Node( - nonterminal=Nonterminal.CONDITION, - kind=Kind.PARENTHESES, - text=" ".join([t.text for t in all_children]), - value=None, - children=list(children), - ), *nodes]) - else: - output.append(nodes.popleft()) - - self._assert( - left_paren is None, - "Unmatched ( at", list(output)) - return self._apply_booleans(output) - - def _apply_booleans(self, nodes): - """Apply and, or, and not constructions.""" - nodes = self._apply_not(nodes) - nodes = self._apply_and(nodes) - nodes = self._apply_or(nodes) - # The expression should reduce to a single condition - self._assert( - len(nodes) == 1, - "Unexpected expression at", list(nodes)[1:]) - self._assert( - nodes[0].nonterminal == Nonterminal.CONDITION, - "Incomplete condition", nodes) - return nodes - - def _apply_not(self, nodes): - """Apply condition := NOT condition.""" - output = deque() - while nodes: - if self._matches(nodes, ['NOT']): - self._assert( - self._matches(nodes, ['NOT', 'CONDITION']), - "Bad NOT expression", list(nodes)[:2]) - not_node = nodes.popleft() - child = nodes.popleft() - output.append(Node( - nonterminal=Nonterminal.CONDITION, - kind=Kind.NOT, - text=" ".join([not_node['text'], value['text']]), - value=None, - children=[child])) - else: - output.append(nodes.popleft()) - - return output - - def _apply_and(self, nodes): - """Apply condition := condition AND condition.""" - output = deque() - while nodes: - if self._matches(nodes, ['*', 'AND']): - self._assert( - self._matches(nodes, ['CONDITION', 'AND', 'CONDITION']), - "Bad AND expression", list(nodes)[:3]) - lhs = nodes.popleft() - and_node = nodes.popleft() - rhs = nodes.popleft() - all_children = [lhs, and_node, rhs] - output.append(Node( - nonterminal=Nonterminal.CONDITION, - kind=Kind.AND, - text=" ".join([t.text for t in all_children]), - value=None, - children=[lhs, rhs])) - else: - output.append(nodes.popleft()) - - return output - - def _apply_or(self, nodes): - """Apply condition := condition OR condition.""" - output = deque() - while nodes: - if self._matches(nodes, ['*', 'OR']): - self._assert( - self._matches(nodes, ['CONDITION', 'OR', 'CONDITION']), - "Bad OR expression", list(nodes)[:3]) - lhs = nodes.popleft() - or_node = nodes.popleft() - rhs = nodes.popleft() - all_children = [lhs, or_node, rhs] - output.append(Node( - nonterminal=Nonterminal.CONDITION, - kind=Kind.OR, - text=" ".join([t.text for t in all_children]), - value=None, - children=[lhs, rhs])) - else: - output.append(nodes.popleft()) - - return output - - def _make_node_tree(self, node): - if len(node.children) > 0: - return ( - node.kind.name, - [ - self._make_node_tree(child) - for child in node.children - ]) - else: - return (node.kind.name, node.value) - - def _print_debug(self, nodes): - print('ROOT') - for node in nodes: - self._print_node_recursive(node, depth=1) - - def _print_node_recursive(self, node, depth=0): - if len(node.children) > 0: - print(' ' * depth, node.nonterminal, node.kind) - for child in node.children: - self._print_node_recursive(child, depth=depth + 1) - else: - print(' ' * depth, node.nonterminal, node.kind, node.value) - - - - def _assert(self, condition, message, nodes): - if not condition: - raise AssertionError(message + " " + " ".join([t.text for t in nodes])) From 6303d07bac24021ecd0008e78e6a39ab6745d074 Mon Sep 17 00:00:00 2001 From: Matthew Stevens Date: Fri, 12 Apr 2019 10:13:36 -0400 Subject: [PATCH 14/24] Fixing tests --- moto/dynamodb2/comparisons.py | 125 ++++++++++++++++------------------ moto/dynamodb2/models.py | 16 ++--- 2 files changed, 67 insertions(+), 74 deletions(-) diff --git a/moto/dynamodb2/comparisons.py b/moto/dynamodb2/comparisons.py index 4095acba..1a4633e6 100644 --- a/moto/dynamodb2/comparisons.py +++ b/moto/dynamodb2/comparisons.py @@ -2,7 +2,6 @@ from __future__ import unicode_literals import re import six import re -import enum from collections import deque from collections import namedtuple @@ -199,46 +198,47 @@ class ConditionExpressionParser: op = self._make_op_condition(node) return op - class Kind(enum.Enum): - """Defines types of nodes in the syntax tree.""" + class Kind: + """Enum defining types of nodes in the syntax tree.""" # Condition nodes # --------------- - OR = enum.auto() - AND = enum.auto() - NOT = enum.auto() - PARENTHESES = enum.auto() - FUNCTION = enum.auto() - BETWEEN = enum.auto() - IN = enum.auto() - COMPARISON = enum.auto() + OR = 'OR' + AND = 'AND' + NOT = 'NOT' + PARENTHESES = 'PARENTHESES' + FUNCTION = 'FUNCTION' + BETWEEN = 'BETWEEN' + IN = 'IN' + COMPARISON = 'COMPARISON' # Operand nodes # ------------- - EXPRESSION_ATTRIBUTE_VALUE = enum.auto() - PATH = enum.auto() + EXPRESSION_ATTRIBUTE_VALUE = 'EXPRESSION_ATTRIBUTE_VALUE' + PATH = 'PATH' # Literal nodes # -------------- - LITERAL = enum.auto() + LITERAL = 'LITERAL' - class Nonterminal(enum.Enum): - """Defines nonterminals for defining productions.""" - CONDITION = enum.auto() - OPERAND = enum.auto() - COMPARATOR = enum.auto() - FUNCTION_NAME = enum.auto() - IDENTIFIER = enum.auto() - AND = enum.auto() - OR = enum.auto() - NOT = enum.auto() - BETWEEN = enum.auto() - IN = enum.auto() - COMMA = enum.auto() - LEFT_PAREN = enum.auto() - RIGHT_PAREN = enum.auto() - WHITESPACE = enum.auto() + class Nonterminal: + """Enum defining nonterminals for productions.""" + + CONDITION = 'CONDITION' + OPERAND = 'OPERAND' + COMPARATOR = 'COMPARATOR' + FUNCTION_NAME = 'FUNCTION_NAME' + IDENTIFIER = 'IDENTIFIER' + AND = 'AND' + OR = 'OR' + NOT = 'NOT' + BETWEEN = 'BETWEEN' + IN = 'IN' + COMMA = 'COMMA' + LEFT_PAREN = 'LEFT_PAREN' + RIGHT_PAREN = 'RIGHT_PAREN' + WHITESPACE = 'WHITESPACE' Node = namedtuple('Node', ['nonterminal', 'kind', 'text', 'value', 'children']) @@ -286,7 +286,7 @@ class ConditionExpressionParser: if match: match_text = match.group() break - else: + else: # pragma: no cover raise ValueError("Cannot parse condition starting at: " + remaining_expression) @@ -387,7 +387,7 @@ class ConditionExpressionParser: children=[]) elif name.startswith('['): # e.g. [123] - if not name.endswith(']'): + if not name.endswith(']'): # pragma: no cover raise ValueError("Bad path element %s" % name) return self.Node( nonterminal=self.Nonterminal.IDENTIFIER, @@ -642,7 +642,7 @@ class ConditionExpressionParser: "Unmatched ) at", nodes) close_paren = nodes.popleft() children = self._apply_booleans(output) - all_children = [left_paren, *children, close_paren] + all_children = [left_paren] + list(children) + [close_paren] return deque([ self.Node( nonterminal=self.Nonterminal.CONDITION, @@ -650,7 +650,7 @@ class ConditionExpressionParser: text=" ".join([t.text for t in all_children]), value=None, children=list(children), - ), *nodes]) + )] + list(nodes)) else: output.append(nodes.popleft()) @@ -747,11 +747,12 @@ class ConditionExpressionParser: return AttributeValue(node.value) elif node.kind == self.Kind.FUNCTION: # size() - function_node, *arguments = node.children + function_node = node.children[0] + arguments = node.children[1:] function_name = function_node.value arguments = [self._make_operand(arg) for arg in arguments] return FUNC_CLASS[function_name](*arguments) - else: + else: # pragma: no cover raise ValueError("Unknown operand: %r" % node) @@ -768,12 +769,13 @@ class ConditionExpressionParser: self._make_op_condition(rhs)) elif node.kind == self.Kind.NOT: child, = node.children - return OpNot(self._make_op_condition(child), None) + return OpNot(self._make_op_condition(child)) elif node.kind == self.Kind.PARENTHESES: child, = node.children return self._make_op_condition(child) elif node.kind == self.Kind.FUNCTION: - function_node, *arguments = node.children + function_node = node.children[0] + arguments = node.children[1:] function_name = function_node.value arguments = [self._make_operand(arg) for arg in arguments] return FUNC_CLASS[function_name](*arguments) @@ -784,24 +786,25 @@ class ConditionExpressionParser: self._make_operand(low), self._make_operand(high)) elif node.kind == self.Kind.IN: - query, *possible_values = node.children + query = node.children[0] + possible_values = node.children[1:] query = self._make_operand(query) possible_values = [self._make_operand(v) for v in possible_values] return FuncIn(query, *possible_values) elif node.kind == self.Kind.COMPARISON: lhs, comparator, rhs = node.children - return OP_CLASS[comparator.value]( + return COMPARATOR_CLASS[comparator.value]( self._make_operand(lhs), self._make_operand(rhs)) - else: + else: # pragma: no cover raise ValueError("Unknown expression node kind %r" % node.kind) - def _print_debug(self, nodes): + def _print_debug(self, nodes): # pragma: no cover print('ROOT') for node in nodes: self._print_node_recursive(node, depth=1) - def _print_node_recursive(self, node, depth=0): + def _print_node_recursive(self, node, depth=0): # pragma: no cover if len(node.children) > 0: print(' ' * depth, node.nonterminal, node.kind) for child in node.children: @@ -922,6 +925,9 @@ class OpDefault(Op): class OpNot(Op): OP = 'NOT' + def __init__(self, lhs): + super(OpNot, self).__init__(lhs, None) + def expr(self, item): lhs = self.lhs.expr(item) return not lhs @@ -1002,15 +1008,6 @@ class OpOr(Op): return lhs or rhs -class OpIn(Op): - OP = 'IN' - - def expr(self, item): - lhs = self.lhs.expr(item) - rhs = self.rhs.expr(item) - return lhs in rhs - - class Func(object): """ Base class for a FilterExpression function @@ -1034,14 +1031,14 @@ class FuncAttrExists(Func): def __init__(self, attribute): self.attr = attribute - super().__init__(attribute) + super(FuncAttrExists, self).__init__(attribute) def expr(self, item): return self.attr.get_type(item) is not None def FuncAttrNotExists(attribute): - return OpNot(FuncAttrExists(attribute), None) + return OpNot(FuncAttrExists(attribute)) class FuncAttrType(Func): @@ -1050,7 +1047,7 @@ class FuncAttrType(Func): def __init__(self, attribute, _type): self.attr = attribute self.type = _type - super().__init__(attribute, _type) + super(FuncAttrType, self).__init__(attribute, _type) def expr(self, item): return self.attr.get_type(item) == self.type.expr(item) @@ -1062,7 +1059,7 @@ class FuncBeginsWith(Func): def __init__(self, attribute, substr): self.attr = attribute self.substr = substr - super().__init__(attribute, substr) + super(FuncBeginsWith, self).__init__(attribute, substr) def expr(self, item): if self.attr.get_type(item) != 'S': @@ -1078,7 +1075,7 @@ class FuncContains(Func): def __init__(self, attribute, operand): self.attr = attribute self.operand = operand - super().__init__(attribute, operand) + super(FuncContains, self).__init__(attribute, operand) def expr(self, item): if self.attr.get_type(item) in ('S', 'SS', 'NS', 'BS', 'L'): @@ -1090,7 +1087,7 @@ class FuncContains(Func): def FuncNotContains(attribute, operand): - return OpNot(FuncContains(attribute, operand), None) + return OpNot(FuncContains(attribute, operand)) class FuncSize(Func): @@ -1098,7 +1095,7 @@ class FuncSize(Func): def __init__(self, attribute): self.attr = attribute - super().__init__(attribute) + super(FuncSize, self).__init__(attribute) def expr(self, item): if self.attr.get_type(item) is None: @@ -1116,7 +1113,7 @@ class FuncBetween(Func): self.attr = attribute self.start = start self.end = end - super().__init__(attribute, start, end) + super(FuncBetween, self).__init__(attribute, start, end) def expr(self, item): return self.start.expr(item) <= self.attr.expr(item) <= self.end.expr(item) @@ -1128,7 +1125,7 @@ class FuncIn(Func): def __init__(self, attribute, *possible_values): self.attr = attribute self.possible_values = possible_values - super().__init__(attribute, *possible_values) + super(FuncIn, self).__init__(attribute, *possible_values) def expr(self, item): for possible_value in self.possible_values: @@ -1138,11 +1135,7 @@ class FuncIn(Func): return False -OP_CLASS = { - 'NOT': OpNot, - 'AND': OpAnd, - 'OR': OpOr, - 'IN': OpIn, +COMPARATOR_CLASS = { '<': OpLessThan, '>': OpGreaterThan, '<=': OpLessThanOrEqual, diff --git a/moto/dynamodb2/models.py b/moto/dynamodb2/models.py index 037db3d7..1f2c6deb 100644 --- a/moto/dynamodb2/models.py +++ b/moto/dynamodb2/models.py @@ -6,6 +6,7 @@ import decimal import json import re import uuid +import six import boto3 from moto.compat import OrderedDict @@ -89,7 +90,7 @@ class DynamoType(object): Returns DynamoType or None. """ - if isinstance(key, str) and self.is_map() and key in self.value: + if isinstance(key, six.string_types) and self.is_map() and key in self.value: return DynamoType(self.value[key]) if isinstance(key, int) and self.is_list(): @@ -994,7 +995,6 @@ class DynamoDBBackend(BaseBackend): dynamo_types = [DynamoType(value) for value in comparison_values] scan_filters[key] = (comparison_operator, dynamo_types) - filter_expression = get_filter_expression(filter_expression, expr_names, expr_values) return table.scan(scan_filters, limit, exclusive_start_key, filter_expression, index_name) @@ -1024,12 +1024,12 @@ class DynamoDBBackend(BaseBackend): if not get_expected(expected).expr(item): raise ValueError('The conditional request failed') - condition_op = get_filter_expression( - condition_expression, - expression_attribute_names, - expression_attribute_values) - if not condition_op.expr(current): - raise ValueError('The conditional request failed') + condition_op = get_filter_expression( + condition_expression, + expression_attribute_names, + expression_attribute_values) + if not condition_op.expr(item): + raise ValueError('The conditional request failed') # Update does not fail on new items, so create one if item is None: From 83082df4d907293438c7b2cd9f622ca8da06450d Mon Sep 17 00:00:00 2001 From: Matthew Stevens Date: Sun, 14 Apr 2019 19:37:43 -0400 Subject: [PATCH 15/24] Adding update_item and attribute_not_exists test --- tests/test_dynamodb2/test_dynamodb.py | 36 +++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/tests/test_dynamodb2/test_dynamodb.py b/tests/test_dynamodb2/test_dynamodb.py index 0ea1d64e..dc41e367 100644 --- a/tests/test_dynamodb2/test_dynamodb.py +++ b/tests/test_dynamodb2/test_dynamodb.py @@ -1719,6 +1719,42 @@ def test_condition_expressions(): } ) + # Make sure update_item honors ConditionExpression as well + dynamodb.update_item( + TableName='test1', + Key={ + 'client': {'S': 'client1'}, + 'app': {'S': 'app1'}, + }, + UpdateExpression='set #match=:match', + ConditionExpression='attribute_exists(#existing)', + ExpressionAttributeNames={ + '#existing': 'existing', + '#match': 'match', + }, + ExpressionAttributeValues={ + ':match': {'S': 'match'} + } + ) + + with assert_raises(dynamodb.exceptions.ConditionalCheckFailedException): + dynamodb.update_item( + TableName='test1', + Key={ + 'client': { 'S': 'client1'}, + 'app': { 'S': 'app1'}, + }, + UpdateExpression='set #match=:match', + ConditionExpression='attribute_not_exists(#existing)', + ExpressionAttributeValues={ + ':match': {'S': 'match'} + }, + ExpressionAttributeNames={ + '#existing': 'existing', + '#match': 'match', + }, + ) + @mock_dynamodb2 def test_query_gsi_with_range_key(): From 467f669c1e6e48d8158a6b35474ccabe56aab3da Mon Sep 17 00:00:00 2001 From: Garrett Heel Date: Wed, 26 Jun 2019 23:13:01 +0100 Subject: [PATCH 16/24] add test for attr doesn't exist --- moto/dynamodb2/models.py | 1 - tests/test_dynamodb2/test_dynamodb.py | 54 +++++++++++++++++++++++++-- 2 files changed, 50 insertions(+), 5 deletions(-) diff --git a/moto/dynamodb2/models.py b/moto/dynamodb2/models.py index 1f2c6deb..6d3a4b95 100644 --- a/moto/dynamodb2/models.py +++ b/moto/dynamodb2/models.py @@ -13,7 +13,6 @@ from moto.compat import OrderedDict from moto.core import BaseBackend, BaseModel from moto.core.utils import unix_time from moto.core.exceptions import JsonRESTError -from .comparisons import get_comparison_func, get_filter_expression, Op from .comparisons import get_comparison_func from .comparisons import get_filter_expression from .comparisons import get_expected diff --git a/tests/test_dynamodb2/test_dynamodb.py b/tests/test_dynamodb2/test_dynamodb.py index dc41e367..a4d79f4d 100644 --- a/tests/test_dynamodb2/test_dynamodb.py +++ b/tests/test_dynamodb2/test_dynamodb.py @@ -1563,7 +1563,6 @@ def test_dynamodb_streams_2(): @mock_dynamodb2 def test_condition_expressions(): client = boto3.client('dynamodb', region_name='us-east-1') - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') # Create the DynamoDB table. client.create_table( @@ -1720,7 +1719,7 @@ def test_condition_expressions(): ) # Make sure update_item honors ConditionExpression as well - dynamodb.update_item( + client.update_item( TableName='test1', Key={ 'client': {'S': 'client1'}, @@ -1737,8 +1736,8 @@ def test_condition_expressions(): } ) - with assert_raises(dynamodb.exceptions.ConditionalCheckFailedException): - dynamodb.update_item( + with assert_raises(client.exceptions.ConditionalCheckFailedException): + client.update_item( TableName='test1', Key={ 'client': { 'S': 'client1'}, @@ -1756,6 +1755,53 @@ def test_condition_expressions(): ) +@mock_dynamodb2 +def test_condition_expression__attr_doesnt_exist(): + client = boto3.client('dynamodb', region_name='us-east-1') + + client.create_table( + TableName='test', + KeySchema=[{'AttributeName': 'forum_name', 'KeyType': 'HASH'}], + AttributeDefinitions=[ + {'AttributeName': 'forum_name', 'AttributeType': 'S'}, + ], + ProvisionedThroughput={'ReadCapacityUnits': 1, 'WriteCapacityUnits': 1}, + ) + + client.put_item( + TableName='test', + Item={ + 'forum_name': {'S': 'foo'}, + 'ttl': {'N': 'bar'}, + } + ) + + + def update_if_attr_doesnt_exist(): + # Test nonexistent top-level attribute. + client.update_item( + TableName='test', + Key={ + 'forum_name': {'S': 'the-key'}, + 'subject': {'S': 'the-subject'}, + }, + UpdateExpression='set #new_state=:new_state, #ttl=:ttl', + ConditionExpression='attribute_not_exists(#new_state)', + ExpressionAttributeNames={'#new_state': 'foobar', '#ttl': 'ttl'}, + ExpressionAttributeValues={ + ':new_state': {'S': 'some-value'}, + ':ttl': {'N': '12345.67'}, + }, + ReturnValues='ALL_NEW', + ) + + update_if_attr_doesnt_exist() + + # Second time should fail + with assert_raises(client.exceptions.ConditionalCheckFailedException): + update_if_attr_doesnt_exist() + + @mock_dynamodb2 def test_query_gsi_with_range_key(): dynamodb = boto3.client('dynamodb', region_name='us-east-1') From b2adbf1f48a23612eed85359765693ba8e228cc6 Mon Sep 17 00:00:00 2001 From: Aden Khan Date: Wed, 3 Jul 2019 11:35:56 -0400 Subject: [PATCH 17/24] Adding the functionality and test so that the If-Modified-Since header is honored in GET Object Signed-off-by: Aden Khan --- moto/s3/responses.py | 7 ++++++- tests/test_s3/test_s3.py | 22 ++++++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/moto/s3/responses.py b/moto/s3/responses.py index e0366666..40449dbf 100644 --- a/moto/s3/responses.py +++ b/moto/s3/responses.py @@ -657,7 +657,7 @@ class ResponseObject(_TemplateEnvironmentMixin): body = b'' if method == 'GET': - return self._key_response_get(bucket_name, query, key_name, headers) + return self._key_response_get(bucket_name, query, key_name, headers=request.headers) elif method == 'PUT': return self._key_response_put(request, body, bucket_name, query, key_name, headers) elif method == 'HEAD': @@ -684,10 +684,15 @@ class ResponseObject(_TemplateEnvironmentMixin): parts=parts ) version_id = query.get('versionId', [None])[0] + if_modified_since = headers.get('If-Modified-Since', None) key = self.backend.get_key( bucket_name, key_name, version_id=version_id) if key is None: raise MissingKey(key_name) + if if_modified_since: + if_modified_since = str_to_rfc_1123_datetime(if_modified_since) + if if_modified_since and key.last_modified < if_modified_since: + return 304, response_headers, 'Not Modified' if 'acl' in query: template = self.response_template(S3_OBJECT_ACL_RESPONSE) return 200, response_headers, template.render(obj=key) diff --git a/tests/test_s3/test_s3.py b/tests/test_s3/test_s3.py index f26964ab..697c4786 100644 --- a/tests/test_s3/test_s3.py +++ b/tests/test_s3/test_s3.py @@ -1596,6 +1596,28 @@ def test_boto3_delete_versioned_bucket(): client.delete_bucket(Bucket='blah') +@mock_s3 +def test_boto3_get_object_if_modified_since(): + s3 = boto3.client('s3', region_name='us-east-1') + bucket_name = "blah" + s3.create_bucket(Bucket=bucket_name) + + key = 'hello.txt' + + s3.put_object( + Bucket=bucket_name, + Key=key, + Body='test' + ) + + with assert_raises(botocore.exceptions.ClientError) as err: + s3.get_object( + Bucket=bucket_name, + Key=key, + IfModifiedSince=datetime.datetime.utcnow() + datetime.timedelta(hours=1) + ) + e = err.exception + e.response['Error'].should.equal({'Code': '304', 'Message': 'Not Modified'}) @mock_s3 def test_boto3_head_object_if_modified_since(): From c51ce76ee9ccd64760157d2dde8158462dd7420b Mon Sep 17 00:00:00 2001 From: Berislav Kovacki Date: Tue, 9 Jul 2019 02:10:33 +0200 Subject: [PATCH 18/24] Add InstanceCreateTime to DBInstance --- moto/rds2/models.py | 4 +++- tests/test_rds2/test_rds2.py | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/moto/rds2/models.py b/moto/rds2/models.py index fee004f7..498f9b12 100644 --- a/moto/rds2/models.py +++ b/moto/rds2/models.py @@ -70,6 +70,7 @@ class Database(BaseModel): self.port = Database.default_port(self.engine) self.db_instance_identifier = kwargs.get('db_instance_identifier') self.db_name = kwargs.get("db_name") + self.instance_create_time = iso_8601_datetime_with_milliseconds(datetime.datetime.now()) self.publicly_accessible = kwargs.get("publicly_accessible") if self.publicly_accessible is None: self.publicly_accessible = True @@ -148,6 +149,7 @@ class Database(BaseModel): {{ database.db_instance_identifier }} {{ database.dbi_resource_id }} + {{ database.instance_create_time }} 03:50-04:20 wed:06:38-wed:07:08 @@ -373,7 +375,7 @@ class Database(BaseModel): "Address": "{{ database.address }}", "Port": "{{ database.port }}" }, - "InstanceCreateTime": null, + "InstanceCreateTime": "{{ database.instance_create_time }}", "Iops": null, "ReadReplicaDBInstanceIdentifiers": [{%- for replica in database.replicas -%} {%- if not loop.first -%},{%- endif -%} diff --git a/tests/test_rds2/test_rds2.py b/tests/test_rds2/test_rds2.py index a25b5319..cf5c9a90 100644 --- a/tests/test_rds2/test_rds2.py +++ b/tests/test_rds2/test_rds2.py @@ -34,6 +34,7 @@ def test_create_database(): db_instance['IAMDatabaseAuthenticationEnabled'].should.equal(False) db_instance['DbiResourceId'].should.contain("db-") db_instance['CopyTagsToSnapshot'].should.equal(False) + db_instance['InstanceCreateTime'].should.be.a("datetime.datetime") @mock_rds2 From e77c4e3d09f39e6415fbbc86cfb98709480eb1a9 Mon Sep 17 00:00:00 2001 From: Steve Pulec Date: Mon, 8 Jul 2019 20:42:24 -0500 Subject: [PATCH 19/24] Add getting setup to contributing. --- CONTRIBUTING.md | 4 ++++ Makefile | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index f2808322..40da55cc 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -2,6 +2,10 @@ Moto has a [Code of Conduct](https://github.com/spulec/moto/blob/master/CODE_OF_CONDUCT.md), you can expect to be treated with respect at all times when interacting with this project. +## Running the tests locally + +Moto has a Makefile which has some helpful commands for getting setup. You should be able to run `make init` to install the dependencies and then `make test` to run the tests. + ## Is there a missing feature? Moto is easier to contribute to than you probably think. There's [a list of which endpoints have been implemented](https://github.com/spulec/moto/blob/master/IMPLEMENTATION_COVERAGE.md) and we invite you to add new endpoints to existing services or to add new services. diff --git a/Makefile b/Makefile index de08c6f7..2a724976 100644 --- a/Makefile +++ b/Makefile @@ -10,7 +10,7 @@ endif init: @python setup.py develop - @pip install -r requirements.txt + @pip install -r requirements-dev.txt lint: flake8 moto From ba95c945f9b16f692b217f89816eae36fccab11c Mon Sep 17 00:00:00 2001 From: Garrett Heel Date: Tue, 9 Jul 2019 09:20:35 -0400 Subject: [PATCH 20/24] remove dead code --- moto/dynamodb2/responses.py | 61 ------------------------------------- 1 file changed, 61 deletions(-) diff --git a/moto/dynamodb2/responses.py b/moto/dynamodb2/responses.py index 13dde683..12260384 100644 --- a/moto/dynamodb2/responses.py +++ b/moto/dynamodb2/responses.py @@ -32,67 +32,6 @@ def get_empty_str_error(): )) -def condition_expression_to_expected(condition_expression, expression_attribute_names, expression_attribute_values): - """ - Limited condition expression syntax parsing. - Supports Global Negation ex: NOT(inner expressions). - Supports simple AND conditions ex: cond_a AND cond_b and cond_c. - Atomic expressions supported are attribute_exists(key), attribute_not_exists(key) and #key = :value. - """ - expected = {} - if condition_expression and 'OR' not in condition_expression: - reverse_re = re.compile('^NOT\s*\((.*)\)$') - reverse_m = reverse_re.match(condition_expression.strip()) - - reverse = False - if reverse_m: - reverse = True - condition_expression = reverse_m.group(1) - - cond_items = [c.strip() for c in condition_expression.split('AND')] - if cond_items: - exists_re = re.compile('^attribute_exists\s*\((.*)\)$') - not_exists_re = re.compile( - '^attribute_not_exists\s*\((.*)\)$') - equals_re = re.compile('^(#?\w+)\s*=\s*(\:?\w+)') - - for cond in cond_items: - exists_m = exists_re.match(cond) - not_exists_m = not_exists_re.match(cond) - equals_m = equals_re.match(cond) - - if exists_m: - attribute_name = expression_attribute_names_lookup(exists_m.group(1), expression_attribute_names) - expected[attribute_name] = {'Exists': True if not reverse else False} - elif not_exists_m: - attribute_name = expression_attribute_names_lookup(not_exists_m.group(1), expression_attribute_names) - expected[attribute_name] = {'Exists': False if not reverse else True} - elif equals_m: - attribute_name = expression_attribute_names_lookup(equals_m.group(1), expression_attribute_names) - attribute_value = expression_attribute_values_lookup(equals_m.group(2), expression_attribute_values) - expected[attribute_name] = { - 'AttributeValueList': [attribute_value], - 'ComparisonOperator': 'EQ' if not reverse else 'NEQ'} - - return expected - - -def expression_attribute_names_lookup(attribute_name, expression_attribute_names): - if attribute_name.startswith('#') and attribute_name in expression_attribute_names: - return expression_attribute_names[attribute_name] - else: - return attribute_name - - -def expression_attribute_values_lookup(attribute_value, expression_attribute_values): - if isinstance(attribute_value, six.string_types) and \ - attribute_value.startswith(':') and\ - attribute_value in expression_attribute_values: - return expression_attribute_values[attribute_value] - else: - return attribute_value - - class DynamoHandler(BaseResponse): def get_endpoint_name(self, headers): From 53f8997d622694153b6a74766da295179845a7d3 Mon Sep 17 00:00:00 2001 From: Steve Pulec Date: Tue, 9 Jul 2019 18:21:00 -0500 Subject: [PATCH 21/24] Fix for UpdateExpression with newline. Closes #2275. --- moto/dynamodb2/responses.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/moto/dynamodb2/responses.py b/moto/dynamodb2/responses.py index 5dde432d..e345d886 100644 --- a/moto/dynamodb2/responses.py +++ b/moto/dynamodb2/responses.py @@ -626,7 +626,7 @@ class DynamoHandler(BaseResponse): name = self.body['TableName'] key = self.body['Key'] return_values = self.body.get('ReturnValues', 'NONE') - update_expression = self.body.get('UpdateExpression') + update_expression = self.body.get('UpdateExpression', '').strip() attribute_updates = self.body.get('AttributeUpdates') expression_attribute_names = self.body.get( 'ExpressionAttributeNames', {}) From ab67c1b26e63734f400ad31a8ae0ec3cfb65b149 Mon Sep 17 00:00:00 2001 From: Steve Pulec Date: Wed, 10 Jul 2019 22:04:31 -0500 Subject: [PATCH 22/24] 1.3.10 --- moto/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/moto/__init__.py b/moto/__init__.py index 9c974f00..6337f046 100644 --- a/moto/__init__.py +++ b/moto/__init__.py @@ -3,7 +3,7 @@ import logging # logging.getLogger('boto').setLevel(logging.CRITICAL) __title__ = 'moto' -__version__ = '1.3.9' +__version__ = '1.3.10' from .acm import mock_acm # flake8: noqa from .apigateway import mock_apigateway, mock_apigateway_deprecated # flake8: noqa From 108dc6b049d8a63a6d2c46d47f97beae427eedc0 Mon Sep 17 00:00:00 2001 From: Steve Pulec Date: Wed, 10 Jul 2019 22:17:47 -0500 Subject: [PATCH 23/24] Prep for 1.3.11 --- moto/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/moto/__init__.py b/moto/__init__.py index 6337f046..a6f35069 100644 --- a/moto/__init__.py +++ b/moto/__init__.py @@ -3,7 +3,7 @@ import logging # logging.getLogger('boto').setLevel(logging.CRITICAL) __title__ = 'moto' -__version__ = '1.3.10' +__version__ = '1.3.11' from .acm import mock_acm # flake8: noqa from .apigateway import mock_apigateway, mock_apigateway_deprecated # flake8: noqa From 4fd0b5c7109d588d6f82d78a6151b9a927d2020e Mon Sep 17 00:00:00 2001 From: Berislav Kovacki Date: Thu, 11 Jul 2019 22:43:42 +0200 Subject: [PATCH 24/24] Add support for OptionGroupName in create_db_instance --- moto/rds2/exceptions.py | 9 +++++++++ moto/rds2/models.py | 31 +++++++++++++++++++------------ moto/rds2/responses.py | 2 +- tests/test_rds2/test_rds2.py | 33 +++++++++++++++++++++++++++++++++ 4 files changed, 62 insertions(+), 13 deletions(-) diff --git a/moto/rds2/exceptions.py b/moto/rds2/exceptions.py index 0e716310..e82ae707 100644 --- a/moto/rds2/exceptions.py +++ b/moto/rds2/exceptions.py @@ -60,6 +60,15 @@ class DBParameterGroupNotFoundError(RDSClientError): 'DB Parameter Group {0} not found.'.format(db_parameter_group_name)) +class OptionGroupNotFoundFaultError(RDSClientError): + + def __init__(self, option_group_name): + super(OptionGroupNotFoundFaultError, self).__init__( + 'OptionGroupNotFoundFault', + 'Specified OptionGroupName: {0} not found.'.format(option_group_name) + ) + + class InvalidDBClusterStateFaultError(RDSClientError): def __init__(self, database_identifier): diff --git a/moto/rds2/models.py b/moto/rds2/models.py index 498f9b12..81b346fd 100644 --- a/moto/rds2/models.py +++ b/moto/rds2/models.py @@ -20,6 +20,7 @@ from .exceptions import (RDSClientError, DBSecurityGroupNotFoundError, DBSubnetGroupNotFoundError, DBParameterGroupNotFoundError, + OptionGroupNotFoundFaultError, InvalidDBClusterStateFaultError, InvalidDBInstanceStateError, SnapshotQuotaExceededError, @@ -100,6 +101,8 @@ class Database(BaseModel): 'preferred_backup_window', '13:14-13:44') self.license_model = kwargs.get('license_model', 'general-public-license') self.option_group_name = kwargs.get('option_group_name', None) + if self.option_group_name and self.option_group_name not in rds2_backends[self.region].option_groups: + raise OptionGroupNotFoundFaultError(self.option_group_name) self.default_option_groups = {"MySQL": "default.mysql5.6", "mysql": "default.mysql5.6", "postgres": "default.postgres9.3" @@ -175,6 +178,10 @@ class Database(BaseModel): {{ database.license_model }} {{ database.engine_version }} + + {{ database.option_group_name }} + in-sync + {% for db_parameter_group in database.db_parameter_groups() %} @@ -875,13 +882,16 @@ class RDS2Backend(BaseBackend): def create_option_group(self, option_group_kwargs): option_group_id = option_group_kwargs['name'] - valid_option_group_engines = {'mysql': ['5.6'], - 'oracle-se1': ['11.2'], - 'oracle-se': ['11.2'], - 'oracle-ee': ['11.2'], + valid_option_group_engines = {'mariadb': ['10.0', '10.1', '10.2', '10.3'], + 'mysql': ['5.5', '5.6', '5.7', '8.0'], + 'oracle-se2': ['11.2', '12.1', '12.2'], + 'oracle-se1': ['11.2', '12.1', '12.2'], + 'oracle-se': ['11.2', '12.1', '12.2'], + 'oracle-ee': ['11.2', '12.1', '12.2'], 'sqlserver-se': ['10.50', '11.00'], - 'sqlserver-ee': ['10.50', '11.00'] - } + 'sqlserver-ee': ['10.50', '11.00'], + 'sqlserver-ex': ['10.50', '11.00'], + 'sqlserver-web': ['10.50', '11.00']} if option_group_kwargs['name'] in self.option_groups: raise RDSClientError('OptionGroupAlreadyExistsFault', 'An option group named {0} already exists.'.format(option_group_kwargs['name'])) @@ -907,8 +917,7 @@ class RDS2Backend(BaseBackend): if option_group_name in self.option_groups: return self.option_groups.pop(option_group_name) else: - raise RDSClientError( - 'OptionGroupNotFoundFault', 'Specified OptionGroupName: {0} not found.'.format(option_group_name)) + raise OptionGroupNotFoundFaultError(option_group_name) def describe_option_groups(self, option_group_kwargs): option_group_list = [] @@ -937,8 +946,7 @@ class RDS2Backend(BaseBackend): else: option_group_list.append(option_group) if not len(option_group_list): - raise RDSClientError('OptionGroupNotFoundFault', - 'Specified OptionGroupName: {0} not found.'.format(option_group_kwargs['name'])) + raise OptionGroupNotFoundFaultError(option_group_kwargs['name']) return option_group_list[marker:max_records + marker] @staticmethod @@ -967,8 +975,7 @@ class RDS2Backend(BaseBackend): def modify_option_group(self, option_group_name, options_to_include=None, options_to_remove=None, apply_immediately=None): if option_group_name not in self.option_groups: - raise RDSClientError('OptionGroupNotFoundFault', - 'Specified OptionGroupName: {0} not found.'.format(option_group_name)) + raise OptionGroupNotFoundFaultError(option_group_name) if not options_to_include and not options_to_remove: raise RDSClientError('InvalidParameterValue', 'At least one option must be added, modified, or removed.') diff --git a/moto/rds2/responses.py b/moto/rds2/responses.py index 66d4e0c5..e9262563 100644 --- a/moto/rds2/responses.py +++ b/moto/rds2/responses.py @@ -34,7 +34,7 @@ class RDS2Response(BaseResponse): "master_user_password": self._get_param('MasterUserPassword'), "master_username": self._get_param('MasterUsername'), "multi_az": self._get_bool_param("MultiAZ"), - # OptionGroupName + "option_group_name": self._get_param("OptionGroupName"), "port": self._get_param('Port'), # PreferredBackupWindow # PreferredMaintenanceWindow diff --git a/tests/test_rds2/test_rds2.py b/tests/test_rds2/test_rds2.py index cf5c9a90..8ea296c2 100644 --- a/tests/test_rds2/test_rds2.py +++ b/tests/test_rds2/test_rds2.py @@ -37,6 +37,38 @@ def test_create_database(): db_instance['InstanceCreateTime'].should.be.a("datetime.datetime") +@mock_rds2 +def test_create_database_non_existing_option_group(): + conn = boto3.client('rds', region_name='us-west-2') + database = conn.create_db_instance.when.called_with( + DBInstanceIdentifier='db-master-1', + AllocatedStorage=10, + Engine='postgres', + DBName='staging-postgres', + DBInstanceClass='db.m1.small', + OptionGroupName='non-existing').should.throw(ClientError) + + +@mock_rds2 +def test_create_database_with_option_group(): + conn = boto3.client('rds', region_name='us-west-2') + conn.create_option_group(OptionGroupName='my-og', + EngineName='mysql', + MajorEngineVersion='5.6', + OptionGroupDescription='test option group') + database = conn.create_db_instance(DBInstanceIdentifier='db-master-1', + AllocatedStorage=10, + Engine='postgres', + DBName='staging-postgres', + DBInstanceClass='db.m1.small', + OptionGroupName='my-og') + db_instance = database['DBInstance'] + db_instance['AllocatedStorage'].should.equal(10) + db_instance['DBInstanceClass'].should.equal('db.m1.small') + db_instance['DBName'].should.equal('staging-postgres') + db_instance['OptionGroupMemberships'][0]['OptionGroupName'].should.equal('my-og') + + @mock_rds2 def test_stop_database(): conn = boto3.client('rds', region_name='us-west-2') @@ -205,6 +237,7 @@ def test_get_databases_paginated(): resp3 = conn.describe_db_instances(MaxRecords=100) resp3["DBInstances"].should.have.length_of(51) + @mock_rds2 def test_describe_non_existant_database(): conn = boto3.client('rds', region_name='us-west-2')