From 1a2fc66f84b527a5f8b65acf0f97a423d2d9925d Mon Sep 17 00:00:00 2001 From: Matthew Stevens Date: Mon, 1 Apr 2019 15:15:20 -0400 Subject: [PATCH 01/10] Adding dynamodb2 expression parser and fixing test cases --- moto/dynamodb2/comparisons.py | 1106 +++++++++++++++++++------ moto/dynamodb2/condition.py | 617 ++++++++++++++ moto/dynamodb2/models.py | 43 +- tests/test_dynamodb2/test_dynamodb.py | 63 +- 4 files changed, 1531 insertions(+), 298 deletions(-) create mode 100644 moto/dynamodb2/condition.py diff --git a/moto/dynamodb2/comparisons.py b/moto/dynamodb2/comparisons.py index 6d37345f..ac78d45b 100644 --- a/moto/dynamodb2/comparisons.py +++ b/moto/dynamodb2/comparisons.py @@ -1,6 +1,40 @@ from __future__ import unicode_literals import re import six +import re +import enum +from collections import deque +from collections import namedtuple + + +def get_filter_expression(expr, names, values): + """ + Parse a filter expression into an Op. + + Examples + expr = 'Id > 5 AND attribute_exists(test) AND Id BETWEEN 5 AND 6 OR length < 6 AND contains(test, 1) AND 5 IN (4,5, 6) OR (Id < 5 AND 5 > Id)' + expr = 'Id > 5 AND Subs < 7' + """ + parser = ConditionExpressionParser(expr, names, values) + return parser.parse() + + +class Op(object): + """ + Base class for a FilterExpression operator + """ + OP = '' + + def __init__(self, lhs, rhs): + self.lhs = lhs + self.rhs = rhs + + def expr(self, item): + raise NotImplementedError("Expr not defined for {0}".format(type(self))) + + def __repr__(self): + return '({0} {1} {2})'.format(self.lhs, self.OP, self.rhs) + # TODO add tests for all of these EQ_FUNCTION = lambda item_value, test_value: item_value == test_value # flake8: noqa @@ -49,292 +83,783 @@ class RecursionStopIteration(StopIteration): pass -def get_filter_expression(expr, names, values): - # Examples - # expr = 'Id > 5 AND attribute_exists(test) AND Id BETWEEN 5 AND 6 OR length < 6 AND contains(test, 1) AND 5 IN (4,5, 6) OR (Id < 5 AND 5 > Id)' - # expr = 'Id > 5 AND Subs < 7' - if names is None: - names = {} - if values is None: - values = {} +class ConditionExpressionParser: + def __init__(self, condition_expression, expression_attribute_names, + expression_attribute_values): + self.condition_expression = condition_expression + self.expression_attribute_names = expression_attribute_names + self.expression_attribute_values = expression_attribute_values - # Do substitutions - for key, value in names.items(): - expr = expr.replace(key, value) + def parse(self): + """Returns a syntax tree for the expression. - # Store correct types of values for use later - values_map = {} - for key, value in values.items(): - if 'N' in value: - values_map[key] = float(value['N']) - elif 'BOOL' in value: - values_map[key] = value['BOOL'] - elif 'S' in value: - values_map[key] = value['S'] - elif 'NS' in value: - values_map[key] = tuple(value['NS']) - elif 'SS' in value: - values_map[key] = tuple(value['SS']) - elif 'L' in value: - values_map[key] = tuple(value['L']) + The tree, and all of the nodes in the tree are a tuple of + - kind: str + - children/value: + list of nodes for parent nodes + value for leaf nodes + + Raises ValueError if the condition expression is invalid + Raises KeyError if expression attribute names/values are invalid + + Here are the types of nodes that can be returned. + The types of child nodes are denoted with a colon (:). + An arbitrary number of children is denoted with ... + + Condition: + ('OR', [lhs : Condition, rhs : Condition]) + ('AND', [lhs: Condition, rhs: Condition]) + ('NOT', [argument: Condition]) + ('PARENTHESES', [argument: Condition]) + ('FUNCTION', [('LITERAL', function_name: str), argument: Operand, ...]) + ('BETWEEN', [query: Operand, low: Operand, high: Operand]) + ('IN', [query: Operand, possible_value: Operand, ...]) + ('COMPARISON', [lhs: Operand, ('LITERAL', comparator: str), rhs: Operand]) + + Operand: + ('EXPRESSION_ATTRIBUTE_VALUE', value: dict, e.g. {'S': 'foobar'}) + ('PATH', [('LITERAL', path_element: str), ...]) + NOTE: Expression attribute names will be expanded + ('FUNCTION', [('LITERAL', 'size'), argument: Operand]) + + Literal: + ('LITERAL', value: str) + + """ + if not self.condition_expression: + return OpDefault(None, None) + nodes = self._lex_condition_expression() + nodes = self._parse_paths(nodes) + # NOTE: The docs say that functions should be parsed after + # IN, BETWEEN, and comparisons like <=. + # However, these expressions are invalid as function arguments, + # so it is okay to parse functions first. This needs to be done + # to interpret size() correctly as an operand. + nodes = self._apply_functions(nodes) + nodes = self._apply_comparator(nodes) + nodes = self._apply_in(nodes) + nodes = self._apply_between(nodes) + nodes = self._apply_parens_and_booleans(nodes) + node = nodes[0] + op = self._make_op_condition(node) + return op + + class Kind(enum.Enum): + """Defines types of nodes in the syntax tree.""" + + # Condition nodes + # --------------- + OR = enum.auto() + AND = enum.auto() + NOT = enum.auto() + PARENTHESES = enum.auto() + FUNCTION = enum.auto() + BETWEEN = enum.auto() + IN = enum.auto() + COMPARISON = enum.auto() + + # Operand nodes + # ------------- + EXPRESSION_ATTRIBUTE_VALUE = enum.auto() + PATH = enum.auto() + + # Literal nodes + # -------------- + LITERAL = enum.auto() + + + class Nonterminal(enum.Enum): + """Defines nonterminals for defining productions.""" + CONDITION = enum.auto() + OPERAND = enum.auto() + COMPARATOR = enum.auto() + FUNCTION_NAME = enum.auto() + IDENTIFIER = enum.auto() + AND = enum.auto() + OR = enum.auto() + NOT = enum.auto() + BETWEEN = enum.auto() + IN = enum.auto() + COMMA = enum.auto() + LEFT_PAREN = enum.auto() + RIGHT_PAREN = enum.auto() + WHITESPACE = enum.auto() + + + Node = namedtuple('Node', ['nonterminal', 'kind', 'text', 'value', 'children']) + + def _lex_condition_expression(self): + nodes = deque() + remaining_expression = self.condition_expression + while remaining_expression: + node, remaining_expression = \ + self._lex_one_node(remaining_expression) + if node.nonterminal == self.Nonterminal.WHITESPACE: + continue + nodes.append(node) + return nodes + + def _lex_one_node(self, remaining_expression): + # TODO: Handle indexing like [1] + attribute_regex = '(:|#)?[A-z0-9\-_]+' + patterns = [( + self.Nonterminal.WHITESPACE, re.compile('^ +') + ), ( + self.Nonterminal.COMPARATOR, re.compile( + '^(' + # Put long expressions first for greedy matching + '<>|' + '<=|' + '>=|' + '=|' + '<|' + '>)'), + ), ( + self.Nonterminal.OPERAND, re.compile( + '^' + + attribute_regex + '(\.' + attribute_regex + '|\[[0-9]\])*') + ), ( + self.Nonterminal.COMMA, re.compile('^,') + ), ( + self.Nonterminal.LEFT_PAREN, re.compile('^\(') + ), ( + self.Nonterminal.RIGHT_PAREN, re.compile('^\)') + )] + + for nonterminal, pattern in patterns: + match = pattern.match(remaining_expression) + if match: + match_text = match.group() + break else: - raise NotImplementedError() + raise ValueError("Cannot parse condition starting at: " + + remaining_expression) - # Remove all spaces, tbf we could just skip them in the next step. - # The number of known options is really small so we can do a fair bit of cheating - expr = list(expr.strip()) + value = match_text + node = self.Node( + nonterminal=nonterminal, + kind=self.Kind.LITERAL, + text=match_text, + value=match_text, + children=[]) - # DodgyTokenisation stage 1 - def is_value(val): - return val not in ('<', '>', '=', '(', ')') + remaining_expression = remaining_expression[len(match_text):] - def contains_keyword(val): - for kw in ('BETWEEN', 'IN', 'AND', 'OR', 'NOT'): - if kw in val: - return kw - return None + return node, remaining_expression - def is_function(val): - return val in ('attribute_exists', 'attribute_not_exists', 'attribute_type', 'begins_with', 'contains', 'size') + def _parse_paths(self, nodes): + output = deque() - # Does the main part of splitting between sections of characters - tokens = [] - stack = '' - while len(expr) > 0: - current_char = expr.pop(0) + while nodes: + node = nodes.popleft() - if current_char == ' ': - if len(stack) > 0: - tokens.append(stack) - stack = '' - elif current_char == ',': # Split params , - if len(stack) > 0: - tokens.append(stack) - stack = '' - elif is_value(current_char): - stack += current_char + if node.nonterminal == self.Nonterminal.OPERAND: + path = node.value.replace('[', '.[').split('.') + children = [ + self._parse_path_element(name) + for name in path] + if len(children) == 1: + child = children[0] + if child.nonterminal != self.Nonterminal.IDENTIFIER: + output.append(child) + continue + else: + for child in children: + self._assert( + child.nonterminal == self.Nonterminal.IDENTIFIER, + "Cannot use %s in path" % child.text, [node]) + output.append(self.Node( + nonterminal=self.Nonterminal.OPERAND, + kind=self.Kind.PATH, + text=node.text, + value=None, + children=children)) + else: + output.append(node) + return output - kw = contains_keyword(stack) - if kw is not None: - # We have a kw in the stack, could be AND or something like 5AND - tmp = stack.replace(kw, '') - if len(tmp) > 0: - tokens.append(tmp) - tokens.append(kw) - stack = '' + def _parse_path_element(self, name): + reserved = { + 'and': self.Nonterminal.AND, + 'or': self.Nonterminal.OR, + 'in': self.Nonterminal.IN, + 'between': self.Nonterminal.BETWEEN, + 'not': self.Nonterminal.NOT, + } + + functions = { + 'attribute_exists', + 'attribute_not_exists', + 'attribute_type', + 'begins_with', + 'contains', + 'size', + } + + + if name.lower() in reserved: + # e.g. AND + nonterminal = reserved[name.lower()] + return self.Node( + nonterminal=nonterminal, + kind=self.Kind.LITERAL, + text=name, + value=name, + children=[]) + elif name in functions: + # e.g. attribute_exists + return self.Node( + nonterminal=self.Nonterminal.FUNCTION_NAME, + kind=self.Kind.LITERAL, + text=name, + value=name, + children=[]) + elif name.startswith(':'): + # e.g. :value0 + return self.Node( + nonterminal=self.Nonterminal.OPERAND, + kind=self.Kind.EXPRESSION_ATTRIBUTE_VALUE, + text=name, + value=self._lookup_expression_attribute_value(name), + children=[]) + elif name.startswith('#'): + # e.g. #name0 + return self.Node( + nonterminal=self.Nonterminal.IDENTIFIER, + kind=self.Kind.LITERAL, + text=name, + value=self._lookup_expression_attribute_name(name), + children=[]) + elif name.startswith('['): + # e.g. [123] + if not name.endswith(']'): + raise ValueError("Bad path element %s" % name) + return self.Node( + nonterminal=self.Nonterminal.IDENTIFIER, + kind=self.Kind.LITERAL, + text=name, + value=int(name[1:-1]), + children=[]) else: - if len(stack) > 0: - tokens.append(stack) - tokens.append(current_char) - stack = '' - if len(stack) > 0: - tokens.append(stack) + # e.g. ItemId + return self.Node( + nonterminal=self.Nonterminal.IDENTIFIER, + kind=self.Kind.LITERAL, + text=name, + value=name, + children=[]) - def is_op(val): - return val in ('<', '>', '=', '>=', '<=', '<>', 'BETWEEN', 'IN', 'AND', 'OR', 'NOT') + def _lookup_expression_attribute_value(self, name): + return self.expression_attribute_values[name] - # DodgyTokenisation stage 2, it groups together some elements to make RPN'ing it later easier. - def handle_token(token, tokens2, token_iterator): - # ok so this essentially groups up some tokens to make later parsing easier, - # when it encounters brackets it will recurse and then unrecurse when RecursionStopIteration is raised. - if token == ')': - raise RecursionStopIteration() # Should be recursive so this should work - elif token == '(': - temp_list = [] + def _lookup_expression_attribute_name(self, name): + return self.expression_attribute_names[name] - try: + # NOTE: The following constructions are ordered from high precedence to low precedence + # according to + # https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Expressions.OperatorsAndFunctions.html#Expressions.OperatorsAndFunctions.Precedence + # + # = <> < <= > >= + # IN + # BETWEEN + # attribute_exists attribute_not_exists begins_with contains + # Parentheses + # NOT + # AND + # OR + # + # The grammar is taken from + # https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Expressions.OperatorsAndFunctions.html#Expressions.OperatorsAndFunctions.Syntax + # + # condition-expression ::= + # operand comparator operand + # operand BETWEEN operand AND operand + # operand IN ( operand (',' operand (, ...) )) + # function + # condition AND condition + # condition OR condition + # NOT condition + # ( condition ) + # + # comparator ::= + # = + # <> + # < + # <= + # > + # >= + # + # function ::= + # attribute_exists (path) + # attribute_not_exists (path) + # attribute_type (path, type) + # begins_with (path, substr) + # contains (path, operand) + # size (path) + + def _matches(self, nodes, production): + """Check if the nodes start with the given production. + + Parameters + ---------- + nodes: list of Node + production: list of str + The name of a Nonterminal, or '*' for anything + + """ + if len(nodes) < len(production): + return False + for i in range(len(production)): + if production[i] == '*': + continue + expected = getattr(self.Nonterminal, production[i]) + if nodes[i].nonterminal != expected: + return False + return True + + def _apply_comparator(self, nodes): + """Apply condition := operand comparator operand.""" + output = deque() + + while nodes: + if self._matches(nodes, ['*', 'COMPARATOR']): + self._assert( + self._matches(nodes, ['OPERAND', 'COMPARATOR', 'OPERAND']), + "Bad comparison", list(nodes)[:3]) + lhs = nodes.popleft() + comparator = nodes.popleft() + rhs = nodes.popleft() + output.append(self.Node( + nonterminal=self.Nonterminal.CONDITION, + kind=self.Kind.COMPARISON, + text=" ".join([ + lhs.text, + comparator.text, + rhs.text]), + value=None, + children=[lhs, comparator, rhs])) + else: + output.append(nodes.popleft()) + return output + + def _apply_in(self, nodes): + """Apply condition := operand IN ( operand , ... ).""" + output = deque() + while nodes: + if self._matches(nodes, ['*', 'IN']): + self._assert( + self._matches(nodes, ['OPERAND', 'IN', 'LEFT_PAREN']), + "Bad IN expression", list(nodes)[:3]) + lhs = nodes.popleft() + in_node = nodes.popleft() + left_paren = nodes.popleft() + all_children = [lhs, in_node, left_paren] + rhs = [] while True: - next_token = six.next(token_iterator) - handle_token(next_token, temp_list, token_iterator) - except RecursionStopIteration: - pass # Continue - except StopIteration: - ValueError('Malformed filter expression, type1') - - # Sigh, we only want to group a tuple if it doesnt contain operators - if any([is_op(item) for item in temp_list]): - # Its an expression - tokens2.append('(') - tokens2.extend(temp_list) - tokens2.append(')') + if self._matches(nodes, ['OPERAND', 'COMMA']): + operand = nodes.popleft() + separator = nodes.popleft() + all_children += [operand, separator] + rhs.append(operand) + elif self._matches(nodes, ['OPERAND', 'RIGHT_PAREN']): + operand = nodes.popleft() + separator = nodes.popleft() + all_children += [operand, separator] + rhs.append(operand) + break # Close + else: + self._assert( + False, + "Bad IN expression starting at", nodes) + output.append(self.Node( + nonterminal=self.Nonterminal.CONDITION, + kind=self.Kind.IN, + text=" ".join([t.text for t in all_children]), + value=None, + children=[lhs] + rhs)) else: - tokens2.append(tuple(temp_list)) - elif token == 'BETWEEN': - field = tokens2.pop() - # if values map contains a number, it would be a float - # so we need to int() it anyway - op1 = six.next(token_iterator) - op1 = int(values_map.get(op1, op1)) - and_op = six.next(token_iterator) - assert and_op == 'AND' - op2 = six.next(token_iterator) - op2 = int(values_map.get(op2, op2)) - tokens2.append(['between', field, op1, op2]) - elif is_function(token): - function_list = [token] + output.append(nodes.popleft()) + return output - lbracket = six.next(token_iterator) - assert lbracket == '(' - - next_token = six.next(token_iterator) - while next_token != ')': - if next_token in values_map: - next_token = values_map[next_token] - function_list.append(next_token) - next_token = six.next(token_iterator) - - tokens2.append(function_list) - else: - # Convert tokens back to real types - if token in values_map: - token = values_map[token] - - # Need to join >= <= <> - if len(tokens2) > 0 and ((tokens2[-1] == '>' and token == '=') or (tokens2[-1] == '<' and token == '=') or (tokens2[-1] == '<' and token == '>')): - tokens2.append(tokens2.pop() + token) + def _apply_between(self, nodes): + """Apply condition := operand BETWEEN operand AND operand.""" + output = deque() + while nodes: + if self._matches(nodes, ['*', 'BETWEEN']): + self._assert( + self._matches(nodes, ['OPERAND', 'BETWEEN', 'OPERAND', + 'AND', 'OPERAND']), + "Bad BETWEEN expression", list(nodes)[:5]) + lhs = nodes.popleft() + between_node = nodes.popleft() + low = nodes.popleft() + and_node = nodes.popleft() + high = nodes.popleft() + all_children = [lhs, between_node, low, and_node, high] + output.append(self.Node( + nonterminal=self.Nonterminal.CONDITION, + kind=self.Kind.BETWEEN, + text=" ".join([t.text for t in all_children]), + value=None, + children=[lhs, low, high])) else: - tokens2.append(token) + output.append(nodes.popleft()) + return output - tokens2 = [] - token_iterator = iter(tokens) - for token in token_iterator: - handle_token(token, tokens2, token_iterator) - - # Start of the Shunting-Yard algorithm. <-- Proper beast algorithm! - def is_number(val): - return val not in ('<', '>', '=', '>=', '<=', '<>', 'BETWEEN', 'IN', 'AND', 'OR', 'NOT') - - OPS = {'<': 5, '>': 5, '=': 5, '>=': 5, '<=': 5, '<>': 5, 'IN': 8, 'AND': 11, 'OR': 12, 'NOT': 10, 'BETWEEN': 9, '(': 100, ')': 100} - - def shunting_yard(token_list): - output = [] - op_stack = [] - - # Basically takes in an infix notation calculation, converts it to a reverse polish notation where there is no - # ambiguity on which order operators are applied. - while len(token_list) > 0: - token = token_list.pop(0) - - if token == '(': - op_stack.append(token) - elif token == ')': - while len(op_stack) > 0 and op_stack[-1] != '(': - output.append(op_stack.pop()) - lbracket = op_stack.pop() - assert lbracket == '(' - - elif is_number(token): - output.append(token) + def _apply_functions(self, nodes): + """Apply condition := function_name (operand , ...).""" + output = deque() + either_kind = {self.Kind.PATH, self.Kind.EXPRESSION_ATTRIBUTE_VALUE} + expected_argument_kind_map = { + 'attribute_exists': [{self.Kind.PATH}], + 'attribute_not_exists': [{self.Kind.PATH}], + 'attribute_type': [either_kind, {self.Kind.EXPRESSION_ATTRIBUTE_VALUE}], + 'begins_with': [either_kind, either_kind], + 'contains': [either_kind, either_kind], + 'size': [{self.Kind.PATH}], + } + while nodes: + if self._matches(nodes, ['FUNCTION_NAME']): + self._assert( + self._matches(nodes, ['FUNCTION_NAME', 'LEFT_PAREN', + 'OPERAND', '*']), + "Bad function expression at", list(nodes)[:4]) + function_name = nodes.popleft() + left_paren = nodes.popleft() + all_children = [function_name, left_paren] + arguments = [] + while True: + if self._matches(nodes, ['OPERAND', 'COMMA']): + operand = nodes.popleft() + separator = nodes.popleft() + all_children += [operand, separator] + arguments.append(operand) + elif self._matches(nodes, ['OPERAND', 'RIGHT_PAREN']): + operand = nodes.popleft() + separator = nodes.popleft() + all_children += [operand, separator] + arguments.append(operand) + break # Close paren + else: + self._assert( + False, + "Bad function expression", all_children + list(nodes)[:2]) + expected_kinds = expected_argument_kind_map[function_name.value] + self._assert( + len(arguments) == len(expected_kinds), + "Wrong number of arguments in", all_children) + for i in range(len(expected_kinds)): + self._assert( + arguments[i].kind in expected_kinds[i], + "Wrong type for argument %d in" % i, all_children) + if function_name.value == 'size': + nonterminal = self.Nonterminal.OPERAND + else: + nonterminal = self.Nonterminal.CONDITION + output.append(self.Node( + nonterminal=nonterminal, + kind=self.Kind.FUNCTION, + text=" ".join([t.text for t in all_children]), + value=None, + children=[function_name] + arguments)) else: - # Must be operator kw + output.append(nodes.popleft()) + return output - # Cheat, NOT is our only RIGHT associative operator, should really have dict of operator associativity - while len(op_stack) > 0 and OPS[op_stack[-1]] <= OPS[token] and op_stack[-1] != 'NOT': - output.append(op_stack.pop()) - op_stack.append(token) - while len(op_stack) > 0: - output.append(op_stack.pop()) + def _apply_parens_and_booleans(self, nodes, left_paren=None): + """Apply condition := ( condition ) and booleans.""" + output = deque() + while nodes: + if self._matches(nodes, ['LEFT_PAREN']): + parsed = self._apply_parens_and_booleans(nodes, left_paren=nodes.popleft()) + self._assert( + len(parsed) >= 1, + "Failed to close parentheses at", nodes) + parens = parsed.popleft() + self._assert( + parens.kind == self.Kind.PARENTHESES, + "Failed to close parentheses at", nodes) + output.append(parens) + nodes = parsed + elif self._matches(nodes, ['RIGHT_PAREN']): + self._assert( + left_paren is not None, + "Unmatched ) at", nodes) + close_paren = nodes.popleft() + children = self._apply_booleans(output) + all_children = [left_paren, *children, close_paren] + return deque([ + self.Node( + nonterminal=self.Nonterminal.CONDITION, + kind=self.Kind.PARENTHESES, + text=" ".join([t.text for t in all_children]), + value=None, + children=list(children), + ), *nodes]) + else: + output.append(nodes.popleft()) + + self._assert( + left_paren is None, + "Unmatched ( at", list(output)) + return self._apply_booleans(output) + + def _apply_booleans(self, nodes): + """Apply and, or, and not constructions.""" + nodes = self._apply_not(nodes) + nodes = self._apply_and(nodes) + nodes = self._apply_or(nodes) + # The expression should reduce to a single condition + self._assert( + len(nodes) == 1, + "Unexpected expression at", list(nodes)[1:]) + self._assert( + nodes[0].nonterminal == self.Nonterminal.CONDITION, + "Incomplete condition", nodes) + return nodes + + def _apply_not(self, nodes): + """Apply condition := NOT condition.""" + output = deque() + while nodes: + if self._matches(nodes, ['NOT']): + self._assert( + self._matches(nodes, ['NOT', 'CONDITION']), + "Bad NOT expression", list(nodes)[:2]) + not_node = nodes.popleft() + child = nodes.popleft() + output.append(self.Node( + nonterminal=self.Nonterminal.CONDITION, + kind=self.Kind.NOT, + text=" ".join([not_node.text, child.text]), + value=None, + children=[child])) + else: + output.append(nodes.popleft()) return output - output = shunting_yard(tokens2) - - # Hacky function to convert dynamo functions (which are represented as lists) to their Class equivalent - def to_func(val): - if isinstance(val, list): - func_name = val.pop(0) - # Expand rest of the list to arguments - val = FUNC_CLASS[func_name](*val) - - return val - - # Simple reverse polish notation execution. Builts up a nested filter object. - # The filter object then takes a dynamo item and returns true/false - stack = [] - for token in output: - if is_op(token): - op_cls = OP_CLASS[token] - - if token == 'NOT': - op1 = stack.pop() - op2 = True + def _apply_and(self, nodes): + """Apply condition := condition AND condition.""" + output = deque() + while nodes: + if self._matches(nodes, ['*', 'AND']): + self._assert( + self._matches(nodes, ['CONDITION', 'AND', 'CONDITION']), + "Bad AND expression", list(nodes)[:3]) + lhs = nodes.popleft() + and_node = nodes.popleft() + rhs = nodes.popleft() + all_children = [lhs, and_node, rhs] + output.append(self.Node( + nonterminal=self.Nonterminal.CONDITION, + kind=self.Kind.AND, + text=" ".join([t.text for t in all_children]), + value=None, + children=[lhs, rhs])) else: - op2 = stack.pop() - op1 = stack.pop() + output.append(nodes.popleft()) - stack.append(op_cls(op1, op2)) + return output + + def _apply_or(self, nodes): + """Apply condition := condition OR condition.""" + output = deque() + while nodes: + if self._matches(nodes, ['*', 'OR']): + self._assert( + self._matches(nodes, ['CONDITION', 'OR', 'CONDITION']), + "Bad OR expression", list(nodes)[:3]) + lhs = nodes.popleft() + or_node = nodes.popleft() + rhs = nodes.popleft() + all_children = [lhs, or_node, rhs] + output.append(self.Node( + nonterminal=self.Nonterminal.CONDITION, + kind=self.Kind.OR, + text=" ".join([t.text for t in all_children]), + value=None, + children=[lhs, rhs])) + else: + output.append(nodes.popleft()) + + return output + + def _make_operand(self, node): + if node.kind == self.Kind.PATH: + return AttributePath([child.value for child in node.children]) + elif node.kind == self.Kind.EXPRESSION_ATTRIBUTE_VALUE: + return AttributeValue(node.value) + elif node.kind == self.Kind.FUNCTION: + # size() + function_node, *arguments = node.children + function_name = function_node.value + arguments = [self._make_operand(arg) for arg in arguments] + return FUNC_CLASS[function_name](*arguments) else: - stack.append(to_func(token)) - - result = stack.pop(0) - if len(stack) > 0: - raise ValueError('Malformed filter expression, type2') - - return result + raise ValueError("Unknown operand: %r" % node) -class Op(object): - """ - Base class for a FilterExpression operator - """ - OP = '' + def _make_op_condition(self, node): + if node.kind == self.Kind.OR: + lhs, rhs = node.children + return OpOr( + self._make_op_condition(lhs), + self._make_op_condition(rhs)) + elif node.kind == self.Kind.AND: + lhs, rhs = node.children + return OpAnd( + self._make_op_condition(lhs), + self._make_op_condition(rhs)) + elif node.kind == self.Kind.NOT: + child, = node.children + return OpNot(self._make_op_condition(child), None) + elif node.kind == self.Kind.PARENTHESES: + child, = node.children + return self._make_op_condition(child) + elif node.kind == self.Kind.FUNCTION: + function_node, *arguments = node.children + function_name = function_node.value + arguments = [self._make_operand(arg) for arg in arguments] + return FUNC_CLASS[function_name](*arguments) + elif node.kind == self.Kind.BETWEEN: + query, low, high = node.children + return FuncBetween( + self._make_operand(query), + self._make_operand(low), + self._make_operand(high)) + elif node.kind == self.Kind.IN: + query, *possible_values = node.children + query = self._make_operand(query) + possible_values = [self._make_operand(v) for v in possible_values] + return FuncIn(query, *possible_values) + elif node.kind == self.Kind.COMPARISON: + lhs, comparator, rhs = node.children + return OP_CLASS[comparator.value]( + self._make_operand(lhs), + self._make_operand(rhs)) + else: + raise ValueError("Unknown expression node kind %r" % node.kind) - def __init__(self, lhs, rhs): - self.lhs = lhs - self.rhs = rhs + def _print_debug(self, nodes): + print('ROOT') + for node in nodes: + self._print_node_recursive(node, depth=1) + + def _print_node_recursive(self, node, depth=0): + if len(node.children) > 0: + print(' ' * depth, node.nonterminal, node.kind) + for child in node.children: + self._print_node_recursive(child, depth=depth + 1) + else: + print(' ' * depth, node.nonterminal, node.kind, node.value) + + + + def _assert(self, condition, message, nodes): + if not condition: + raise ValueError(message + " " + " ".join([t.text for t in nodes])) + + +class Operand(object): + def expr(self, item): + raise NotImplementedError + + def get_type(self, item): + raise NotImplementedError + + +class AttributePath(Operand): + def __init__(self, path): + """Initialize the AttributePath. + + Parameters + ---------- + path: list of int/str - def _lhs(self, item): """ - :type item: moto.dynamodb2.models.Item - """ - lhs = self.lhs - if isinstance(self.lhs, (Op, Func)): - lhs = self.lhs.expr(item) - elif isinstance(self.lhs, six.string_types): - try: - lhs = item.attrs[self.lhs].cast_value - except Exception: - pass + assert len(path) >= 1 + self.path = path - return lhs - - def _rhs(self, item): - rhs = self.rhs - if isinstance(self.rhs, (Op, Func)): - rhs = self.rhs.expr(item) - elif isinstance(self.rhs, six.string_types): - try: - rhs = item.attrs[self.rhs].cast_value - except Exception: - pass - return rhs + def _get_attr(self, item): + base = self.path[0] + if base not in item.attrs: + return None + attr = item.attrs[base] + for name in self.path[1:]: + attr = attr.child_attr(name) + if attr is None: + return None + return attr def expr(self, item): - return True + attr = self._get_attr(item) + if attr is None: + return None + else: + return attr.cast_value + + def get_type(self, item): + attr = self._get_attr(item) + if attr is None: + return None + else: + return attr.type def __repr__(self): - return '({0} {1} {2})'.format(self.lhs, self.OP, self.rhs) + return self.path -class Func(object): - """ - Base class for a FilterExpression function - """ - FUNC = 'Unknown' +class AttributeValue(Operand): + def __init__(self, value): + """Initialize the AttributePath. + + Parameters + ---------- + value: dict + e.g. {'N': '1.234'} + + """ + self.type = list(value.keys())[0] + if 'N' in value: + self.value = float(value['N']) + elif 'BOOL' in value: + self.value = value['BOOL'] + elif 'S' in value: + self.value = value['S'] + elif 'NS' in value: + self.value = tuple(value['NS']) + elif 'SS' in value: + self.value = tuple(value['SS']) + elif 'L' in value: + self.value = tuple(value['L']) + else: + # TODO: Handle all attribute types + raise NotImplementedError() def expr(self, item): - return True + return self.value + + def get_type(self, item): + return self.type def __repr__(self): - return 'Func(...)'.format(self.FUNC) + return repr(self.value) + + +class OpDefault(Op): + OP = 'NONE' + + def expr(self, item): + """If no condition is specified, always True.""" + return True class OpNot(Op): OP = 'NOT' def expr(self, item): - lhs = self._lhs(item) - + lhs = self.lhs.expr(item) return not lhs def __str__(self): @@ -345,8 +870,8 @@ class OpAnd(Op): OP = 'AND' def expr(self, item): - lhs = self._lhs(item) - rhs = self._rhs(item) + lhs = self.lhs.expr(item) + rhs = self.rhs.expr(item) return lhs and rhs @@ -354,8 +879,8 @@ class OpLessThan(Op): OP = '<' def expr(self, item): - lhs = self._lhs(item) - rhs = self._rhs(item) + lhs = self.lhs.expr(item) + rhs = self.rhs.expr(item) return lhs < rhs @@ -363,8 +888,8 @@ class OpGreaterThan(Op): OP = '>' def expr(self, item): - lhs = self._lhs(item) - rhs = self._rhs(item) + lhs = self.lhs.expr(item) + rhs = self.rhs.expr(item) return lhs > rhs @@ -372,8 +897,8 @@ class OpEqual(Op): OP = '=' def expr(self, item): - lhs = self._lhs(item) - rhs = self._rhs(item) + lhs = self.lhs.expr(item) + rhs = self.rhs.expr(item) return lhs == rhs @@ -381,8 +906,8 @@ class OpNotEqual(Op): OP = '<>' def expr(self, item): - lhs = self._lhs(item) - rhs = self._rhs(item) + lhs = self.lhs.expr(item) + rhs = self.rhs.expr(item) return lhs != rhs @@ -390,8 +915,8 @@ class OpLessThanOrEqual(Op): OP = '<=' def expr(self, item): - lhs = self._lhs(item) - rhs = self._rhs(item) + lhs = self.lhs.expr(item) + rhs = self.rhs.expr(item) return lhs <= rhs @@ -399,8 +924,8 @@ class OpGreaterThanOrEqual(Op): OP = '>=' def expr(self, item): - lhs = self._lhs(item) - rhs = self._rhs(item) + lhs = self.lhs.expr(item) + rhs = self.rhs.expr(item) return lhs >= rhs @@ -408,8 +933,8 @@ class OpOr(Op): OP = 'OR' def expr(self, item): - lhs = self._lhs(item) - rhs = self._rhs(item) + lhs = self.lhs.expr(item) + rhs = self.rhs.expr(item) return lhs or rhs @@ -417,19 +942,38 @@ class OpIn(Op): OP = 'IN' def expr(self, item): - lhs = self._lhs(item) - rhs = self._rhs(item) + lhs = self.lhs.expr(item) + rhs = self.rhs.expr(item) return lhs in rhs +class Func(object): + """ + Base class for a FilterExpression function + """ + FUNC = 'Unknown' + + def __init__(self, *arguments): + self.arguments = arguments + + def expr(self, item): + raise NotImplementedError + + def __repr__(self): + return '{0}({1})'.format( + self.FUNC, + " ".join([repr(arg) for arg in self.arguments])) + + class FuncAttrExists(Func): FUNC = 'attribute_exists' def __init__(self, attribute): self.attr = attribute + super().__init__(attribute) def expr(self, item): - return self.attr in item.attrs + return self.attr.get_type(item) is not None class FuncAttrNotExists(Func): @@ -437,9 +981,10 @@ class FuncAttrNotExists(Func): def __init__(self, attribute): self.attr = attribute + super().__init__(attribute) def expr(self, item): - return self.attr not in item.attrs + return self.attr.get_type(item) is None class FuncAttrType(Func): @@ -448,9 +993,10 @@ class FuncAttrType(Func): def __init__(self, attribute, _type): self.attr = attribute self.type = _type + super().__init__(attribute, _type) def expr(self, item): - return self.attr in item.attrs and item.attrs[self.attr].type == self.type + return self.attr.get_type(item) == self.type.expr(item) class FuncBeginsWith(Func): @@ -459,9 +1005,14 @@ class FuncBeginsWith(Func): def __init__(self, attribute, substr): self.attr = attribute self.substr = substr + super().__init__(attribute, substr) def expr(self, item): - return self.attr in item.attrs and item.attrs[self.attr].type == 'S' and item.attrs[self.attr].value.startswith(self.substr) + if self.attr.get_type(item) != 'S': + return False + if self.substr.get_type(item) != 'S': + return False + return self.attr.expr(item).startswith(self.substr.expr(item)) class FuncContains(Func): @@ -470,13 +1021,11 @@ class FuncContains(Func): def __init__(self, attribute, operand): self.attr = attribute self.operand = operand + super().__init__(attribute, operand) def expr(self, item): - if self.attr not in item.attrs: - return False - - if item.attrs[self.attr].type in ('S', 'SS', 'NS', 'BS', 'L', 'M'): - return self.operand in item.attrs[self.attr].value + if self.attr.get_type(item) in ('S', 'SS', 'NS', 'BS', 'L', 'M'): + return self.operand.expr(item) in self.attr.expr(item) return False @@ -485,29 +1034,44 @@ class FuncSize(Func): def __init__(self, attribute): self.attr = attribute + super().__init__(attribute) def expr(self, item): - if self.attr not in item.attrs: + if self.attr.get_type(item) is None: raise ValueError('Invalid attribute name {0}'.format(self.attr)) - if item.attrs[self.attr].type in ('S', 'SS', 'NS', 'B', 'BS', 'L', 'M'): - return len(item.attrs[self.attr].value) + if self.attr.get_type(item) in ('S', 'SS', 'NS', 'B', 'BS', 'L', 'M'): + return len(self.attr.expr(item)) raise ValueError('Invalid filter expression') class FuncBetween(Func): - FUNC = 'between' + FUNC = 'BETWEEN' def __init__(self, attribute, start, end): self.attr = attribute self.start = start self.end = end + super().__init__(attribute, start, end) def expr(self, item): - if self.attr not in item.attrs: - raise ValueError('Invalid attribute name {0}'.format(self.attr)) + return self.start.expr(item) <= self.attr.expr(item) <= self.end.expr(item) - return self.start <= item.attrs[self.attr].cast_value <= self.end + +class FuncIn(Func): + FUNC = 'IN' + + def __init__(self, attribute, *possible_values): + self.attr = attribute + self.possible_values = possible_values + super().__init__(attribute, *possible_values) + + def expr(self, item): + for possible_value in self.possible_values: + if self.attr.expr(item) == possible_value.expr(item): + return True + + return False OP_CLASS = { diff --git a/moto/dynamodb2/condition.py b/moto/dynamodb2/condition.py new file mode 100644 index 00000000..b50678e2 --- /dev/null +++ b/moto/dynamodb2/condition.py @@ -0,0 +1,617 @@ +import re +import json +import enum +from collections import deque +from collections import namedtuple + + +class Kind(enum.Enum): + """Defines types of nodes in the syntax tree.""" + + # Condition nodes + # --------------- + OR = enum.auto() + AND = enum.auto() + NOT = enum.auto() + PARENTHESES = enum.auto() + FUNCTION = enum.auto() + BETWEEN = enum.auto() + IN = enum.auto() + COMPARISON = enum.auto() + + # Operand nodes + # ------------- + EXPRESSION_ATTRIBUTE_VALUE = enum.auto() + PATH = enum.auto() + + # Literal nodes + # -------------- + LITERAL = enum.auto() + + +class Nonterminal(enum.Enum): + """Defines nonterminals for defining productions.""" + CONDITION = enum.auto() + OPERAND = enum.auto() + COMPARATOR = enum.auto() + FUNCTION_NAME = enum.auto() + IDENTIFIER = enum.auto() + AND = enum.auto() + OR = enum.auto() + NOT = enum.auto() + BETWEEN = enum.auto() + IN = enum.auto() + COMMA = enum.auto() + LEFT_PAREN = enum.auto() + RIGHT_PAREN = enum.auto() + WHITESPACE = enum.auto() + + +Node = namedtuple('Node', ['nonterminal', 'kind', 'text', 'value', 'children']) + + +class ConditionExpressionParser: + def __init__(self, condition_expression, expression_attribute_names, + expression_attribute_values): + self.condition_expression = condition_expression + self.expression_attribute_names = expression_attribute_names + self.expression_attribute_values = expression_attribute_values + + def parse(self): + """Returns a syntax tree for the expression. + + The tree, and all of the nodes in the tree are a tuple of + - kind: str + - children/value: + list of nodes for parent nodes + value for leaf nodes + + Raises AssertionError if the condition expression is invalid + Raises KeyError if expression attribute names/values are invalid + + Here are the types of nodes that can be returned. + The types of child nodes are denoted with a colon (:). + An arbitrary number of children is denoted with ... + + Condition: + ('OR', [lhs : Condition, rhs : Condition]) + ('AND', [lhs: Condition, rhs: Condition]) + ('NOT', [argument: Condition]) + ('PARENTHESES', [argument: Condition]) + ('FUNCTION', [('LITERAL', function_name: str), argument: Operand, ...]) + ('BETWEEN', [query: Operand, low: Operand, high: Operand]) + ('IN', [query: Operand, possible_value: Operand, ...]) + ('COMPARISON', [lhs: Operand, ('LITERAL', comparator: str), rhs: Operand]) + + Operand: + ('EXPRESSION_ATTRIBUTE_VALUE', value: dict, e.g. {'S': 'foobar'}) + ('PATH', [('LITERAL', path_element: str), ...]) + NOTE: Expression attribute names will be expanded + + Literal: + ('LITERAL', value: str) + + """ + if not self.condition_expression: + return None + nodes = self._lex_condition_expression() + nodes = self._parse_paths(nodes) + self._print_debug(nodes) + nodes = self._apply_comparator(nodes) + self._print_debug(nodes) + nodes = self._apply_in(nodes) + self._print_debug(nodes) + nodes = self._apply_between(nodes) + self._print_debug(nodes) + nodes = self._apply_functions(nodes) + self._print_debug(nodes) + nodes = self._apply_parens_and_booleans(nodes) + self._print_debug(nodes) + node = nodes[0] + return self._make_node_tree(node) + + def _lex_condition_expression(self): + nodes = deque() + remaining_expression = self.condition_expression + while remaining_expression: + node, remaining_expression = \ + self._lex_one_node(remaining_expression) + if node.nonterminal == Nonterminal.WHITESPACE: + continue + nodes.append(node) + return nodes + + def _lex_one_node(self, remaining_expression): + + attribute_regex = '(:|#)?[A-z0-9\-_]+' + patterns = [( + Nonterminal.WHITESPACE, re.compile('^ +') + ), ( + Nonterminal.COMPARATOR, re.compile( + '^(' + '=|' + '<>|' + '<|' + '<=|' + '>|' + '>=)'), + ), ( + Nonterminal.OPERAND, re.compile( + '^' + + attribute_regex + '(\.' + attribute_regex + ')*') + ), ( + Nonterminal.COMMA, re.compile('^,') + ), ( + Nonterminal.LEFT_PAREN, re.compile('^\(') + ), ( + Nonterminal.RIGHT_PAREN, re.compile('^\)') + )] + + for nonterminal, pattern in patterns: + match = pattern.match(remaining_expression) + if match: + match_text = match.group() + break + else: + raise AssertionError("Cannot parse condition starting at: " + + remaining_expression) + + value = match_text + node = Node( + nonterminal=nonterminal, + kind=Kind.LITERAL, + text=match_text, + value=match_text, + children=[]) + + remaining_expression = remaining_expression[len(match_text):] + + return node, remaining_expression + + def _parse_paths(self, nodes): + output = deque() + + while nodes: + node = nodes.popleft() + + if node.nonterminal == Nonterminal.OPERAND: + path = node.value.split('.') + children = [ + self._parse_path_element(name) + for name in path] + if len(children) == 1: + child = children[0] + if child.nonterminal != Nonterminal.IDENTIFIER: + output.append(child) + continue + else: + for child in children: + self._assert( + child.nonterminal == Nonterminal.IDENTIFIER, + "Cannot use %s in path" % child.text, [node]) + output.append(Node( + nonterminal=Nonterminal.OPERAND, + kind=Kind.PATH, + text=node.text, + value=None, + children=children)) + else: + output.append(node) + return output + + def _parse_path_element(self, name): + reserved = { + 'AND': Nonterminal.AND, + 'OR': Nonterminal.OR, + 'IN': Nonterminal.IN, + 'BETWEEN': Nonterminal.BETWEEN, + 'NOT': Nonterminal.NOT, + } + + functions = { + 'attribute_exists', + 'attribute_not_exists', + 'attribute_type', + 'begins_with', + 'contains', + 'size', + } + + + if name in reserved: + nonterminal = reserved[name] + return Node( + nonterminal=nonterminal, + kind=Kind.LITERAL, + text=name, + value=name, + children=[]) + elif name in functions: + return Node( + nonterminal=Nonterminal.FUNCTION_NAME, + kind=Kind.LITERAL, + text=name, + value=name, + children=[]) + elif name.startswith(':'): + return Node( + nonterminal=Nonterminal.OPERAND, + kind=Kind.EXPRESSION_ATTRIBUTE_VALUE, + text=name, + value=self._lookup_expression_attribute_value(name), + children=[]) + elif name.startswith('#'): + return Node( + nonterminal=Nonterminal.IDENTIFIER, + kind=Kind.LITERAL, + text=name, + value=self._lookup_expression_attribute_name(name), + children=[]) + else: + return Node( + nonterminal=Nonterminal.IDENTIFIER, + kind=Kind.LITERAL, + text=name, + value=name, + children=[]) + + def _lookup_expression_attribute_value(self, name): + return self.expression_attribute_values[name] + + def _lookup_expression_attribute_name(self, name): + return self.expression_attribute_names[name] + + # NOTE: The following constructions are ordered from high precedence to low precedence + # according to + # https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Expressions.OperatorsAndFunctions.html#Expressions.OperatorsAndFunctions.Precedence + # + # = <> < <= > >= + # IN + # BETWEEN + # attribute_exists attribute_not_exists begins_with contains + # Parentheses + # NOT + # AND + # OR + # + # The grammar is taken from + # https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Expressions.OperatorsAndFunctions.html#Expressions.OperatorsAndFunctions.Syntax + # + # condition-expression ::= + # operand comparator operand + # operand BETWEEN operand AND operand + # operand IN ( operand (',' operand (, ...) )) + # function + # condition AND condition + # condition OR condition + # NOT condition + # ( condition ) + # + # comparator ::= + # = + # <> + # < + # <= + # > + # >= + # + # function ::= + # attribute_exists (path) + # attribute_not_exists (path) + # attribute_type (path, type) + # begins_with (path, substr) + # contains (path, operand) + # size (path) + + def _matches(self, nodes, production): + """Check if the nodes start with the given production. + + Parameters + ---------- + nodes: list of Node + production: list of str + The name of a Nonterminal, or '*' for anything + + """ + if len(nodes) < len(production): + return False + for i in range(len(production)): + if production[i] == '*': + continue + expected = getattr(Nonterminal, production[i]) + if nodes[i].nonterminal != expected: + return False + return True + + def _apply_comparator(self, nodes): + """Apply condition := operand comparator operand.""" + output = deque() + + while nodes: + if self._matches(nodes, ['*', 'COMPARATOR']): + self._assert( + self._matches(nodes, ['OPERAND', 'COMPARATOR', 'OPERAND']), + "Bad comparison", list(nodes)[:3]) + lhs = nodes.popleft() + comparator = nodes.popleft() + rhs = nodes.popleft() + output.append(Node( + nonterminal=Nonterminal.CONDITION, + kind=Kind.COMPARISON, + text=" ".join([ + lhs.text, + comparator.text, + rhs.text]), + value=None, + children=[lhs, comparator, rhs])) + else: + output.append(nodes.popleft()) + return output + + def _apply_in(self, nodes): + """Apply condition := operand IN ( operand , ... ).""" + output = deque() + while nodes: + if self._matches(nodes, ['*', 'IN']): + self._assert( + self._matches(nodes, ['OPERAND', 'IN', 'LEFT_PAREN']), + "Bad IN expression", list(nodes)[:3]) + lhs = nodes.popleft() + in_node = nodes.popleft() + left_paren = nodes.popleft() + all_children = [lhs, in_node, left_paren] + rhs = [] + while True: + if self._matches(nodes, ['OPERAND', 'COMMA']): + operand = nodes.popleft() + separator = nodes.popleft() + all_children += [operand, separator] + rhs.append(operand) + elif self._matches(nodes, ['OPERAND', 'RIGHT_PAREN']): + operand = nodes.popleft() + separator = nodes.popleft() + all_children += [operand, separator] + rhs.append(operand) + break # Close + else: + self._assert( + False, + "Bad IN expression starting at", nodes) + output.append(Node( + nonterminal=Nonterminal.CONDITION, + kind=Kind.IN, + text=" ".join([t.text for t in all_children]), + value=None, + children=[lhs] + rhs)) + else: + output.append(nodes.popleft()) + return output + + def _apply_between(self, nodes): + """Apply condition := operand BETWEEN operand AND operand.""" + output = deque() + while nodes: + if self._matches(nodes, ['*', 'BETWEEN']): + self._assert( + self._matches(nodes, ['OPERAND', 'BETWEEN', 'OPERAND', + 'AND', 'OPERAND']), + "Bad BETWEEN expression", list(nodes)[:5]) + lhs = nodes.popleft() + between_node = nodes.popleft() + low = nodes.popleft() + and_node = nodes.popleft() + high = nodes.popleft() + all_children = [lhs, between_node, low, and_node, high] + output.append(Node( + nonterminal=Nonterminal.CONDITION, + kind=Kind.BETWEEN, + text=" ".join([t.text for t in all_children]), + value=None, + children=[lhs, low, high])) + else: + output.append(nodes.popleft()) + return output + + def _apply_functions(self, nodes): + """Apply condition := function_name (operand , ...).""" + output = deque() + expected_argument_kind_map = { + 'attribute_exists': [{Kind.PATH}], + 'attribute_not_exists': [{Kind.PATH}], + 'attribute_type': [{Kind.PATH}, {Kind.EXPRESSION_ATTRIBUTE_VALUE}], + 'begins_with': [{Kind.PATH}, {Kind.EXPRESSION_ATTRIBUTE_VALUE}], + 'contains': [{Kind.PATH}, {Kind.PATH, Kind.EXPRESSION_ATTRIBUTE_VALUE}], + 'size': [{Kind.PATH}], + } + while nodes: + if self._matches(nodes, ['FUNCTION_NAME']): + self._assert( + self._matches(nodes, ['FUNCTION_NAME', 'LEFT_PAREN', + 'OPERAND', '*']), + "Bad function expression at", list(nodes)[:4]) + function_name = nodes.popleft() + left_paren = nodes.popleft() + all_children = [function_name, left_paren] + arguments = [] + while True: + if self._matches(nodes, ['OPERAND', 'COMMA']): + operand = nodes.popleft() + separator = nodes.popleft() + all_children += [operand, separator] + arguments.append(operand) + elif self._matches(nodes, ['OPERAND', 'RIGHT_PAREN']): + operand = nodes.popleft() + separator = nodes.popleft() + all_children += [operand, separator] + arguments.append(operand) + break # Close paren + else: + self._assert( + False, + "Bad function expression", all_children + list(nodes)[:2]) + expected_kinds = expected_argument_kind_map[function_name.value] + self._assert( + len(arguments) == len(expected_kinds), + "Wrong number of arguments in", all_children) + for i in range(len(expected_kinds)): + self._assert( + arguments[i].kind in expected_kinds[i], + "Wrong type for argument %d in" % i, all_children) + output.append(Node( + nonterminal=Nonterminal.CONDITION, + kind=Kind.FUNCTION, + text=" ".join([t.text for t in all_children]), + value=None, + children=[function_name] + arguments)) + else: + output.append(nodes.popleft()) + return output + + def _apply_parens_and_booleans(self, nodes, left_paren=None): + """Apply condition := ( condition ) and booleans.""" + output = deque() + while nodes: + if self._matches(nodes, ['LEFT_PAREN']): + parsed = self._apply_parens_and_booleans(nodes, left_paren=nodes.popleft()) + self._assert( + len(parsed) >= 1, + "Failed to close parentheses at", nodes) + parens = parsed.popleft() + self._assert( + parens.kind == Kind.PARENTHESES, + "Failed to close parentheses at", nodes) + output.append(parens) + nodes = parsed + elif self._matches(nodes, ['RIGHT_PAREN']): + self._assert( + left_paren is not None, + "Unmatched ) at", nodes) + close_paren = nodes.popleft() + children = self._apply_booleans(output) + all_children = [left_paren, *children, close_paren] + return deque([ + Node( + nonterminal=Nonterminal.CONDITION, + kind=Kind.PARENTHESES, + text=" ".join([t.text for t in all_children]), + value=None, + children=list(children), + ), *nodes]) + else: + output.append(nodes.popleft()) + + self._assert( + left_paren is None, + "Unmatched ( at", list(output)) + return self._apply_booleans(output) + + def _apply_booleans(self, nodes): + """Apply and, or, and not constructions.""" + nodes = self._apply_not(nodes) + nodes = self._apply_and(nodes) + nodes = self._apply_or(nodes) + # The expression should reduce to a single condition + self._assert( + len(nodes) == 1, + "Unexpected expression at", list(nodes)[1:]) + self._assert( + nodes[0].nonterminal == Nonterminal.CONDITION, + "Incomplete condition", nodes) + return nodes + + def _apply_not(self, nodes): + """Apply condition := NOT condition.""" + output = deque() + while nodes: + if self._matches(nodes, ['NOT']): + self._assert( + self._matches(nodes, ['NOT', 'CONDITION']), + "Bad NOT expression", list(nodes)[:2]) + not_node = nodes.popleft() + child = nodes.popleft() + output.append(Node( + nonterminal=Nonterminal.CONDITION, + kind=Kind.NOT, + text=" ".join([not_node['text'], value['text']]), + value=None, + children=[child])) + else: + output.append(nodes.popleft()) + + return output + + def _apply_and(self, nodes): + """Apply condition := condition AND condition.""" + output = deque() + while nodes: + if self._matches(nodes, ['*', 'AND']): + self._assert( + self._matches(nodes, ['CONDITION', 'AND', 'CONDITION']), + "Bad AND expression", list(nodes)[:3]) + lhs = nodes.popleft() + and_node = nodes.popleft() + rhs = nodes.popleft() + all_children = [lhs, and_node, rhs] + output.append(Node( + nonterminal=Nonterminal.CONDITION, + kind=Kind.AND, + text=" ".join([t.text for t in all_children]), + value=None, + children=[lhs, rhs])) + else: + output.append(nodes.popleft()) + + return output + + def _apply_or(self, nodes): + """Apply condition := condition OR condition.""" + output = deque() + while nodes: + if self._matches(nodes, ['*', 'OR']): + self._assert( + self._matches(nodes, ['CONDITION', 'OR', 'CONDITION']), + "Bad OR expression", list(nodes)[:3]) + lhs = nodes.popleft() + or_node = nodes.popleft() + rhs = nodes.popleft() + all_children = [lhs, or_node, rhs] + output.append(Node( + nonterminal=Nonterminal.CONDITION, + kind=Kind.OR, + text=" ".join([t.text for t in all_children]), + value=None, + children=[lhs, rhs])) + else: + output.append(nodes.popleft()) + + return output + + def _make_node_tree(self, node): + if len(node.children) > 0: + return ( + node.kind.name, + [ + self._make_node_tree(child) + for child in node.children + ]) + else: + return (node.kind.name, node.value) + + def _print_debug(self, nodes): + print('ROOT') + for node in nodes: + self._print_node_recursive(node, depth=1) + + def _print_node_recursive(self, node, depth=0): + if len(node.children) > 0: + print(' ' * depth, node.nonterminal, node.kind) + for child in node.children: + self._print_node_recursive(child, depth=depth + 1) + else: + print(' ' * depth, node.nonterminal, node.kind, node.value) + + + + def _assert(self, condition, message, nodes): + if not condition: + raise AssertionError(message + " " + " ".join([t.text for t in nodes])) diff --git a/moto/dynamodb2/models.py b/moto/dynamodb2/models.py index 6bcde41b..300479e9 100644 --- a/moto/dynamodb2/models.py +++ b/moto/dynamodb2/models.py @@ -68,10 +68,34 @@ class DynamoType(object): except ValueError: return float(self.value) elif self.is_set(): - return set(self.value) + sub_type = self.type[0] + return set([DynamoType({sub_type: v}).cast_value + for v in self.value]) + elif self.is_list(): + return [DynamoType(v).cast_value for v in self.value] + elif self.is_map(): + return dict([ + (k, DynamoType(v).cast_value) + for k, v in self.value.items()]) else: return self.value + def child_attr(self, key): + """ + Get Map or List children by key. str for Map, int for List. + + Returns DynamoType or None. + """ + if isinstance(key, str) and self.is_map() and key in self.value: + return DynamoType(self.value[key]) + + if isinstance(key, int) and self.is_list(): + idx = key + if idx >= 0 and idx < len(self.value): + return DynamoType(self.value[idx]) + + return None + def to_json(self): return {self.type: self.value} @@ -89,6 +113,12 @@ class DynamoType(object): def is_set(self): return self.type == 'SS' or self.type == 'NS' or self.type == 'BS' + def is_list(self): + return self.type == 'L' + + def is_map(self): + return self.type == 'M' + def same_type(self, other): return self.type == other.type @@ -954,10 +984,7 @@ class DynamoDBBackend(BaseBackend): range_values = [DynamoType(range_value) for range_value in range_value_dicts] - if filter_expression is not None: - filter_expression = get_filter_expression(filter_expression, expr_names, expr_values) - else: - filter_expression = Op(None, None) # Will always eval to true + filter_expression = get_filter_expression(filter_expression, expr_names, expr_values) return table.query(hash_key, range_comparison, range_values, limit, exclusive_start_key, scan_index_forward, projection_expression, index_name, filter_expression, **filter_kwargs) @@ -972,10 +999,8 @@ class DynamoDBBackend(BaseBackend): dynamo_types = [DynamoType(value) for value in comparison_values] scan_filters[key] = (comparison_operator, dynamo_types) - if filter_expression is not None: - filter_expression = get_filter_expression(filter_expression, expr_names, expr_values) - else: - filter_expression = Op(None, None) # Will always eval to true + + filter_expression = get_filter_expression(filter_expression, expr_names, expr_values) return table.scan(scan_filters, limit, exclusive_start_key, filter_expression, index_name) diff --git a/tests/test_dynamodb2/test_dynamodb.py b/tests/test_dynamodb2/test_dynamodb.py index 77846de0..932139ee 100644 --- a/tests/test_dynamodb2/test_dynamodb.py +++ b/tests/test_dynamodb2/test_dynamodb.py @@ -676,44 +676,47 @@ def test_filter_expression(): filter_expr.expr(row1).should.be(True) # NOT test 2 - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('NOT (Id = :v0)', {}, {':v0': {'N': 8}}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression('NOT (Id = :v0)', {}, {':v0': {'N': '8'}}) filter_expr.expr(row1).should.be(False) # Id = 8 so should be false # AND test - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id > :v0 AND Subs < :v1', {}, {':v0': {'N': 5}, ':v1': {'N': 7}}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id > :v0 AND Subs < :v1', {}, {':v0': {'N': '5'}, ':v1': {'N': '7'}}) filter_expr.expr(row1).should.be(True) filter_expr.expr(row2).should.be(False) # OR test - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id = :v0 OR Id=:v1', {}, {':v0': {'N': 5}, ':v1': {'N': 8}}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id = :v0 OR Id=:v1', {}, {':v0': {'N': '5'}, ':v1': {'N': '8'}}) filter_expr.expr(row1).should.be(True) # BETWEEN test - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id BETWEEN :v0 AND :v1', {}, {':v0': {'N': 5}, ':v1': {'N': 10}}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id BETWEEN :v0 AND :v1', {}, {':v0': {'N': '5'}, ':v1': {'N': '10'}}) filter_expr.expr(row1).should.be(True) # PAREN test - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id = :v0 AND (Subs = :v0 OR Subs = :v1)', {}, {':v0': {'N': 8}, ':v1': {'N': 5}}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id = :v0 AND (Subs = :v0 OR Subs = :v1)', {}, {':v0': {'N': '8'}, ':v1': {'N': '5'}}) filter_expr.expr(row1).should.be(True) # IN test - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id IN :v0', {}, {':v0': {'NS': [7, 8, 9]}}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id IN (:v0, :v1, :v2)', {}, { + ':v0': {'N': '7'}, + ':v1': {'N': '8'}, + ':v2': {'N': '9'}}) filter_expr.expr(row1).should.be(True) # attribute function tests (with extra spaces) filter_expr = moto.dynamodb2.comparisons.get_filter_expression('attribute_exists(Id) AND attribute_not_exists (User)', {}, {}) filter_expr.expr(row1).should.be(True) - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('attribute_type(Id, N)', {}, {}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression('attribute_type(Id, :v0)', {}, {':v0': {'S': 'N'}}) filter_expr.expr(row1).should.be(True) # beginswith function test - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('begins_with(Desc, Some)', {}, {}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression('begins_with(Desc, :v0)', {}, {':v0': {'S': 'Some'}}) filter_expr.expr(row1).should.be(True) filter_expr.expr(row2).should.be(False) # contains function test - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('contains(KV, test1)', {}, {}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression('contains(KV, :v0)', {}, {':v0': {'S': 'test1'}}) filter_expr.expr(row1).should.be(True) filter_expr.expr(row2).should.be(False) @@ -754,14 +757,26 @@ def test_query_filter(): TableName='test1', Item={ 'client': {'S': 'client1'}, - 'app': {'S': 'app1'} + 'app': {'S': 'app1'}, + 'nested': {'M': { + 'version': {'S': 'version1'}, + 'contents': {'L': [ + {'S': 'value1'}, {'S': 'value2'}, + ]}, + }}, } ) client.put_item( TableName='test1', Item={ 'client': {'S': 'client1'}, - 'app': {'S': 'app2'} + 'app': {'S': 'app2'}, + 'nested': {'M': { + 'version': {'S': 'version2'}, + 'contents': {'L': [ + {'S': 'value1'}, {'S': 'value2'}, + ]}, + }}, } ) @@ -783,6 +798,18 @@ def test_query_filter(): ) assert response['Count'] == 2 + response = table.query( + KeyConditionExpression=Key('client').eq('client1'), + FilterExpression=Attr('nested.version').contains('version') + ) + assert response['Count'] == 2 + + response = table.query( + KeyConditionExpression=Key('client').eq('client1'), + FilterExpression=Attr('nested.contents[0]').eq('value1') + ) + assert response['Count'] == 2 + @mock_dynamodb2 def test_scan_filter(): @@ -1061,7 +1088,7 @@ def test_delete_item(): with assert_raises(ClientError) as ex: table.delete_item(Key={'client': 'client1', 'app': 'app1'}, ReturnValues='ALL_NEW') - + # Test deletion and returning old value response = table.delete_item(Key={'client': 'client1', 'app': 'app1'}, ReturnValues='ALL_OLD') response['Attributes'].should.contain('client') @@ -1364,7 +1391,7 @@ def test_put_return_attributes(): ReturnValues='NONE' ) assert 'Attributes' not in r - + r = dynamodb.put_item( TableName='moto-test', Item={'id': {'S': 'foo'}, 'col1': {'S': 'val2'}}, @@ -1381,7 +1408,7 @@ def test_put_return_attributes(): ex.exception.response['Error']['Code'].should.equal('ValidationException') ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) ex.exception.response['Error']['Message'].should.equal('Return values set to invalid value') - + @mock_dynamodb2 def test_query_global_secondary_index_when_created_via_update_table_resource(): @@ -1489,7 +1516,7 @@ def test_dynamodb_streams_1(): 'StreamViewType': 'NEW_AND_OLD_IMAGES' } ) - + assert 'StreamSpecification' in resp['TableDescription'] assert resp['TableDescription']['StreamSpecification'] == { 'StreamEnabled': True, @@ -1497,11 +1524,11 @@ def test_dynamodb_streams_1(): } assert 'LatestStreamLabel' in resp['TableDescription'] assert 'LatestStreamArn' in resp['TableDescription'] - + resp = conn.delete_table(TableName='test-streams') assert 'StreamSpecification' in resp['TableDescription'] - + @mock_dynamodb2 def test_dynamodb_streams_2(): @@ -1532,7 +1559,7 @@ def test_dynamodb_streams_2(): assert 'LatestStreamLabel' in resp['TableDescription'] assert 'LatestStreamArn' in resp['TableDescription'] - + @mock_dynamodb2 def test_condition_expressions(): client = boto3.client('dynamodb', region_name='us-east-1') From 271265451822d7ca7c1a9d68386f84b5b7323aeb Mon Sep 17 00:00:00 2001 From: Matthew Stevens Date: Mon, 1 Apr 2019 16:23:49 -0400 Subject: [PATCH 02/10] Using Ops for dynamodb expected dicts --- moto/dynamodb2/comparisons.py | 122 ++++++++++++++++++++++++++-------- moto/dynamodb2/models.py | 52 ++------------- 2 files changed, 101 insertions(+), 73 deletions(-) diff --git a/moto/dynamodb2/comparisons.py b/moto/dynamodb2/comparisons.py index ac78d45b..06d99260 100644 --- a/moto/dynamodb2/comparisons.py +++ b/moto/dynamodb2/comparisons.py @@ -19,6 +19,63 @@ def get_filter_expression(expr, names, values): return parser.parse() +def get_expected(expected): + """ + Parse a filter expression into an Op. + + Examples + expr = 'Id > 5 AND attribute_exists(test) AND Id BETWEEN 5 AND 6 OR length < 6 AND contains(test, 1) AND 5 IN (4,5, 6) OR (Id < 5 AND 5 > Id)' + expr = 'Id > 5 AND Subs < 7' + """ + ops = { + 'EQ': OpEqual, + 'NE': OpNotEqual, + 'LE': OpLessThanOrEqual, + 'LT': OpLessThan, + 'GE': OpGreaterThanOrEqual, + 'GT': OpGreaterThan, + 'NOT_NULL': FuncAttrExists, + 'NULL': FuncAttrNotExists, + 'CONTAINS': FuncContains, + 'NOT_CONTAINS': FuncNotContains, + 'BEGINS_WITH': FuncBeginsWith, + 'IN': FuncIn, + 'BETWEEN': FuncBetween, + } + + # NOTE: Always uses ConditionalOperator=AND + conditions = [] + for key, cond in expected.items(): + path = AttributePath([key]) + if 'Exists' in cond: + if cond['Exists']: + conditions.append(FuncAttrExists(path)) + else: + conditions.append(FuncAttrNotExists(path)) + elif 'Value' in cond: + conditions.append(OpEqual(path, AttributeValue(cond['Value']))) + elif 'ComparisonOperator' in cond: + operator_name = cond['ComparisonOperator'] + values = [ + AttributeValue(v) + for v in cond.get("AttributeValueList", [])] + print(path, values) + OpClass = ops[operator_name] + conditions.append(OpClass(path, *values)) + + # NOTE: Ignore ConditionalOperator + ConditionalOp = OpAnd + if conditions: + output = conditions[0] + for condition in conditions[1:]: + output = ConditionalOp(output, condition) + else: + return OpDefault(None, None) + + print("EXPECTED:", expected, output) + return output + + class Op(object): """ Base class for a FilterExpression operator @@ -782,14 +839,19 @@ class AttributePath(Operand): self.path = path def _get_attr(self, item): + if item is None: + return None + base = self.path[0] if base not in item.attrs: return None attr = item.attrs[base] + for name in self.path[1:]: attr = attr.child_attr(name) if attr is None: return None + return attr def expr(self, item): @@ -807,7 +869,7 @@ class AttributePath(Operand): return attr.type def __repr__(self): - return self.path + return ".".join(self.path) class AttributeValue(Operand): @@ -821,23 +883,27 @@ class AttributeValue(Operand): """ self.type = list(value.keys())[0] - if 'N' in value: - self.value = float(value['N']) - elif 'BOOL' in value: - self.value = value['BOOL'] - elif 'S' in value: - self.value = value['S'] - elif 'NS' in value: - self.value = tuple(value['NS']) - elif 'SS' in value: - self.value = tuple(value['SS']) - elif 'L' in value: - self.value = tuple(value['L']) - else: - # TODO: Handle all attribute types - raise NotImplementedError() + self.value = value[self.type] def expr(self, item): + # TODO: Reuse DynamoType code + if self.type == 'N': + try: + return int(self.value) + except ValueError: + return float(self.value) + elif self.type in ['SS', 'NS', 'BS']: + sub_type = self.type[0] + return set([AttributeValue({sub_type: v}).expr(item) + for v in self.value]) + elif self.type == 'L': + return [AttributeValue(v).expr(item) for v in self.value] + elif self.type == 'M': + return dict([ + (k, AttributeValue(v).expr(item)) + for k, v in self.value.items()]) + else: + return self.value return self.value def get_type(self, item): @@ -976,15 +1042,8 @@ class FuncAttrExists(Func): return self.attr.get_type(item) is not None -class FuncAttrNotExists(Func): - FUNC = 'attribute_not_exists' - - def __init__(self, attribute): - self.attr = attribute - super().__init__(attribute) - - def expr(self, item): - return self.attr.get_type(item) is None +def FuncAttrNotExists(attribute): + return OpNot(FuncAttrExists(attribute), None) class FuncAttrType(Func): @@ -1024,13 +1083,20 @@ class FuncContains(Func): super().__init__(attribute, operand) def expr(self, item): - if self.attr.get_type(item) in ('S', 'SS', 'NS', 'BS', 'L', 'M'): - return self.operand.expr(item) in self.attr.expr(item) + if self.attr.get_type(item) in ('S', 'SS', 'NS', 'BS', 'L'): + try: + return self.operand.expr(item) in self.attr.expr(item) + except TypeError: + return False return False +def FuncNotContains(attribute, operand): + return OpNot(FuncContains(attribute, operand), None) + + class FuncSize(Func): - FUNC = 'contains' + FUNC = 'size' def __init__(self, attribute): self.attr = attribute diff --git a/moto/dynamodb2/models.py b/moto/dynamodb2/models.py index 300479e9..bdf59df1 100644 --- a/moto/dynamodb2/models.py +++ b/moto/dynamodb2/models.py @@ -13,6 +13,9 @@ from moto.core import BaseBackend, BaseModel from moto.core.utils import unix_time from moto.core.exceptions import JsonRESTError from .comparisons import get_comparison_func, get_filter_expression, Op +from .comparisons import get_comparison_func +from .comparisons import get_filter_expression +from .comparisons import get_expected from .exceptions import InvalidIndexNameError @@ -557,29 +560,9 @@ class Table(BaseModel): self.range_key_type, item_attrs) if not overwrite: - if current is None: - current_attr = {} - elif hasattr(current, 'attrs'): - current_attr = current.attrs - else: - current_attr = current + if not get_expected(expected).expr(current): + raise ValueError('The conditional request failed') - for key, val in expected.items(): - if 'Exists' in val and val['Exists'] is False \ - or 'ComparisonOperator' in val and val['ComparisonOperator'] == 'NULL': - if key in current_attr: - raise ValueError("The conditional request failed") - elif key not in current_attr: - raise ValueError("The conditional request failed") - elif 'Value' in val and DynamoType(val['Value']).value != current_attr[key].value: - raise ValueError("The conditional request failed") - elif 'ComparisonOperator' in val: - dynamo_types = [ - DynamoType(ele) for ele in - val.get("AttributeValueList", []) - ] - if not current_attr[key].compare(val['ComparisonOperator'], dynamo_types): - raise ValueError('The conditional request failed') if range_value: self.items[hash_value][range_value] = item else: @@ -1024,32 +1007,11 @@ class DynamoDBBackend(BaseBackend): item = table.get_item(hash_value, range_value) - if item is None: - item_attr = {} - elif hasattr(item, 'attrs'): - item_attr = item.attrs - else: - item_attr = item - if not expected: expected = {} - for key, val in expected.items(): - if 'Exists' in val and val['Exists'] is False \ - or 'ComparisonOperator' in val and val['ComparisonOperator'] == 'NULL': - if key in item_attr: - raise ValueError("The conditional request failed") - elif key not in item_attr: - raise ValueError("The conditional request failed") - elif 'Value' in val and DynamoType(val['Value']).value != item_attr[key].value: - raise ValueError("The conditional request failed") - elif 'ComparisonOperator' in val: - dynamo_types = [ - DynamoType(ele) for ele in - val.get("AttributeValueList", []) - ] - if not item_attr[key].compare(val['ComparisonOperator'], dynamo_types): - raise ValueError('The conditional request failed') + if not get_expected(expected).expr(item): + raise ValueError('The conditional request failed') # Update does not fail on new items, so create one if item is None: From 57b668c8323761b173276607f1263863207ce053 Mon Sep 17 00:00:00 2001 From: Matthew Stevens Date: Mon, 1 Apr 2019 16:48:00 -0400 Subject: [PATCH 03/10] Using Ops for dynamodb condition expressions --- moto/dynamodb2/comparisons.py | 16 ++++++-------- moto/dynamodb2/models.py | 26 ++++++++++++++++++---- moto/dynamodb2/responses.py | 32 ++++++++++++--------------- tests/test_dynamodb2/test_dynamodb.py | 15 +++++++++++++ 4 files changed, 58 insertions(+), 31 deletions(-) diff --git a/moto/dynamodb2/comparisons.py b/moto/dynamodb2/comparisons.py index 06d99260..4095acba 100644 --- a/moto/dynamodb2/comparisons.py +++ b/moto/dynamodb2/comparisons.py @@ -59,7 +59,6 @@ def get_expected(expected): values = [ AttributeValue(v) for v in cond.get("AttributeValueList", [])] - print(path, values) OpClass = ops[operator_name] conditions.append(OpClass(path, *values)) @@ -72,7 +71,6 @@ def get_expected(expected): else: return OpDefault(None, None) - print("EXPECTED:", expected, output) return output @@ -486,7 +484,7 @@ class ConditionExpressionParser: lhs = nodes.popleft() comparator = nodes.popleft() rhs = nodes.popleft() - output.append(self.Node( + nodes.appendleft(self.Node( nonterminal=self.Nonterminal.CONDITION, kind=self.Kind.COMPARISON, text=" ".join([ @@ -528,7 +526,7 @@ class ConditionExpressionParser: self._assert( False, "Bad IN expression starting at", nodes) - output.append(self.Node( + nodes.appendleft(self.Node( nonterminal=self.Nonterminal.CONDITION, kind=self.Kind.IN, text=" ".join([t.text for t in all_children]), @@ -553,7 +551,7 @@ class ConditionExpressionParser: and_node = nodes.popleft() high = nodes.popleft() all_children = [lhs, between_node, low, and_node, high] - output.append(self.Node( + nodes.appendleft(self.Node( nonterminal=self.Nonterminal.CONDITION, kind=self.Kind.BETWEEN, text=" ".join([t.text for t in all_children]), @@ -613,7 +611,7 @@ class ConditionExpressionParser: nonterminal = self.Nonterminal.OPERAND else: nonterminal = self.Nonterminal.CONDITION - output.append(self.Node( + nodes.appendleft(self.Node( nonterminal=nonterminal, kind=self.Kind.FUNCTION, text=" ".join([t.text for t in all_children]), @@ -685,7 +683,7 @@ class ConditionExpressionParser: "Bad NOT expression", list(nodes)[:2]) not_node = nodes.popleft() child = nodes.popleft() - output.append(self.Node( + nodes.appendleft(self.Node( nonterminal=self.Nonterminal.CONDITION, kind=self.Kind.NOT, text=" ".join([not_node.text, child.text]), @@ -708,7 +706,7 @@ class ConditionExpressionParser: and_node = nodes.popleft() rhs = nodes.popleft() all_children = [lhs, and_node, rhs] - output.append(self.Node( + nodes.appendleft(self.Node( nonterminal=self.Nonterminal.CONDITION, kind=self.Kind.AND, text=" ".join([t.text for t in all_children]), @@ -731,7 +729,7 @@ class ConditionExpressionParser: or_node = nodes.popleft() rhs = nodes.popleft() all_children = [lhs, or_node, rhs] - output.append(self.Node( + nodes.appendleft(self.Node( nonterminal=self.Nonterminal.CONDITION, kind=self.Kind.OR, text=" ".join([t.text for t in all_children]), diff --git a/moto/dynamodb2/models.py b/moto/dynamodb2/models.py index bdf59df1..037db3d7 100644 --- a/moto/dynamodb2/models.py +++ b/moto/dynamodb2/models.py @@ -537,7 +537,9 @@ class Table(BaseModel): keys.append(range_key) return keys - def put_item(self, item_attrs, expected=None, overwrite=False): + def put_item(self, item_attrs, expected=None, condition_expression=None, + expression_attribute_names=None, + expression_attribute_values=None, overwrite=False): hash_value = DynamoType(item_attrs.get(self.hash_key_attr)) if self.has_range_key: range_value = DynamoType(item_attrs.get(self.range_key_attr)) @@ -562,6 +564,12 @@ class Table(BaseModel): if not overwrite: if not get_expected(expected).expr(current): raise ValueError('The conditional request failed') + condition_op = get_filter_expression( + condition_expression, + expression_attribute_names, + expression_attribute_values) + if not condition_op.expr(current): + raise ValueError('The conditional request failed') if range_value: self.items[hash_value][range_value] = item @@ -907,11 +915,15 @@ class DynamoDBBackend(BaseBackend): table.global_indexes = list(gsis_by_name.values()) return table - def put_item(self, table_name, item_attrs, expected=None, overwrite=False): + def put_item(self, table_name, item_attrs, expected=None, + condition_expression=None, expression_attribute_names=None, + expression_attribute_values=None, overwrite=False): table = self.tables.get(table_name) if not table: return None - return table.put_item(item_attrs, expected, overwrite) + return table.put_item(item_attrs, expected, condition_expression, + expression_attribute_names, + expression_attribute_values, overwrite) def get_table_keys_name(self, table_name, keys): """ @@ -988,7 +1000,7 @@ class DynamoDBBackend(BaseBackend): return table.scan(scan_filters, limit, exclusive_start_key, filter_expression, index_name) def update_item(self, table_name, key, update_expression, attribute_updates, expression_attribute_names, - expression_attribute_values, expected=None): + expression_attribute_values, expected=None, condition_expression=None): table = self.get_table(table_name) if all([table.hash_key_attr in key, table.range_key_attr in key]): @@ -1012,6 +1024,12 @@ class DynamoDBBackend(BaseBackend): if not get_expected(expected).expr(item): raise ValueError('The conditional request failed') + condition_op = get_filter_expression( + condition_expression, + expression_attribute_names, + expression_attribute_values) + if not condition_op.expr(current): + raise ValueError('The conditional request failed') # Update does not fail on new items, so create one if item is None: diff --git a/moto/dynamodb2/responses.py b/moto/dynamodb2/responses.py index 7eb56574..13dde683 100644 --- a/moto/dynamodb2/responses.py +++ b/moto/dynamodb2/responses.py @@ -288,18 +288,18 @@ class DynamoHandler(BaseResponse): # Attempt to parse simple ConditionExpressions into an Expected # expression - if not expected: - condition_expression = self.body.get('ConditionExpression') - expression_attribute_names = self.body.get('ExpressionAttributeNames', {}) - expression_attribute_values = self.body.get('ExpressionAttributeValues', {}) - expected = condition_expression_to_expected(condition_expression, - expression_attribute_names, - expression_attribute_values) - if expected: - overwrite = False + condition_expression = self.body.get('ConditionExpression') + expression_attribute_names = self.body.get('ExpressionAttributeNames', {}) + expression_attribute_values = self.body.get('ExpressionAttributeValues', {}) + + if condition_expression: + overwrite = False try: - result = self.dynamodb_backend.put_item(name, item, expected, overwrite) + result = self.dynamodb_backend.put_item( + name, item, expected, condition_expression, + expression_attribute_names, expression_attribute_values, + overwrite) except ValueError: er = 'com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException' return self.error(er, 'A condition specified in the operation could not be evaluated.') @@ -652,13 +652,9 @@ class DynamoHandler(BaseResponse): # Attempt to parse simple ConditionExpressions into an Expected # expression - if not expected: - condition_expression = self.body.get('ConditionExpression') - expression_attribute_names = self.body.get('ExpressionAttributeNames', {}) - expression_attribute_values = self.body.get('ExpressionAttributeValues', {}) - expected = condition_expression_to_expected(condition_expression, - expression_attribute_names, - expression_attribute_values) + condition_expression = self.body.get('ConditionExpression') + expression_attribute_names = self.body.get('ExpressionAttributeNames', {}) + expression_attribute_values = self.body.get('ExpressionAttributeValues', {}) # Support spaces between operators in an update expression # E.g. `a = b + c` -> `a=b+c` @@ -669,7 +665,7 @@ class DynamoHandler(BaseResponse): try: item = self.dynamodb_backend.update_item( name, key, update_expression, attribute_updates, expression_attribute_names, - expression_attribute_values, expected + expression_attribute_values, expected, condition_expression ) except ValueError: er = 'com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException' diff --git a/tests/test_dynamodb2/test_dynamodb.py b/tests/test_dynamodb2/test_dynamodb.py index 932139ee..f87e84fb 100644 --- a/tests/test_dynamodb2/test_dynamodb.py +++ b/tests/test_dynamodb2/test_dynamodb.py @@ -1616,6 +1616,21 @@ def test_condition_expressions(): } ) + client.put_item( + TableName='test1', + Item={ + 'client': {'S': 'client1'}, + 'app': {'S': 'app1'}, + 'match': {'S': 'match'}, + 'existing': {'S': 'existing'}, + }, + ConditionExpression='attribute_exists(#nonexistent) OR attribute_exists(#existing)', + ExpressionAttributeNames={ + '#nonexistent': 'nope', + '#existing': 'existing' + } + ) + with assert_raises(client.exceptions.ConditionalCheckFailedException): client.put_item( TableName='test1', From 6fd47f843fb046305d6379b4f791dbe76569f87a Mon Sep 17 00:00:00 2001 From: Matthew Stevens Date: Mon, 1 Apr 2019 17:00:02 -0400 Subject: [PATCH 04/10] Test case for #1819 --- tests/test_dynamodb2/test_dynamodb.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/test_dynamodb2/test_dynamodb.py b/tests/test_dynamodb2/test_dynamodb.py index f87e84fb..d2178205 100644 --- a/tests/test_dynamodb2/test_dynamodb.py +++ b/tests/test_dynamodb2/test_dynamodb.py @@ -1631,6 +1631,24 @@ def test_condition_expressions(): } ) + client.put_item( + TableName='test1', + Item={ + 'client': {'S': 'client1'}, + 'app': {'S': 'app1'}, + 'match': {'S': 'match'}, + 'existing': {'S': 'existing'}, + }, + ConditionExpression='#client BETWEEN :a AND :z', + ExpressionAttributeNames={ + '#client': 'client', + }, + ExpressionAttributeValues={ + ':a': {'S': 'a'}, + ':z': {'S': 'z'}, + } + ) + with assert_raises(client.exceptions.ConditionalCheckFailedException): client.put_item( TableName='test1', From 8a90971ba152a692ac9f17ef346739630136e6ad Mon Sep 17 00:00:00 2001 From: Matthew Stevens Date: Mon, 1 Apr 2019 17:02:14 -0400 Subject: [PATCH 05/10] Adding test cases for #1587 --- tests/test_dynamodb2/test_dynamodb.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/test_dynamodb2/test_dynamodb.py b/tests/test_dynamodb2/test_dynamodb.py index d2178205..0ea1d64e 100644 --- a/tests/test_dynamodb2/test_dynamodb.py +++ b/tests/test_dynamodb2/test_dynamodb.py @@ -1649,6 +1649,24 @@ def test_condition_expressions(): } ) + client.put_item( + TableName='test1', + Item={ + 'client': {'S': 'client1'}, + 'app': {'S': 'app1'}, + 'match': {'S': 'match'}, + 'existing': {'S': 'existing'}, + }, + ConditionExpression='#client IN (:client1, :client2)', + ExpressionAttributeNames={ + '#client': 'client', + }, + ExpressionAttributeValues={ + ':client1': {'S': 'client1'}, + ':client2': {'S': 'client2'}, + } + ) + with assert_raises(client.exceptions.ConditionalCheckFailedException): client.put_item( TableName='test1', From 94503285274e670acba6ac441539aa7153d6915d Mon Sep 17 00:00:00 2001 From: Matthew Stevens Date: Mon, 1 Apr 2019 17:03:58 -0400 Subject: [PATCH 06/10] Deleting unnecessary dynamodb2 file --- moto/dynamodb2/condition.py | 617 ------------------------------------ 1 file changed, 617 deletions(-) delete mode 100644 moto/dynamodb2/condition.py diff --git a/moto/dynamodb2/condition.py b/moto/dynamodb2/condition.py deleted file mode 100644 index b50678e2..00000000 --- a/moto/dynamodb2/condition.py +++ /dev/null @@ -1,617 +0,0 @@ -import re -import json -import enum -from collections import deque -from collections import namedtuple - - -class Kind(enum.Enum): - """Defines types of nodes in the syntax tree.""" - - # Condition nodes - # --------------- - OR = enum.auto() - AND = enum.auto() - NOT = enum.auto() - PARENTHESES = enum.auto() - FUNCTION = enum.auto() - BETWEEN = enum.auto() - IN = enum.auto() - COMPARISON = enum.auto() - - # Operand nodes - # ------------- - EXPRESSION_ATTRIBUTE_VALUE = enum.auto() - PATH = enum.auto() - - # Literal nodes - # -------------- - LITERAL = enum.auto() - - -class Nonterminal(enum.Enum): - """Defines nonterminals for defining productions.""" - CONDITION = enum.auto() - OPERAND = enum.auto() - COMPARATOR = enum.auto() - FUNCTION_NAME = enum.auto() - IDENTIFIER = enum.auto() - AND = enum.auto() - OR = enum.auto() - NOT = enum.auto() - BETWEEN = enum.auto() - IN = enum.auto() - COMMA = enum.auto() - LEFT_PAREN = enum.auto() - RIGHT_PAREN = enum.auto() - WHITESPACE = enum.auto() - - -Node = namedtuple('Node', ['nonterminal', 'kind', 'text', 'value', 'children']) - - -class ConditionExpressionParser: - def __init__(self, condition_expression, expression_attribute_names, - expression_attribute_values): - self.condition_expression = condition_expression - self.expression_attribute_names = expression_attribute_names - self.expression_attribute_values = expression_attribute_values - - def parse(self): - """Returns a syntax tree for the expression. - - The tree, and all of the nodes in the tree are a tuple of - - kind: str - - children/value: - list of nodes for parent nodes - value for leaf nodes - - Raises AssertionError if the condition expression is invalid - Raises KeyError if expression attribute names/values are invalid - - Here are the types of nodes that can be returned. - The types of child nodes are denoted with a colon (:). - An arbitrary number of children is denoted with ... - - Condition: - ('OR', [lhs : Condition, rhs : Condition]) - ('AND', [lhs: Condition, rhs: Condition]) - ('NOT', [argument: Condition]) - ('PARENTHESES', [argument: Condition]) - ('FUNCTION', [('LITERAL', function_name: str), argument: Operand, ...]) - ('BETWEEN', [query: Operand, low: Operand, high: Operand]) - ('IN', [query: Operand, possible_value: Operand, ...]) - ('COMPARISON', [lhs: Operand, ('LITERAL', comparator: str), rhs: Operand]) - - Operand: - ('EXPRESSION_ATTRIBUTE_VALUE', value: dict, e.g. {'S': 'foobar'}) - ('PATH', [('LITERAL', path_element: str), ...]) - NOTE: Expression attribute names will be expanded - - Literal: - ('LITERAL', value: str) - - """ - if not self.condition_expression: - return None - nodes = self._lex_condition_expression() - nodes = self._parse_paths(nodes) - self._print_debug(nodes) - nodes = self._apply_comparator(nodes) - self._print_debug(nodes) - nodes = self._apply_in(nodes) - self._print_debug(nodes) - nodes = self._apply_between(nodes) - self._print_debug(nodes) - nodes = self._apply_functions(nodes) - self._print_debug(nodes) - nodes = self._apply_parens_and_booleans(nodes) - self._print_debug(nodes) - node = nodes[0] - return self._make_node_tree(node) - - def _lex_condition_expression(self): - nodes = deque() - remaining_expression = self.condition_expression - while remaining_expression: - node, remaining_expression = \ - self._lex_one_node(remaining_expression) - if node.nonterminal == Nonterminal.WHITESPACE: - continue - nodes.append(node) - return nodes - - def _lex_one_node(self, remaining_expression): - - attribute_regex = '(:|#)?[A-z0-9\-_]+' - patterns = [( - Nonterminal.WHITESPACE, re.compile('^ +') - ), ( - Nonterminal.COMPARATOR, re.compile( - '^(' - '=|' - '<>|' - '<|' - '<=|' - '>|' - '>=)'), - ), ( - Nonterminal.OPERAND, re.compile( - '^' + - attribute_regex + '(\.' + attribute_regex + ')*') - ), ( - Nonterminal.COMMA, re.compile('^,') - ), ( - Nonterminal.LEFT_PAREN, re.compile('^\(') - ), ( - Nonterminal.RIGHT_PAREN, re.compile('^\)') - )] - - for nonterminal, pattern in patterns: - match = pattern.match(remaining_expression) - if match: - match_text = match.group() - break - else: - raise AssertionError("Cannot parse condition starting at: " + - remaining_expression) - - value = match_text - node = Node( - nonterminal=nonterminal, - kind=Kind.LITERAL, - text=match_text, - value=match_text, - children=[]) - - remaining_expression = remaining_expression[len(match_text):] - - return node, remaining_expression - - def _parse_paths(self, nodes): - output = deque() - - while nodes: - node = nodes.popleft() - - if node.nonterminal == Nonterminal.OPERAND: - path = node.value.split('.') - children = [ - self._parse_path_element(name) - for name in path] - if len(children) == 1: - child = children[0] - if child.nonterminal != Nonterminal.IDENTIFIER: - output.append(child) - continue - else: - for child in children: - self._assert( - child.nonterminal == Nonterminal.IDENTIFIER, - "Cannot use %s in path" % child.text, [node]) - output.append(Node( - nonterminal=Nonterminal.OPERAND, - kind=Kind.PATH, - text=node.text, - value=None, - children=children)) - else: - output.append(node) - return output - - def _parse_path_element(self, name): - reserved = { - 'AND': Nonterminal.AND, - 'OR': Nonterminal.OR, - 'IN': Nonterminal.IN, - 'BETWEEN': Nonterminal.BETWEEN, - 'NOT': Nonterminal.NOT, - } - - functions = { - 'attribute_exists', - 'attribute_not_exists', - 'attribute_type', - 'begins_with', - 'contains', - 'size', - } - - - if name in reserved: - nonterminal = reserved[name] - return Node( - nonterminal=nonterminal, - kind=Kind.LITERAL, - text=name, - value=name, - children=[]) - elif name in functions: - return Node( - nonterminal=Nonterminal.FUNCTION_NAME, - kind=Kind.LITERAL, - text=name, - value=name, - children=[]) - elif name.startswith(':'): - return Node( - nonterminal=Nonterminal.OPERAND, - kind=Kind.EXPRESSION_ATTRIBUTE_VALUE, - text=name, - value=self._lookup_expression_attribute_value(name), - children=[]) - elif name.startswith('#'): - return Node( - nonterminal=Nonterminal.IDENTIFIER, - kind=Kind.LITERAL, - text=name, - value=self._lookup_expression_attribute_name(name), - children=[]) - else: - return Node( - nonterminal=Nonterminal.IDENTIFIER, - kind=Kind.LITERAL, - text=name, - value=name, - children=[]) - - def _lookup_expression_attribute_value(self, name): - return self.expression_attribute_values[name] - - def _lookup_expression_attribute_name(self, name): - return self.expression_attribute_names[name] - - # NOTE: The following constructions are ordered from high precedence to low precedence - # according to - # https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Expressions.OperatorsAndFunctions.html#Expressions.OperatorsAndFunctions.Precedence - # - # = <> < <= > >= - # IN - # BETWEEN - # attribute_exists attribute_not_exists begins_with contains - # Parentheses - # NOT - # AND - # OR - # - # The grammar is taken from - # https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Expressions.OperatorsAndFunctions.html#Expressions.OperatorsAndFunctions.Syntax - # - # condition-expression ::= - # operand comparator operand - # operand BETWEEN operand AND operand - # operand IN ( operand (',' operand (, ...) )) - # function - # condition AND condition - # condition OR condition - # NOT condition - # ( condition ) - # - # comparator ::= - # = - # <> - # < - # <= - # > - # >= - # - # function ::= - # attribute_exists (path) - # attribute_not_exists (path) - # attribute_type (path, type) - # begins_with (path, substr) - # contains (path, operand) - # size (path) - - def _matches(self, nodes, production): - """Check if the nodes start with the given production. - - Parameters - ---------- - nodes: list of Node - production: list of str - The name of a Nonterminal, or '*' for anything - - """ - if len(nodes) < len(production): - return False - for i in range(len(production)): - if production[i] == '*': - continue - expected = getattr(Nonterminal, production[i]) - if nodes[i].nonterminal != expected: - return False - return True - - def _apply_comparator(self, nodes): - """Apply condition := operand comparator operand.""" - output = deque() - - while nodes: - if self._matches(nodes, ['*', 'COMPARATOR']): - self._assert( - self._matches(nodes, ['OPERAND', 'COMPARATOR', 'OPERAND']), - "Bad comparison", list(nodes)[:3]) - lhs = nodes.popleft() - comparator = nodes.popleft() - rhs = nodes.popleft() - output.append(Node( - nonterminal=Nonterminal.CONDITION, - kind=Kind.COMPARISON, - text=" ".join([ - lhs.text, - comparator.text, - rhs.text]), - value=None, - children=[lhs, comparator, rhs])) - else: - output.append(nodes.popleft()) - return output - - def _apply_in(self, nodes): - """Apply condition := operand IN ( operand , ... ).""" - output = deque() - while nodes: - if self._matches(nodes, ['*', 'IN']): - self._assert( - self._matches(nodes, ['OPERAND', 'IN', 'LEFT_PAREN']), - "Bad IN expression", list(nodes)[:3]) - lhs = nodes.popleft() - in_node = nodes.popleft() - left_paren = nodes.popleft() - all_children = [lhs, in_node, left_paren] - rhs = [] - while True: - if self._matches(nodes, ['OPERAND', 'COMMA']): - operand = nodes.popleft() - separator = nodes.popleft() - all_children += [operand, separator] - rhs.append(operand) - elif self._matches(nodes, ['OPERAND', 'RIGHT_PAREN']): - operand = nodes.popleft() - separator = nodes.popleft() - all_children += [operand, separator] - rhs.append(operand) - break # Close - else: - self._assert( - False, - "Bad IN expression starting at", nodes) - output.append(Node( - nonterminal=Nonterminal.CONDITION, - kind=Kind.IN, - text=" ".join([t.text for t in all_children]), - value=None, - children=[lhs] + rhs)) - else: - output.append(nodes.popleft()) - return output - - def _apply_between(self, nodes): - """Apply condition := operand BETWEEN operand AND operand.""" - output = deque() - while nodes: - if self._matches(nodes, ['*', 'BETWEEN']): - self._assert( - self._matches(nodes, ['OPERAND', 'BETWEEN', 'OPERAND', - 'AND', 'OPERAND']), - "Bad BETWEEN expression", list(nodes)[:5]) - lhs = nodes.popleft() - between_node = nodes.popleft() - low = nodes.popleft() - and_node = nodes.popleft() - high = nodes.popleft() - all_children = [lhs, between_node, low, and_node, high] - output.append(Node( - nonterminal=Nonterminal.CONDITION, - kind=Kind.BETWEEN, - text=" ".join([t.text for t in all_children]), - value=None, - children=[lhs, low, high])) - else: - output.append(nodes.popleft()) - return output - - def _apply_functions(self, nodes): - """Apply condition := function_name (operand , ...).""" - output = deque() - expected_argument_kind_map = { - 'attribute_exists': [{Kind.PATH}], - 'attribute_not_exists': [{Kind.PATH}], - 'attribute_type': [{Kind.PATH}, {Kind.EXPRESSION_ATTRIBUTE_VALUE}], - 'begins_with': [{Kind.PATH}, {Kind.EXPRESSION_ATTRIBUTE_VALUE}], - 'contains': [{Kind.PATH}, {Kind.PATH, Kind.EXPRESSION_ATTRIBUTE_VALUE}], - 'size': [{Kind.PATH}], - } - while nodes: - if self._matches(nodes, ['FUNCTION_NAME']): - self._assert( - self._matches(nodes, ['FUNCTION_NAME', 'LEFT_PAREN', - 'OPERAND', '*']), - "Bad function expression at", list(nodes)[:4]) - function_name = nodes.popleft() - left_paren = nodes.popleft() - all_children = [function_name, left_paren] - arguments = [] - while True: - if self._matches(nodes, ['OPERAND', 'COMMA']): - operand = nodes.popleft() - separator = nodes.popleft() - all_children += [operand, separator] - arguments.append(operand) - elif self._matches(nodes, ['OPERAND', 'RIGHT_PAREN']): - operand = nodes.popleft() - separator = nodes.popleft() - all_children += [operand, separator] - arguments.append(operand) - break # Close paren - else: - self._assert( - False, - "Bad function expression", all_children + list(nodes)[:2]) - expected_kinds = expected_argument_kind_map[function_name.value] - self._assert( - len(arguments) == len(expected_kinds), - "Wrong number of arguments in", all_children) - for i in range(len(expected_kinds)): - self._assert( - arguments[i].kind in expected_kinds[i], - "Wrong type for argument %d in" % i, all_children) - output.append(Node( - nonterminal=Nonterminal.CONDITION, - kind=Kind.FUNCTION, - text=" ".join([t.text for t in all_children]), - value=None, - children=[function_name] + arguments)) - else: - output.append(nodes.popleft()) - return output - - def _apply_parens_and_booleans(self, nodes, left_paren=None): - """Apply condition := ( condition ) and booleans.""" - output = deque() - while nodes: - if self._matches(nodes, ['LEFT_PAREN']): - parsed = self._apply_parens_and_booleans(nodes, left_paren=nodes.popleft()) - self._assert( - len(parsed) >= 1, - "Failed to close parentheses at", nodes) - parens = parsed.popleft() - self._assert( - parens.kind == Kind.PARENTHESES, - "Failed to close parentheses at", nodes) - output.append(parens) - nodes = parsed - elif self._matches(nodes, ['RIGHT_PAREN']): - self._assert( - left_paren is not None, - "Unmatched ) at", nodes) - close_paren = nodes.popleft() - children = self._apply_booleans(output) - all_children = [left_paren, *children, close_paren] - return deque([ - Node( - nonterminal=Nonterminal.CONDITION, - kind=Kind.PARENTHESES, - text=" ".join([t.text for t in all_children]), - value=None, - children=list(children), - ), *nodes]) - else: - output.append(nodes.popleft()) - - self._assert( - left_paren is None, - "Unmatched ( at", list(output)) - return self._apply_booleans(output) - - def _apply_booleans(self, nodes): - """Apply and, or, and not constructions.""" - nodes = self._apply_not(nodes) - nodes = self._apply_and(nodes) - nodes = self._apply_or(nodes) - # The expression should reduce to a single condition - self._assert( - len(nodes) == 1, - "Unexpected expression at", list(nodes)[1:]) - self._assert( - nodes[0].nonterminal == Nonterminal.CONDITION, - "Incomplete condition", nodes) - return nodes - - def _apply_not(self, nodes): - """Apply condition := NOT condition.""" - output = deque() - while nodes: - if self._matches(nodes, ['NOT']): - self._assert( - self._matches(nodes, ['NOT', 'CONDITION']), - "Bad NOT expression", list(nodes)[:2]) - not_node = nodes.popleft() - child = nodes.popleft() - output.append(Node( - nonterminal=Nonterminal.CONDITION, - kind=Kind.NOT, - text=" ".join([not_node['text'], value['text']]), - value=None, - children=[child])) - else: - output.append(nodes.popleft()) - - return output - - def _apply_and(self, nodes): - """Apply condition := condition AND condition.""" - output = deque() - while nodes: - if self._matches(nodes, ['*', 'AND']): - self._assert( - self._matches(nodes, ['CONDITION', 'AND', 'CONDITION']), - "Bad AND expression", list(nodes)[:3]) - lhs = nodes.popleft() - and_node = nodes.popleft() - rhs = nodes.popleft() - all_children = [lhs, and_node, rhs] - output.append(Node( - nonterminal=Nonterminal.CONDITION, - kind=Kind.AND, - text=" ".join([t.text for t in all_children]), - value=None, - children=[lhs, rhs])) - else: - output.append(nodes.popleft()) - - return output - - def _apply_or(self, nodes): - """Apply condition := condition OR condition.""" - output = deque() - while nodes: - if self._matches(nodes, ['*', 'OR']): - self._assert( - self._matches(nodes, ['CONDITION', 'OR', 'CONDITION']), - "Bad OR expression", list(nodes)[:3]) - lhs = nodes.popleft() - or_node = nodes.popleft() - rhs = nodes.popleft() - all_children = [lhs, or_node, rhs] - output.append(Node( - nonterminal=Nonterminal.CONDITION, - kind=Kind.OR, - text=" ".join([t.text for t in all_children]), - value=None, - children=[lhs, rhs])) - else: - output.append(nodes.popleft()) - - return output - - def _make_node_tree(self, node): - if len(node.children) > 0: - return ( - node.kind.name, - [ - self._make_node_tree(child) - for child in node.children - ]) - else: - return (node.kind.name, node.value) - - def _print_debug(self, nodes): - print('ROOT') - for node in nodes: - self._print_node_recursive(node, depth=1) - - def _print_node_recursive(self, node, depth=0): - if len(node.children) > 0: - print(' ' * depth, node.nonterminal, node.kind) - for child in node.children: - self._print_node_recursive(child, depth=depth + 1) - else: - print(' ' * depth, node.nonterminal, node.kind, node.value) - - - - def _assert(self, condition, message, nodes): - if not condition: - raise AssertionError(message + " " + " ".join([t.text for t in nodes])) From 6303d07bac24021ecd0008e78e6a39ab6745d074 Mon Sep 17 00:00:00 2001 From: Matthew Stevens Date: Fri, 12 Apr 2019 10:13:36 -0400 Subject: [PATCH 07/10] Fixing tests --- moto/dynamodb2/comparisons.py | 125 ++++++++++++++++------------------ moto/dynamodb2/models.py | 16 ++--- 2 files changed, 67 insertions(+), 74 deletions(-) diff --git a/moto/dynamodb2/comparisons.py b/moto/dynamodb2/comparisons.py index 4095acba..1a4633e6 100644 --- a/moto/dynamodb2/comparisons.py +++ b/moto/dynamodb2/comparisons.py @@ -2,7 +2,6 @@ from __future__ import unicode_literals import re import six import re -import enum from collections import deque from collections import namedtuple @@ -199,46 +198,47 @@ class ConditionExpressionParser: op = self._make_op_condition(node) return op - class Kind(enum.Enum): - """Defines types of nodes in the syntax tree.""" + class Kind: + """Enum defining types of nodes in the syntax tree.""" # Condition nodes # --------------- - OR = enum.auto() - AND = enum.auto() - NOT = enum.auto() - PARENTHESES = enum.auto() - FUNCTION = enum.auto() - BETWEEN = enum.auto() - IN = enum.auto() - COMPARISON = enum.auto() + OR = 'OR' + AND = 'AND' + NOT = 'NOT' + PARENTHESES = 'PARENTHESES' + FUNCTION = 'FUNCTION' + BETWEEN = 'BETWEEN' + IN = 'IN' + COMPARISON = 'COMPARISON' # Operand nodes # ------------- - EXPRESSION_ATTRIBUTE_VALUE = enum.auto() - PATH = enum.auto() + EXPRESSION_ATTRIBUTE_VALUE = 'EXPRESSION_ATTRIBUTE_VALUE' + PATH = 'PATH' # Literal nodes # -------------- - LITERAL = enum.auto() + LITERAL = 'LITERAL' - class Nonterminal(enum.Enum): - """Defines nonterminals for defining productions.""" - CONDITION = enum.auto() - OPERAND = enum.auto() - COMPARATOR = enum.auto() - FUNCTION_NAME = enum.auto() - IDENTIFIER = enum.auto() - AND = enum.auto() - OR = enum.auto() - NOT = enum.auto() - BETWEEN = enum.auto() - IN = enum.auto() - COMMA = enum.auto() - LEFT_PAREN = enum.auto() - RIGHT_PAREN = enum.auto() - WHITESPACE = enum.auto() + class Nonterminal: + """Enum defining nonterminals for productions.""" + + CONDITION = 'CONDITION' + OPERAND = 'OPERAND' + COMPARATOR = 'COMPARATOR' + FUNCTION_NAME = 'FUNCTION_NAME' + IDENTIFIER = 'IDENTIFIER' + AND = 'AND' + OR = 'OR' + NOT = 'NOT' + BETWEEN = 'BETWEEN' + IN = 'IN' + COMMA = 'COMMA' + LEFT_PAREN = 'LEFT_PAREN' + RIGHT_PAREN = 'RIGHT_PAREN' + WHITESPACE = 'WHITESPACE' Node = namedtuple('Node', ['nonterminal', 'kind', 'text', 'value', 'children']) @@ -286,7 +286,7 @@ class ConditionExpressionParser: if match: match_text = match.group() break - else: + else: # pragma: no cover raise ValueError("Cannot parse condition starting at: " + remaining_expression) @@ -387,7 +387,7 @@ class ConditionExpressionParser: children=[]) elif name.startswith('['): # e.g. [123] - if not name.endswith(']'): + if not name.endswith(']'): # pragma: no cover raise ValueError("Bad path element %s" % name) return self.Node( nonterminal=self.Nonterminal.IDENTIFIER, @@ -642,7 +642,7 @@ class ConditionExpressionParser: "Unmatched ) at", nodes) close_paren = nodes.popleft() children = self._apply_booleans(output) - all_children = [left_paren, *children, close_paren] + all_children = [left_paren] + list(children) + [close_paren] return deque([ self.Node( nonterminal=self.Nonterminal.CONDITION, @@ -650,7 +650,7 @@ class ConditionExpressionParser: text=" ".join([t.text for t in all_children]), value=None, children=list(children), - ), *nodes]) + )] + list(nodes)) else: output.append(nodes.popleft()) @@ -747,11 +747,12 @@ class ConditionExpressionParser: return AttributeValue(node.value) elif node.kind == self.Kind.FUNCTION: # size() - function_node, *arguments = node.children + function_node = node.children[0] + arguments = node.children[1:] function_name = function_node.value arguments = [self._make_operand(arg) for arg in arguments] return FUNC_CLASS[function_name](*arguments) - else: + else: # pragma: no cover raise ValueError("Unknown operand: %r" % node) @@ -768,12 +769,13 @@ class ConditionExpressionParser: self._make_op_condition(rhs)) elif node.kind == self.Kind.NOT: child, = node.children - return OpNot(self._make_op_condition(child), None) + return OpNot(self._make_op_condition(child)) elif node.kind == self.Kind.PARENTHESES: child, = node.children return self._make_op_condition(child) elif node.kind == self.Kind.FUNCTION: - function_node, *arguments = node.children + function_node = node.children[0] + arguments = node.children[1:] function_name = function_node.value arguments = [self._make_operand(arg) for arg in arguments] return FUNC_CLASS[function_name](*arguments) @@ -784,24 +786,25 @@ class ConditionExpressionParser: self._make_operand(low), self._make_operand(high)) elif node.kind == self.Kind.IN: - query, *possible_values = node.children + query = node.children[0] + possible_values = node.children[1:] query = self._make_operand(query) possible_values = [self._make_operand(v) for v in possible_values] return FuncIn(query, *possible_values) elif node.kind == self.Kind.COMPARISON: lhs, comparator, rhs = node.children - return OP_CLASS[comparator.value]( + return COMPARATOR_CLASS[comparator.value]( self._make_operand(lhs), self._make_operand(rhs)) - else: + else: # pragma: no cover raise ValueError("Unknown expression node kind %r" % node.kind) - def _print_debug(self, nodes): + def _print_debug(self, nodes): # pragma: no cover print('ROOT') for node in nodes: self._print_node_recursive(node, depth=1) - def _print_node_recursive(self, node, depth=0): + def _print_node_recursive(self, node, depth=0): # pragma: no cover if len(node.children) > 0: print(' ' * depth, node.nonterminal, node.kind) for child in node.children: @@ -922,6 +925,9 @@ class OpDefault(Op): class OpNot(Op): OP = 'NOT' + def __init__(self, lhs): + super(OpNot, self).__init__(lhs, None) + def expr(self, item): lhs = self.lhs.expr(item) return not lhs @@ -1002,15 +1008,6 @@ class OpOr(Op): return lhs or rhs -class OpIn(Op): - OP = 'IN' - - def expr(self, item): - lhs = self.lhs.expr(item) - rhs = self.rhs.expr(item) - return lhs in rhs - - class Func(object): """ Base class for a FilterExpression function @@ -1034,14 +1031,14 @@ class FuncAttrExists(Func): def __init__(self, attribute): self.attr = attribute - super().__init__(attribute) + super(FuncAttrExists, self).__init__(attribute) def expr(self, item): return self.attr.get_type(item) is not None def FuncAttrNotExists(attribute): - return OpNot(FuncAttrExists(attribute), None) + return OpNot(FuncAttrExists(attribute)) class FuncAttrType(Func): @@ -1050,7 +1047,7 @@ class FuncAttrType(Func): def __init__(self, attribute, _type): self.attr = attribute self.type = _type - super().__init__(attribute, _type) + super(FuncAttrType, self).__init__(attribute, _type) def expr(self, item): return self.attr.get_type(item) == self.type.expr(item) @@ -1062,7 +1059,7 @@ class FuncBeginsWith(Func): def __init__(self, attribute, substr): self.attr = attribute self.substr = substr - super().__init__(attribute, substr) + super(FuncBeginsWith, self).__init__(attribute, substr) def expr(self, item): if self.attr.get_type(item) != 'S': @@ -1078,7 +1075,7 @@ class FuncContains(Func): def __init__(self, attribute, operand): self.attr = attribute self.operand = operand - super().__init__(attribute, operand) + super(FuncContains, self).__init__(attribute, operand) def expr(self, item): if self.attr.get_type(item) in ('S', 'SS', 'NS', 'BS', 'L'): @@ -1090,7 +1087,7 @@ class FuncContains(Func): def FuncNotContains(attribute, operand): - return OpNot(FuncContains(attribute, operand), None) + return OpNot(FuncContains(attribute, operand)) class FuncSize(Func): @@ -1098,7 +1095,7 @@ class FuncSize(Func): def __init__(self, attribute): self.attr = attribute - super().__init__(attribute) + super(FuncSize, self).__init__(attribute) def expr(self, item): if self.attr.get_type(item) is None: @@ -1116,7 +1113,7 @@ class FuncBetween(Func): self.attr = attribute self.start = start self.end = end - super().__init__(attribute, start, end) + super(FuncBetween, self).__init__(attribute, start, end) def expr(self, item): return self.start.expr(item) <= self.attr.expr(item) <= self.end.expr(item) @@ -1128,7 +1125,7 @@ class FuncIn(Func): def __init__(self, attribute, *possible_values): self.attr = attribute self.possible_values = possible_values - super().__init__(attribute, *possible_values) + super(FuncIn, self).__init__(attribute, *possible_values) def expr(self, item): for possible_value in self.possible_values: @@ -1138,11 +1135,7 @@ class FuncIn(Func): return False -OP_CLASS = { - 'NOT': OpNot, - 'AND': OpAnd, - 'OR': OpOr, - 'IN': OpIn, +COMPARATOR_CLASS = { '<': OpLessThan, '>': OpGreaterThan, '<=': OpLessThanOrEqual, diff --git a/moto/dynamodb2/models.py b/moto/dynamodb2/models.py index 037db3d7..1f2c6deb 100644 --- a/moto/dynamodb2/models.py +++ b/moto/dynamodb2/models.py @@ -6,6 +6,7 @@ import decimal import json import re import uuid +import six import boto3 from moto.compat import OrderedDict @@ -89,7 +90,7 @@ class DynamoType(object): Returns DynamoType or None. """ - if isinstance(key, str) and self.is_map() and key in self.value: + if isinstance(key, six.string_types) and self.is_map() and key in self.value: return DynamoType(self.value[key]) if isinstance(key, int) and self.is_list(): @@ -994,7 +995,6 @@ class DynamoDBBackend(BaseBackend): dynamo_types = [DynamoType(value) for value in comparison_values] scan_filters[key] = (comparison_operator, dynamo_types) - filter_expression = get_filter_expression(filter_expression, expr_names, expr_values) return table.scan(scan_filters, limit, exclusive_start_key, filter_expression, index_name) @@ -1024,12 +1024,12 @@ class DynamoDBBackend(BaseBackend): if not get_expected(expected).expr(item): raise ValueError('The conditional request failed') - condition_op = get_filter_expression( - condition_expression, - expression_attribute_names, - expression_attribute_values) - if not condition_op.expr(current): - raise ValueError('The conditional request failed') + condition_op = get_filter_expression( + condition_expression, + expression_attribute_names, + expression_attribute_values) + if not condition_op.expr(item): + raise ValueError('The conditional request failed') # Update does not fail on new items, so create one if item is None: From 83082df4d907293438c7b2cd9f622ca8da06450d Mon Sep 17 00:00:00 2001 From: Matthew Stevens Date: Sun, 14 Apr 2019 19:37:43 -0400 Subject: [PATCH 08/10] Adding update_item and attribute_not_exists test --- tests/test_dynamodb2/test_dynamodb.py | 36 +++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/tests/test_dynamodb2/test_dynamodb.py b/tests/test_dynamodb2/test_dynamodb.py index 0ea1d64e..dc41e367 100644 --- a/tests/test_dynamodb2/test_dynamodb.py +++ b/tests/test_dynamodb2/test_dynamodb.py @@ -1719,6 +1719,42 @@ def test_condition_expressions(): } ) + # Make sure update_item honors ConditionExpression as well + dynamodb.update_item( + TableName='test1', + Key={ + 'client': {'S': 'client1'}, + 'app': {'S': 'app1'}, + }, + UpdateExpression='set #match=:match', + ConditionExpression='attribute_exists(#existing)', + ExpressionAttributeNames={ + '#existing': 'existing', + '#match': 'match', + }, + ExpressionAttributeValues={ + ':match': {'S': 'match'} + } + ) + + with assert_raises(dynamodb.exceptions.ConditionalCheckFailedException): + dynamodb.update_item( + TableName='test1', + Key={ + 'client': { 'S': 'client1'}, + 'app': { 'S': 'app1'}, + }, + UpdateExpression='set #match=:match', + ConditionExpression='attribute_not_exists(#existing)', + ExpressionAttributeValues={ + ':match': {'S': 'match'} + }, + ExpressionAttributeNames={ + '#existing': 'existing', + '#match': 'match', + }, + ) + @mock_dynamodb2 def test_query_gsi_with_range_key(): From 467f669c1e6e48d8158a6b35474ccabe56aab3da Mon Sep 17 00:00:00 2001 From: Garrett Heel Date: Wed, 26 Jun 2019 23:13:01 +0100 Subject: [PATCH 09/10] add test for attr doesn't exist --- moto/dynamodb2/models.py | 1 - tests/test_dynamodb2/test_dynamodb.py | 54 +++++++++++++++++++++++++-- 2 files changed, 50 insertions(+), 5 deletions(-) diff --git a/moto/dynamodb2/models.py b/moto/dynamodb2/models.py index 1f2c6deb..6d3a4b95 100644 --- a/moto/dynamodb2/models.py +++ b/moto/dynamodb2/models.py @@ -13,7 +13,6 @@ from moto.compat import OrderedDict from moto.core import BaseBackend, BaseModel from moto.core.utils import unix_time from moto.core.exceptions import JsonRESTError -from .comparisons import get_comparison_func, get_filter_expression, Op from .comparisons import get_comparison_func from .comparisons import get_filter_expression from .comparisons import get_expected diff --git a/tests/test_dynamodb2/test_dynamodb.py b/tests/test_dynamodb2/test_dynamodb.py index dc41e367..a4d79f4d 100644 --- a/tests/test_dynamodb2/test_dynamodb.py +++ b/tests/test_dynamodb2/test_dynamodb.py @@ -1563,7 +1563,6 @@ def test_dynamodb_streams_2(): @mock_dynamodb2 def test_condition_expressions(): client = boto3.client('dynamodb', region_name='us-east-1') - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') # Create the DynamoDB table. client.create_table( @@ -1720,7 +1719,7 @@ def test_condition_expressions(): ) # Make sure update_item honors ConditionExpression as well - dynamodb.update_item( + client.update_item( TableName='test1', Key={ 'client': {'S': 'client1'}, @@ -1737,8 +1736,8 @@ def test_condition_expressions(): } ) - with assert_raises(dynamodb.exceptions.ConditionalCheckFailedException): - dynamodb.update_item( + with assert_raises(client.exceptions.ConditionalCheckFailedException): + client.update_item( TableName='test1', Key={ 'client': { 'S': 'client1'}, @@ -1756,6 +1755,53 @@ def test_condition_expressions(): ) +@mock_dynamodb2 +def test_condition_expression__attr_doesnt_exist(): + client = boto3.client('dynamodb', region_name='us-east-1') + + client.create_table( + TableName='test', + KeySchema=[{'AttributeName': 'forum_name', 'KeyType': 'HASH'}], + AttributeDefinitions=[ + {'AttributeName': 'forum_name', 'AttributeType': 'S'}, + ], + ProvisionedThroughput={'ReadCapacityUnits': 1, 'WriteCapacityUnits': 1}, + ) + + client.put_item( + TableName='test', + Item={ + 'forum_name': {'S': 'foo'}, + 'ttl': {'N': 'bar'}, + } + ) + + + def update_if_attr_doesnt_exist(): + # Test nonexistent top-level attribute. + client.update_item( + TableName='test', + Key={ + 'forum_name': {'S': 'the-key'}, + 'subject': {'S': 'the-subject'}, + }, + UpdateExpression='set #new_state=:new_state, #ttl=:ttl', + ConditionExpression='attribute_not_exists(#new_state)', + ExpressionAttributeNames={'#new_state': 'foobar', '#ttl': 'ttl'}, + ExpressionAttributeValues={ + ':new_state': {'S': 'some-value'}, + ':ttl': {'N': '12345.67'}, + }, + ReturnValues='ALL_NEW', + ) + + update_if_attr_doesnt_exist() + + # Second time should fail + with assert_raises(client.exceptions.ConditionalCheckFailedException): + update_if_attr_doesnt_exist() + + @mock_dynamodb2 def test_query_gsi_with_range_key(): dynamodb = boto3.client('dynamodb', region_name='us-east-1') From ba95c945f9b16f692b217f89816eae36fccab11c Mon Sep 17 00:00:00 2001 From: Garrett Heel Date: Tue, 9 Jul 2019 09:20:35 -0400 Subject: [PATCH 10/10] remove dead code --- moto/dynamodb2/responses.py | 61 ------------------------------------- 1 file changed, 61 deletions(-) diff --git a/moto/dynamodb2/responses.py b/moto/dynamodb2/responses.py index 13dde683..12260384 100644 --- a/moto/dynamodb2/responses.py +++ b/moto/dynamodb2/responses.py @@ -32,67 +32,6 @@ def get_empty_str_error(): )) -def condition_expression_to_expected(condition_expression, expression_attribute_names, expression_attribute_values): - """ - Limited condition expression syntax parsing. - Supports Global Negation ex: NOT(inner expressions). - Supports simple AND conditions ex: cond_a AND cond_b and cond_c. - Atomic expressions supported are attribute_exists(key), attribute_not_exists(key) and #key = :value. - """ - expected = {} - if condition_expression and 'OR' not in condition_expression: - reverse_re = re.compile('^NOT\s*\((.*)\)$') - reverse_m = reverse_re.match(condition_expression.strip()) - - reverse = False - if reverse_m: - reverse = True - condition_expression = reverse_m.group(1) - - cond_items = [c.strip() for c in condition_expression.split('AND')] - if cond_items: - exists_re = re.compile('^attribute_exists\s*\((.*)\)$') - not_exists_re = re.compile( - '^attribute_not_exists\s*\((.*)\)$') - equals_re = re.compile('^(#?\w+)\s*=\s*(\:?\w+)') - - for cond in cond_items: - exists_m = exists_re.match(cond) - not_exists_m = not_exists_re.match(cond) - equals_m = equals_re.match(cond) - - if exists_m: - attribute_name = expression_attribute_names_lookup(exists_m.group(1), expression_attribute_names) - expected[attribute_name] = {'Exists': True if not reverse else False} - elif not_exists_m: - attribute_name = expression_attribute_names_lookup(not_exists_m.group(1), expression_attribute_names) - expected[attribute_name] = {'Exists': False if not reverse else True} - elif equals_m: - attribute_name = expression_attribute_names_lookup(equals_m.group(1), expression_attribute_names) - attribute_value = expression_attribute_values_lookup(equals_m.group(2), expression_attribute_values) - expected[attribute_name] = { - 'AttributeValueList': [attribute_value], - 'ComparisonOperator': 'EQ' if not reverse else 'NEQ'} - - return expected - - -def expression_attribute_names_lookup(attribute_name, expression_attribute_names): - if attribute_name.startswith('#') and attribute_name in expression_attribute_names: - return expression_attribute_names[attribute_name] - else: - return attribute_name - - -def expression_attribute_values_lookup(attribute_value, expression_attribute_values): - if isinstance(attribute_value, six.string_types) and \ - attribute_value.startswith(':') and\ - attribute_value in expression_attribute_values: - return expression_attribute_values[attribute_value] - else: - return attribute_value - - class DynamoHandler(BaseResponse): def get_endpoint_name(self, headers):