diff --git a/moto/dynamodb2/comparisons.py b/moto/dynamodb2/comparisons.py index 6d37345f..ac78d45b 100644 --- a/moto/dynamodb2/comparisons.py +++ b/moto/dynamodb2/comparisons.py @@ -1,6 +1,40 @@ from __future__ import unicode_literals import re import six +import re +import enum +from collections import deque +from collections import namedtuple + + +def get_filter_expression(expr, names, values): + """ + Parse a filter expression into an Op. + + Examples + expr = 'Id > 5 AND attribute_exists(test) AND Id BETWEEN 5 AND 6 OR length < 6 AND contains(test, 1) AND 5 IN (4,5, 6) OR (Id < 5 AND 5 > Id)' + expr = 'Id > 5 AND Subs < 7' + """ + parser = ConditionExpressionParser(expr, names, values) + return parser.parse() + + +class Op(object): + """ + Base class for a FilterExpression operator + """ + OP = '' + + def __init__(self, lhs, rhs): + self.lhs = lhs + self.rhs = rhs + + def expr(self, item): + raise NotImplementedError("Expr not defined for {0}".format(type(self))) + + def __repr__(self): + return '({0} {1} {2})'.format(self.lhs, self.OP, self.rhs) + # TODO add tests for all of these EQ_FUNCTION = lambda item_value, test_value: item_value == test_value # flake8: noqa @@ -49,292 +83,783 @@ class RecursionStopIteration(StopIteration): pass -def get_filter_expression(expr, names, values): - # Examples - # expr = 'Id > 5 AND attribute_exists(test) AND Id BETWEEN 5 AND 6 OR length < 6 AND contains(test, 1) AND 5 IN (4,5, 6) OR (Id < 5 AND 5 > Id)' - # expr = 'Id > 5 AND Subs < 7' - if names is None: - names = {} - if values is None: - values = {} +class ConditionExpressionParser: + def __init__(self, condition_expression, expression_attribute_names, + expression_attribute_values): + self.condition_expression = condition_expression + self.expression_attribute_names = expression_attribute_names + self.expression_attribute_values = expression_attribute_values - # Do substitutions - for key, value in names.items(): - expr = expr.replace(key, value) + def parse(self): + """Returns a syntax tree for the expression. - # Store correct types of values for use later - values_map = {} - for key, value in values.items(): - if 'N' in value: - values_map[key] = float(value['N']) - elif 'BOOL' in value: - values_map[key] = value['BOOL'] - elif 'S' in value: - values_map[key] = value['S'] - elif 'NS' in value: - values_map[key] = tuple(value['NS']) - elif 'SS' in value: - values_map[key] = tuple(value['SS']) - elif 'L' in value: - values_map[key] = tuple(value['L']) + The tree, and all of the nodes in the tree are a tuple of + - kind: str + - children/value: + list of nodes for parent nodes + value for leaf nodes + + Raises ValueError if the condition expression is invalid + Raises KeyError if expression attribute names/values are invalid + + Here are the types of nodes that can be returned. + The types of child nodes are denoted with a colon (:). + An arbitrary number of children is denoted with ... + + Condition: + ('OR', [lhs : Condition, rhs : Condition]) + ('AND', [lhs: Condition, rhs: Condition]) + ('NOT', [argument: Condition]) + ('PARENTHESES', [argument: Condition]) + ('FUNCTION', [('LITERAL', function_name: str), argument: Operand, ...]) + ('BETWEEN', [query: Operand, low: Operand, high: Operand]) + ('IN', [query: Operand, possible_value: Operand, ...]) + ('COMPARISON', [lhs: Operand, ('LITERAL', comparator: str), rhs: Operand]) + + Operand: + ('EXPRESSION_ATTRIBUTE_VALUE', value: dict, e.g. {'S': 'foobar'}) + ('PATH', [('LITERAL', path_element: str), ...]) + NOTE: Expression attribute names will be expanded + ('FUNCTION', [('LITERAL', 'size'), argument: Operand]) + + Literal: + ('LITERAL', value: str) + + """ + if not self.condition_expression: + return OpDefault(None, None) + nodes = self._lex_condition_expression() + nodes = self._parse_paths(nodes) + # NOTE: The docs say that functions should be parsed after + # IN, BETWEEN, and comparisons like <=. + # However, these expressions are invalid as function arguments, + # so it is okay to parse functions first. This needs to be done + # to interpret size() correctly as an operand. + nodes = self._apply_functions(nodes) + nodes = self._apply_comparator(nodes) + nodes = self._apply_in(nodes) + nodes = self._apply_between(nodes) + nodes = self._apply_parens_and_booleans(nodes) + node = nodes[0] + op = self._make_op_condition(node) + return op + + class Kind(enum.Enum): + """Defines types of nodes in the syntax tree.""" + + # Condition nodes + # --------------- + OR = enum.auto() + AND = enum.auto() + NOT = enum.auto() + PARENTHESES = enum.auto() + FUNCTION = enum.auto() + BETWEEN = enum.auto() + IN = enum.auto() + COMPARISON = enum.auto() + + # Operand nodes + # ------------- + EXPRESSION_ATTRIBUTE_VALUE = enum.auto() + PATH = enum.auto() + + # Literal nodes + # -------------- + LITERAL = enum.auto() + + + class Nonterminal(enum.Enum): + """Defines nonterminals for defining productions.""" + CONDITION = enum.auto() + OPERAND = enum.auto() + COMPARATOR = enum.auto() + FUNCTION_NAME = enum.auto() + IDENTIFIER = enum.auto() + AND = enum.auto() + OR = enum.auto() + NOT = enum.auto() + BETWEEN = enum.auto() + IN = enum.auto() + COMMA = enum.auto() + LEFT_PAREN = enum.auto() + RIGHT_PAREN = enum.auto() + WHITESPACE = enum.auto() + + + Node = namedtuple('Node', ['nonterminal', 'kind', 'text', 'value', 'children']) + + def _lex_condition_expression(self): + nodes = deque() + remaining_expression = self.condition_expression + while remaining_expression: + node, remaining_expression = \ + self._lex_one_node(remaining_expression) + if node.nonterminal == self.Nonterminal.WHITESPACE: + continue + nodes.append(node) + return nodes + + def _lex_one_node(self, remaining_expression): + # TODO: Handle indexing like [1] + attribute_regex = '(:|#)?[A-z0-9\-_]+' + patterns = [( + self.Nonterminal.WHITESPACE, re.compile('^ +') + ), ( + self.Nonterminal.COMPARATOR, re.compile( + '^(' + # Put long expressions first for greedy matching + '<>|' + '<=|' + '>=|' + '=|' + '<|' + '>)'), + ), ( + self.Nonterminal.OPERAND, re.compile( + '^' + + attribute_regex + '(\.' + attribute_regex + '|\[[0-9]\])*') + ), ( + self.Nonterminal.COMMA, re.compile('^,') + ), ( + self.Nonterminal.LEFT_PAREN, re.compile('^\(') + ), ( + self.Nonterminal.RIGHT_PAREN, re.compile('^\)') + )] + + for nonterminal, pattern in patterns: + match = pattern.match(remaining_expression) + if match: + match_text = match.group() + break else: - raise NotImplementedError() + raise ValueError("Cannot parse condition starting at: " + + remaining_expression) - # Remove all spaces, tbf we could just skip them in the next step. - # The number of known options is really small so we can do a fair bit of cheating - expr = list(expr.strip()) + value = match_text + node = self.Node( + nonterminal=nonterminal, + kind=self.Kind.LITERAL, + text=match_text, + value=match_text, + children=[]) - # DodgyTokenisation stage 1 - def is_value(val): - return val not in ('<', '>', '=', '(', ')') + remaining_expression = remaining_expression[len(match_text):] - def contains_keyword(val): - for kw in ('BETWEEN', 'IN', 'AND', 'OR', 'NOT'): - if kw in val: - return kw - return None + return node, remaining_expression - def is_function(val): - return val in ('attribute_exists', 'attribute_not_exists', 'attribute_type', 'begins_with', 'contains', 'size') + def _parse_paths(self, nodes): + output = deque() - # Does the main part of splitting between sections of characters - tokens = [] - stack = '' - while len(expr) > 0: - current_char = expr.pop(0) + while nodes: + node = nodes.popleft() - if current_char == ' ': - if len(stack) > 0: - tokens.append(stack) - stack = '' - elif current_char == ',': # Split params , - if len(stack) > 0: - tokens.append(stack) - stack = '' - elif is_value(current_char): - stack += current_char + if node.nonterminal == self.Nonterminal.OPERAND: + path = node.value.replace('[', '.[').split('.') + children = [ + self._parse_path_element(name) + for name in path] + if len(children) == 1: + child = children[0] + if child.nonterminal != self.Nonterminal.IDENTIFIER: + output.append(child) + continue + else: + for child in children: + self._assert( + child.nonterminal == self.Nonterminal.IDENTIFIER, + "Cannot use %s in path" % child.text, [node]) + output.append(self.Node( + nonterminal=self.Nonterminal.OPERAND, + kind=self.Kind.PATH, + text=node.text, + value=None, + children=children)) + else: + output.append(node) + return output - kw = contains_keyword(stack) - if kw is not None: - # We have a kw in the stack, could be AND or something like 5AND - tmp = stack.replace(kw, '') - if len(tmp) > 0: - tokens.append(tmp) - tokens.append(kw) - stack = '' + def _parse_path_element(self, name): + reserved = { + 'and': self.Nonterminal.AND, + 'or': self.Nonterminal.OR, + 'in': self.Nonterminal.IN, + 'between': self.Nonterminal.BETWEEN, + 'not': self.Nonterminal.NOT, + } + + functions = { + 'attribute_exists', + 'attribute_not_exists', + 'attribute_type', + 'begins_with', + 'contains', + 'size', + } + + + if name.lower() in reserved: + # e.g. AND + nonterminal = reserved[name.lower()] + return self.Node( + nonterminal=nonterminal, + kind=self.Kind.LITERAL, + text=name, + value=name, + children=[]) + elif name in functions: + # e.g. attribute_exists + return self.Node( + nonterminal=self.Nonterminal.FUNCTION_NAME, + kind=self.Kind.LITERAL, + text=name, + value=name, + children=[]) + elif name.startswith(':'): + # e.g. :value0 + return self.Node( + nonterminal=self.Nonterminal.OPERAND, + kind=self.Kind.EXPRESSION_ATTRIBUTE_VALUE, + text=name, + value=self._lookup_expression_attribute_value(name), + children=[]) + elif name.startswith('#'): + # e.g. #name0 + return self.Node( + nonterminal=self.Nonterminal.IDENTIFIER, + kind=self.Kind.LITERAL, + text=name, + value=self._lookup_expression_attribute_name(name), + children=[]) + elif name.startswith('['): + # e.g. [123] + if not name.endswith(']'): + raise ValueError("Bad path element %s" % name) + return self.Node( + nonterminal=self.Nonterminal.IDENTIFIER, + kind=self.Kind.LITERAL, + text=name, + value=int(name[1:-1]), + children=[]) else: - if len(stack) > 0: - tokens.append(stack) - tokens.append(current_char) - stack = '' - if len(stack) > 0: - tokens.append(stack) + # e.g. ItemId + return self.Node( + nonterminal=self.Nonterminal.IDENTIFIER, + kind=self.Kind.LITERAL, + text=name, + value=name, + children=[]) - def is_op(val): - return val in ('<', '>', '=', '>=', '<=', '<>', 'BETWEEN', 'IN', 'AND', 'OR', 'NOT') + def _lookup_expression_attribute_value(self, name): + return self.expression_attribute_values[name] - # DodgyTokenisation stage 2, it groups together some elements to make RPN'ing it later easier. - def handle_token(token, tokens2, token_iterator): - # ok so this essentially groups up some tokens to make later parsing easier, - # when it encounters brackets it will recurse and then unrecurse when RecursionStopIteration is raised. - if token == ')': - raise RecursionStopIteration() # Should be recursive so this should work - elif token == '(': - temp_list = [] + def _lookup_expression_attribute_name(self, name): + return self.expression_attribute_names[name] - try: + # NOTE: The following constructions are ordered from high precedence to low precedence + # according to + # https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Expressions.OperatorsAndFunctions.html#Expressions.OperatorsAndFunctions.Precedence + # + # = <> < <= > >= + # IN + # BETWEEN + # attribute_exists attribute_not_exists begins_with contains + # Parentheses + # NOT + # AND + # OR + # + # The grammar is taken from + # https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Expressions.OperatorsAndFunctions.html#Expressions.OperatorsAndFunctions.Syntax + # + # condition-expression ::= + # operand comparator operand + # operand BETWEEN operand AND operand + # operand IN ( operand (',' operand (, ...) )) + # function + # condition AND condition + # condition OR condition + # NOT condition + # ( condition ) + # + # comparator ::= + # = + # <> + # < + # <= + # > + # >= + # + # function ::= + # attribute_exists (path) + # attribute_not_exists (path) + # attribute_type (path, type) + # begins_with (path, substr) + # contains (path, operand) + # size (path) + + def _matches(self, nodes, production): + """Check if the nodes start with the given production. + + Parameters + ---------- + nodes: list of Node + production: list of str + The name of a Nonterminal, or '*' for anything + + """ + if len(nodes) < len(production): + return False + for i in range(len(production)): + if production[i] == '*': + continue + expected = getattr(self.Nonterminal, production[i]) + if nodes[i].nonterminal != expected: + return False + return True + + def _apply_comparator(self, nodes): + """Apply condition := operand comparator operand.""" + output = deque() + + while nodes: + if self._matches(nodes, ['*', 'COMPARATOR']): + self._assert( + self._matches(nodes, ['OPERAND', 'COMPARATOR', 'OPERAND']), + "Bad comparison", list(nodes)[:3]) + lhs = nodes.popleft() + comparator = nodes.popleft() + rhs = nodes.popleft() + output.append(self.Node( + nonterminal=self.Nonterminal.CONDITION, + kind=self.Kind.COMPARISON, + text=" ".join([ + lhs.text, + comparator.text, + rhs.text]), + value=None, + children=[lhs, comparator, rhs])) + else: + output.append(nodes.popleft()) + return output + + def _apply_in(self, nodes): + """Apply condition := operand IN ( operand , ... ).""" + output = deque() + while nodes: + if self._matches(nodes, ['*', 'IN']): + self._assert( + self._matches(nodes, ['OPERAND', 'IN', 'LEFT_PAREN']), + "Bad IN expression", list(nodes)[:3]) + lhs = nodes.popleft() + in_node = nodes.popleft() + left_paren = nodes.popleft() + all_children = [lhs, in_node, left_paren] + rhs = [] while True: - next_token = six.next(token_iterator) - handle_token(next_token, temp_list, token_iterator) - except RecursionStopIteration: - pass # Continue - except StopIteration: - ValueError('Malformed filter expression, type1') - - # Sigh, we only want to group a tuple if it doesnt contain operators - if any([is_op(item) for item in temp_list]): - # Its an expression - tokens2.append('(') - tokens2.extend(temp_list) - tokens2.append(')') + if self._matches(nodes, ['OPERAND', 'COMMA']): + operand = nodes.popleft() + separator = nodes.popleft() + all_children += [operand, separator] + rhs.append(operand) + elif self._matches(nodes, ['OPERAND', 'RIGHT_PAREN']): + operand = nodes.popleft() + separator = nodes.popleft() + all_children += [operand, separator] + rhs.append(operand) + break # Close + else: + self._assert( + False, + "Bad IN expression starting at", nodes) + output.append(self.Node( + nonterminal=self.Nonterminal.CONDITION, + kind=self.Kind.IN, + text=" ".join([t.text for t in all_children]), + value=None, + children=[lhs] + rhs)) else: - tokens2.append(tuple(temp_list)) - elif token == 'BETWEEN': - field = tokens2.pop() - # if values map contains a number, it would be a float - # so we need to int() it anyway - op1 = six.next(token_iterator) - op1 = int(values_map.get(op1, op1)) - and_op = six.next(token_iterator) - assert and_op == 'AND' - op2 = six.next(token_iterator) - op2 = int(values_map.get(op2, op2)) - tokens2.append(['between', field, op1, op2]) - elif is_function(token): - function_list = [token] + output.append(nodes.popleft()) + return output - lbracket = six.next(token_iterator) - assert lbracket == '(' - - next_token = six.next(token_iterator) - while next_token != ')': - if next_token in values_map: - next_token = values_map[next_token] - function_list.append(next_token) - next_token = six.next(token_iterator) - - tokens2.append(function_list) - else: - # Convert tokens back to real types - if token in values_map: - token = values_map[token] - - # Need to join >= <= <> - if len(tokens2) > 0 and ((tokens2[-1] == '>' and token == '=') or (tokens2[-1] == '<' and token == '=') or (tokens2[-1] == '<' and token == '>')): - tokens2.append(tokens2.pop() + token) + def _apply_between(self, nodes): + """Apply condition := operand BETWEEN operand AND operand.""" + output = deque() + while nodes: + if self._matches(nodes, ['*', 'BETWEEN']): + self._assert( + self._matches(nodes, ['OPERAND', 'BETWEEN', 'OPERAND', + 'AND', 'OPERAND']), + "Bad BETWEEN expression", list(nodes)[:5]) + lhs = nodes.popleft() + between_node = nodes.popleft() + low = nodes.popleft() + and_node = nodes.popleft() + high = nodes.popleft() + all_children = [lhs, between_node, low, and_node, high] + output.append(self.Node( + nonterminal=self.Nonterminal.CONDITION, + kind=self.Kind.BETWEEN, + text=" ".join([t.text for t in all_children]), + value=None, + children=[lhs, low, high])) else: - tokens2.append(token) + output.append(nodes.popleft()) + return output - tokens2 = [] - token_iterator = iter(tokens) - for token in token_iterator: - handle_token(token, tokens2, token_iterator) - - # Start of the Shunting-Yard algorithm. <-- Proper beast algorithm! - def is_number(val): - return val not in ('<', '>', '=', '>=', '<=', '<>', 'BETWEEN', 'IN', 'AND', 'OR', 'NOT') - - OPS = {'<': 5, '>': 5, '=': 5, '>=': 5, '<=': 5, '<>': 5, 'IN': 8, 'AND': 11, 'OR': 12, 'NOT': 10, 'BETWEEN': 9, '(': 100, ')': 100} - - def shunting_yard(token_list): - output = [] - op_stack = [] - - # Basically takes in an infix notation calculation, converts it to a reverse polish notation where there is no - # ambiguity on which order operators are applied. - while len(token_list) > 0: - token = token_list.pop(0) - - if token == '(': - op_stack.append(token) - elif token == ')': - while len(op_stack) > 0 and op_stack[-1] != '(': - output.append(op_stack.pop()) - lbracket = op_stack.pop() - assert lbracket == '(' - - elif is_number(token): - output.append(token) + def _apply_functions(self, nodes): + """Apply condition := function_name (operand , ...).""" + output = deque() + either_kind = {self.Kind.PATH, self.Kind.EXPRESSION_ATTRIBUTE_VALUE} + expected_argument_kind_map = { + 'attribute_exists': [{self.Kind.PATH}], + 'attribute_not_exists': [{self.Kind.PATH}], + 'attribute_type': [either_kind, {self.Kind.EXPRESSION_ATTRIBUTE_VALUE}], + 'begins_with': [either_kind, either_kind], + 'contains': [either_kind, either_kind], + 'size': [{self.Kind.PATH}], + } + while nodes: + if self._matches(nodes, ['FUNCTION_NAME']): + self._assert( + self._matches(nodes, ['FUNCTION_NAME', 'LEFT_PAREN', + 'OPERAND', '*']), + "Bad function expression at", list(nodes)[:4]) + function_name = nodes.popleft() + left_paren = nodes.popleft() + all_children = [function_name, left_paren] + arguments = [] + while True: + if self._matches(nodes, ['OPERAND', 'COMMA']): + operand = nodes.popleft() + separator = nodes.popleft() + all_children += [operand, separator] + arguments.append(operand) + elif self._matches(nodes, ['OPERAND', 'RIGHT_PAREN']): + operand = nodes.popleft() + separator = nodes.popleft() + all_children += [operand, separator] + arguments.append(operand) + break # Close paren + else: + self._assert( + False, + "Bad function expression", all_children + list(nodes)[:2]) + expected_kinds = expected_argument_kind_map[function_name.value] + self._assert( + len(arguments) == len(expected_kinds), + "Wrong number of arguments in", all_children) + for i in range(len(expected_kinds)): + self._assert( + arguments[i].kind in expected_kinds[i], + "Wrong type for argument %d in" % i, all_children) + if function_name.value == 'size': + nonterminal = self.Nonterminal.OPERAND + else: + nonterminal = self.Nonterminal.CONDITION + output.append(self.Node( + nonterminal=nonterminal, + kind=self.Kind.FUNCTION, + text=" ".join([t.text for t in all_children]), + value=None, + children=[function_name] + arguments)) else: - # Must be operator kw + output.append(nodes.popleft()) + return output - # Cheat, NOT is our only RIGHT associative operator, should really have dict of operator associativity - while len(op_stack) > 0 and OPS[op_stack[-1]] <= OPS[token] and op_stack[-1] != 'NOT': - output.append(op_stack.pop()) - op_stack.append(token) - while len(op_stack) > 0: - output.append(op_stack.pop()) + def _apply_parens_and_booleans(self, nodes, left_paren=None): + """Apply condition := ( condition ) and booleans.""" + output = deque() + while nodes: + if self._matches(nodes, ['LEFT_PAREN']): + parsed = self._apply_parens_and_booleans(nodes, left_paren=nodes.popleft()) + self._assert( + len(parsed) >= 1, + "Failed to close parentheses at", nodes) + parens = parsed.popleft() + self._assert( + parens.kind == self.Kind.PARENTHESES, + "Failed to close parentheses at", nodes) + output.append(parens) + nodes = parsed + elif self._matches(nodes, ['RIGHT_PAREN']): + self._assert( + left_paren is not None, + "Unmatched ) at", nodes) + close_paren = nodes.popleft() + children = self._apply_booleans(output) + all_children = [left_paren, *children, close_paren] + return deque([ + self.Node( + nonterminal=self.Nonterminal.CONDITION, + kind=self.Kind.PARENTHESES, + text=" ".join([t.text for t in all_children]), + value=None, + children=list(children), + ), *nodes]) + else: + output.append(nodes.popleft()) + + self._assert( + left_paren is None, + "Unmatched ( at", list(output)) + return self._apply_booleans(output) + + def _apply_booleans(self, nodes): + """Apply and, or, and not constructions.""" + nodes = self._apply_not(nodes) + nodes = self._apply_and(nodes) + nodes = self._apply_or(nodes) + # The expression should reduce to a single condition + self._assert( + len(nodes) == 1, + "Unexpected expression at", list(nodes)[1:]) + self._assert( + nodes[0].nonterminal == self.Nonterminal.CONDITION, + "Incomplete condition", nodes) + return nodes + + def _apply_not(self, nodes): + """Apply condition := NOT condition.""" + output = deque() + while nodes: + if self._matches(nodes, ['NOT']): + self._assert( + self._matches(nodes, ['NOT', 'CONDITION']), + "Bad NOT expression", list(nodes)[:2]) + not_node = nodes.popleft() + child = nodes.popleft() + output.append(self.Node( + nonterminal=self.Nonterminal.CONDITION, + kind=self.Kind.NOT, + text=" ".join([not_node.text, child.text]), + value=None, + children=[child])) + else: + output.append(nodes.popleft()) return output - output = shunting_yard(tokens2) - - # Hacky function to convert dynamo functions (which are represented as lists) to their Class equivalent - def to_func(val): - if isinstance(val, list): - func_name = val.pop(0) - # Expand rest of the list to arguments - val = FUNC_CLASS[func_name](*val) - - return val - - # Simple reverse polish notation execution. Builts up a nested filter object. - # The filter object then takes a dynamo item and returns true/false - stack = [] - for token in output: - if is_op(token): - op_cls = OP_CLASS[token] - - if token == 'NOT': - op1 = stack.pop() - op2 = True + def _apply_and(self, nodes): + """Apply condition := condition AND condition.""" + output = deque() + while nodes: + if self._matches(nodes, ['*', 'AND']): + self._assert( + self._matches(nodes, ['CONDITION', 'AND', 'CONDITION']), + "Bad AND expression", list(nodes)[:3]) + lhs = nodes.popleft() + and_node = nodes.popleft() + rhs = nodes.popleft() + all_children = [lhs, and_node, rhs] + output.append(self.Node( + nonterminal=self.Nonterminal.CONDITION, + kind=self.Kind.AND, + text=" ".join([t.text for t in all_children]), + value=None, + children=[lhs, rhs])) else: - op2 = stack.pop() - op1 = stack.pop() + output.append(nodes.popleft()) - stack.append(op_cls(op1, op2)) + return output + + def _apply_or(self, nodes): + """Apply condition := condition OR condition.""" + output = deque() + while nodes: + if self._matches(nodes, ['*', 'OR']): + self._assert( + self._matches(nodes, ['CONDITION', 'OR', 'CONDITION']), + "Bad OR expression", list(nodes)[:3]) + lhs = nodes.popleft() + or_node = nodes.popleft() + rhs = nodes.popleft() + all_children = [lhs, or_node, rhs] + output.append(self.Node( + nonterminal=self.Nonterminal.CONDITION, + kind=self.Kind.OR, + text=" ".join([t.text for t in all_children]), + value=None, + children=[lhs, rhs])) + else: + output.append(nodes.popleft()) + + return output + + def _make_operand(self, node): + if node.kind == self.Kind.PATH: + return AttributePath([child.value for child in node.children]) + elif node.kind == self.Kind.EXPRESSION_ATTRIBUTE_VALUE: + return AttributeValue(node.value) + elif node.kind == self.Kind.FUNCTION: + # size() + function_node, *arguments = node.children + function_name = function_node.value + arguments = [self._make_operand(arg) for arg in arguments] + return FUNC_CLASS[function_name](*arguments) else: - stack.append(to_func(token)) - - result = stack.pop(0) - if len(stack) > 0: - raise ValueError('Malformed filter expression, type2') - - return result + raise ValueError("Unknown operand: %r" % node) -class Op(object): - """ - Base class for a FilterExpression operator - """ - OP = '' + def _make_op_condition(self, node): + if node.kind == self.Kind.OR: + lhs, rhs = node.children + return OpOr( + self._make_op_condition(lhs), + self._make_op_condition(rhs)) + elif node.kind == self.Kind.AND: + lhs, rhs = node.children + return OpAnd( + self._make_op_condition(lhs), + self._make_op_condition(rhs)) + elif node.kind == self.Kind.NOT: + child, = node.children + return OpNot(self._make_op_condition(child), None) + elif node.kind == self.Kind.PARENTHESES: + child, = node.children + return self._make_op_condition(child) + elif node.kind == self.Kind.FUNCTION: + function_node, *arguments = node.children + function_name = function_node.value + arguments = [self._make_operand(arg) for arg in arguments] + return FUNC_CLASS[function_name](*arguments) + elif node.kind == self.Kind.BETWEEN: + query, low, high = node.children + return FuncBetween( + self._make_operand(query), + self._make_operand(low), + self._make_operand(high)) + elif node.kind == self.Kind.IN: + query, *possible_values = node.children + query = self._make_operand(query) + possible_values = [self._make_operand(v) for v in possible_values] + return FuncIn(query, *possible_values) + elif node.kind == self.Kind.COMPARISON: + lhs, comparator, rhs = node.children + return OP_CLASS[comparator.value]( + self._make_operand(lhs), + self._make_operand(rhs)) + else: + raise ValueError("Unknown expression node kind %r" % node.kind) - def __init__(self, lhs, rhs): - self.lhs = lhs - self.rhs = rhs + def _print_debug(self, nodes): + print('ROOT') + for node in nodes: + self._print_node_recursive(node, depth=1) + + def _print_node_recursive(self, node, depth=0): + if len(node.children) > 0: + print(' ' * depth, node.nonterminal, node.kind) + for child in node.children: + self._print_node_recursive(child, depth=depth + 1) + else: + print(' ' * depth, node.nonterminal, node.kind, node.value) + + + + def _assert(self, condition, message, nodes): + if not condition: + raise ValueError(message + " " + " ".join([t.text for t in nodes])) + + +class Operand(object): + def expr(self, item): + raise NotImplementedError + + def get_type(self, item): + raise NotImplementedError + + +class AttributePath(Operand): + def __init__(self, path): + """Initialize the AttributePath. + + Parameters + ---------- + path: list of int/str - def _lhs(self, item): """ - :type item: moto.dynamodb2.models.Item - """ - lhs = self.lhs - if isinstance(self.lhs, (Op, Func)): - lhs = self.lhs.expr(item) - elif isinstance(self.lhs, six.string_types): - try: - lhs = item.attrs[self.lhs].cast_value - except Exception: - pass + assert len(path) >= 1 + self.path = path - return lhs - - def _rhs(self, item): - rhs = self.rhs - if isinstance(self.rhs, (Op, Func)): - rhs = self.rhs.expr(item) - elif isinstance(self.rhs, six.string_types): - try: - rhs = item.attrs[self.rhs].cast_value - except Exception: - pass - return rhs + def _get_attr(self, item): + base = self.path[0] + if base not in item.attrs: + return None + attr = item.attrs[base] + for name in self.path[1:]: + attr = attr.child_attr(name) + if attr is None: + return None + return attr def expr(self, item): - return True + attr = self._get_attr(item) + if attr is None: + return None + else: + return attr.cast_value + + def get_type(self, item): + attr = self._get_attr(item) + if attr is None: + return None + else: + return attr.type def __repr__(self): - return '({0} {1} {2})'.format(self.lhs, self.OP, self.rhs) + return self.path -class Func(object): - """ - Base class for a FilterExpression function - """ - FUNC = 'Unknown' +class AttributeValue(Operand): + def __init__(self, value): + """Initialize the AttributePath. + + Parameters + ---------- + value: dict + e.g. {'N': '1.234'} + + """ + self.type = list(value.keys())[0] + if 'N' in value: + self.value = float(value['N']) + elif 'BOOL' in value: + self.value = value['BOOL'] + elif 'S' in value: + self.value = value['S'] + elif 'NS' in value: + self.value = tuple(value['NS']) + elif 'SS' in value: + self.value = tuple(value['SS']) + elif 'L' in value: + self.value = tuple(value['L']) + else: + # TODO: Handle all attribute types + raise NotImplementedError() def expr(self, item): - return True + return self.value + + def get_type(self, item): + return self.type def __repr__(self): - return 'Func(...)'.format(self.FUNC) + return repr(self.value) + + +class OpDefault(Op): + OP = 'NONE' + + def expr(self, item): + """If no condition is specified, always True.""" + return True class OpNot(Op): OP = 'NOT' def expr(self, item): - lhs = self._lhs(item) - + lhs = self.lhs.expr(item) return not lhs def __str__(self): @@ -345,8 +870,8 @@ class OpAnd(Op): OP = 'AND' def expr(self, item): - lhs = self._lhs(item) - rhs = self._rhs(item) + lhs = self.lhs.expr(item) + rhs = self.rhs.expr(item) return lhs and rhs @@ -354,8 +879,8 @@ class OpLessThan(Op): OP = '<' def expr(self, item): - lhs = self._lhs(item) - rhs = self._rhs(item) + lhs = self.lhs.expr(item) + rhs = self.rhs.expr(item) return lhs < rhs @@ -363,8 +888,8 @@ class OpGreaterThan(Op): OP = '>' def expr(self, item): - lhs = self._lhs(item) - rhs = self._rhs(item) + lhs = self.lhs.expr(item) + rhs = self.rhs.expr(item) return lhs > rhs @@ -372,8 +897,8 @@ class OpEqual(Op): OP = '=' def expr(self, item): - lhs = self._lhs(item) - rhs = self._rhs(item) + lhs = self.lhs.expr(item) + rhs = self.rhs.expr(item) return lhs == rhs @@ -381,8 +906,8 @@ class OpNotEqual(Op): OP = '<>' def expr(self, item): - lhs = self._lhs(item) - rhs = self._rhs(item) + lhs = self.lhs.expr(item) + rhs = self.rhs.expr(item) return lhs != rhs @@ -390,8 +915,8 @@ class OpLessThanOrEqual(Op): OP = '<=' def expr(self, item): - lhs = self._lhs(item) - rhs = self._rhs(item) + lhs = self.lhs.expr(item) + rhs = self.rhs.expr(item) return lhs <= rhs @@ -399,8 +924,8 @@ class OpGreaterThanOrEqual(Op): OP = '>=' def expr(self, item): - lhs = self._lhs(item) - rhs = self._rhs(item) + lhs = self.lhs.expr(item) + rhs = self.rhs.expr(item) return lhs >= rhs @@ -408,8 +933,8 @@ class OpOr(Op): OP = 'OR' def expr(self, item): - lhs = self._lhs(item) - rhs = self._rhs(item) + lhs = self.lhs.expr(item) + rhs = self.rhs.expr(item) return lhs or rhs @@ -417,19 +942,38 @@ class OpIn(Op): OP = 'IN' def expr(self, item): - lhs = self._lhs(item) - rhs = self._rhs(item) + lhs = self.lhs.expr(item) + rhs = self.rhs.expr(item) return lhs in rhs +class Func(object): + """ + Base class for a FilterExpression function + """ + FUNC = 'Unknown' + + def __init__(self, *arguments): + self.arguments = arguments + + def expr(self, item): + raise NotImplementedError + + def __repr__(self): + return '{0}({1})'.format( + self.FUNC, + " ".join([repr(arg) for arg in self.arguments])) + + class FuncAttrExists(Func): FUNC = 'attribute_exists' def __init__(self, attribute): self.attr = attribute + super().__init__(attribute) def expr(self, item): - return self.attr in item.attrs + return self.attr.get_type(item) is not None class FuncAttrNotExists(Func): @@ -437,9 +981,10 @@ class FuncAttrNotExists(Func): def __init__(self, attribute): self.attr = attribute + super().__init__(attribute) def expr(self, item): - return self.attr not in item.attrs + return self.attr.get_type(item) is None class FuncAttrType(Func): @@ -448,9 +993,10 @@ class FuncAttrType(Func): def __init__(self, attribute, _type): self.attr = attribute self.type = _type + super().__init__(attribute, _type) def expr(self, item): - return self.attr in item.attrs and item.attrs[self.attr].type == self.type + return self.attr.get_type(item) == self.type.expr(item) class FuncBeginsWith(Func): @@ -459,9 +1005,14 @@ class FuncBeginsWith(Func): def __init__(self, attribute, substr): self.attr = attribute self.substr = substr + super().__init__(attribute, substr) def expr(self, item): - return self.attr in item.attrs and item.attrs[self.attr].type == 'S' and item.attrs[self.attr].value.startswith(self.substr) + if self.attr.get_type(item) != 'S': + return False + if self.substr.get_type(item) != 'S': + return False + return self.attr.expr(item).startswith(self.substr.expr(item)) class FuncContains(Func): @@ -470,13 +1021,11 @@ class FuncContains(Func): def __init__(self, attribute, operand): self.attr = attribute self.operand = operand + super().__init__(attribute, operand) def expr(self, item): - if self.attr not in item.attrs: - return False - - if item.attrs[self.attr].type in ('S', 'SS', 'NS', 'BS', 'L', 'M'): - return self.operand in item.attrs[self.attr].value + if self.attr.get_type(item) in ('S', 'SS', 'NS', 'BS', 'L', 'M'): + return self.operand.expr(item) in self.attr.expr(item) return False @@ -485,29 +1034,44 @@ class FuncSize(Func): def __init__(self, attribute): self.attr = attribute + super().__init__(attribute) def expr(self, item): - if self.attr not in item.attrs: + if self.attr.get_type(item) is None: raise ValueError('Invalid attribute name {0}'.format(self.attr)) - if item.attrs[self.attr].type in ('S', 'SS', 'NS', 'B', 'BS', 'L', 'M'): - return len(item.attrs[self.attr].value) + if self.attr.get_type(item) in ('S', 'SS', 'NS', 'B', 'BS', 'L', 'M'): + return len(self.attr.expr(item)) raise ValueError('Invalid filter expression') class FuncBetween(Func): - FUNC = 'between' + FUNC = 'BETWEEN' def __init__(self, attribute, start, end): self.attr = attribute self.start = start self.end = end + super().__init__(attribute, start, end) def expr(self, item): - if self.attr not in item.attrs: - raise ValueError('Invalid attribute name {0}'.format(self.attr)) + return self.start.expr(item) <= self.attr.expr(item) <= self.end.expr(item) - return self.start <= item.attrs[self.attr].cast_value <= self.end + +class FuncIn(Func): + FUNC = 'IN' + + def __init__(self, attribute, *possible_values): + self.attr = attribute + self.possible_values = possible_values + super().__init__(attribute, *possible_values) + + def expr(self, item): + for possible_value in self.possible_values: + if self.attr.expr(item) == possible_value.expr(item): + return True + + return False OP_CLASS = { diff --git a/moto/dynamodb2/condition.py b/moto/dynamodb2/condition.py new file mode 100644 index 00000000..b50678e2 --- /dev/null +++ b/moto/dynamodb2/condition.py @@ -0,0 +1,617 @@ +import re +import json +import enum +from collections import deque +from collections import namedtuple + + +class Kind(enum.Enum): + """Defines types of nodes in the syntax tree.""" + + # Condition nodes + # --------------- + OR = enum.auto() + AND = enum.auto() + NOT = enum.auto() + PARENTHESES = enum.auto() + FUNCTION = enum.auto() + BETWEEN = enum.auto() + IN = enum.auto() + COMPARISON = enum.auto() + + # Operand nodes + # ------------- + EXPRESSION_ATTRIBUTE_VALUE = enum.auto() + PATH = enum.auto() + + # Literal nodes + # -------------- + LITERAL = enum.auto() + + +class Nonterminal(enum.Enum): + """Defines nonterminals for defining productions.""" + CONDITION = enum.auto() + OPERAND = enum.auto() + COMPARATOR = enum.auto() + FUNCTION_NAME = enum.auto() + IDENTIFIER = enum.auto() + AND = enum.auto() + OR = enum.auto() + NOT = enum.auto() + BETWEEN = enum.auto() + IN = enum.auto() + COMMA = enum.auto() + LEFT_PAREN = enum.auto() + RIGHT_PAREN = enum.auto() + WHITESPACE = enum.auto() + + +Node = namedtuple('Node', ['nonterminal', 'kind', 'text', 'value', 'children']) + + +class ConditionExpressionParser: + def __init__(self, condition_expression, expression_attribute_names, + expression_attribute_values): + self.condition_expression = condition_expression + self.expression_attribute_names = expression_attribute_names + self.expression_attribute_values = expression_attribute_values + + def parse(self): + """Returns a syntax tree for the expression. + + The tree, and all of the nodes in the tree are a tuple of + - kind: str + - children/value: + list of nodes for parent nodes + value for leaf nodes + + Raises AssertionError if the condition expression is invalid + Raises KeyError if expression attribute names/values are invalid + + Here are the types of nodes that can be returned. + The types of child nodes are denoted with a colon (:). + An arbitrary number of children is denoted with ... + + Condition: + ('OR', [lhs : Condition, rhs : Condition]) + ('AND', [lhs: Condition, rhs: Condition]) + ('NOT', [argument: Condition]) + ('PARENTHESES', [argument: Condition]) + ('FUNCTION', [('LITERAL', function_name: str), argument: Operand, ...]) + ('BETWEEN', [query: Operand, low: Operand, high: Operand]) + ('IN', [query: Operand, possible_value: Operand, ...]) + ('COMPARISON', [lhs: Operand, ('LITERAL', comparator: str), rhs: Operand]) + + Operand: + ('EXPRESSION_ATTRIBUTE_VALUE', value: dict, e.g. {'S': 'foobar'}) + ('PATH', [('LITERAL', path_element: str), ...]) + NOTE: Expression attribute names will be expanded + + Literal: + ('LITERAL', value: str) + + """ + if not self.condition_expression: + return None + nodes = self._lex_condition_expression() + nodes = self._parse_paths(nodes) + self._print_debug(nodes) + nodes = self._apply_comparator(nodes) + self._print_debug(nodes) + nodes = self._apply_in(nodes) + self._print_debug(nodes) + nodes = self._apply_between(nodes) + self._print_debug(nodes) + nodes = self._apply_functions(nodes) + self._print_debug(nodes) + nodes = self._apply_parens_and_booleans(nodes) + self._print_debug(nodes) + node = nodes[0] + return self._make_node_tree(node) + + def _lex_condition_expression(self): + nodes = deque() + remaining_expression = self.condition_expression + while remaining_expression: + node, remaining_expression = \ + self._lex_one_node(remaining_expression) + if node.nonterminal == Nonterminal.WHITESPACE: + continue + nodes.append(node) + return nodes + + def _lex_one_node(self, remaining_expression): + + attribute_regex = '(:|#)?[A-z0-9\-_]+' + patterns = [( + Nonterminal.WHITESPACE, re.compile('^ +') + ), ( + Nonterminal.COMPARATOR, re.compile( + '^(' + '=|' + '<>|' + '<|' + '<=|' + '>|' + '>=)'), + ), ( + Nonterminal.OPERAND, re.compile( + '^' + + attribute_regex + '(\.' + attribute_regex + ')*') + ), ( + Nonterminal.COMMA, re.compile('^,') + ), ( + Nonterminal.LEFT_PAREN, re.compile('^\(') + ), ( + Nonterminal.RIGHT_PAREN, re.compile('^\)') + )] + + for nonterminal, pattern in patterns: + match = pattern.match(remaining_expression) + if match: + match_text = match.group() + break + else: + raise AssertionError("Cannot parse condition starting at: " + + remaining_expression) + + value = match_text + node = Node( + nonterminal=nonterminal, + kind=Kind.LITERAL, + text=match_text, + value=match_text, + children=[]) + + remaining_expression = remaining_expression[len(match_text):] + + return node, remaining_expression + + def _parse_paths(self, nodes): + output = deque() + + while nodes: + node = nodes.popleft() + + if node.nonterminal == Nonterminal.OPERAND: + path = node.value.split('.') + children = [ + self._parse_path_element(name) + for name in path] + if len(children) == 1: + child = children[0] + if child.nonterminal != Nonterminal.IDENTIFIER: + output.append(child) + continue + else: + for child in children: + self._assert( + child.nonterminal == Nonterminal.IDENTIFIER, + "Cannot use %s in path" % child.text, [node]) + output.append(Node( + nonterminal=Nonterminal.OPERAND, + kind=Kind.PATH, + text=node.text, + value=None, + children=children)) + else: + output.append(node) + return output + + def _parse_path_element(self, name): + reserved = { + 'AND': Nonterminal.AND, + 'OR': Nonterminal.OR, + 'IN': Nonterminal.IN, + 'BETWEEN': Nonterminal.BETWEEN, + 'NOT': Nonterminal.NOT, + } + + functions = { + 'attribute_exists', + 'attribute_not_exists', + 'attribute_type', + 'begins_with', + 'contains', + 'size', + } + + + if name in reserved: + nonterminal = reserved[name] + return Node( + nonterminal=nonterminal, + kind=Kind.LITERAL, + text=name, + value=name, + children=[]) + elif name in functions: + return Node( + nonterminal=Nonterminal.FUNCTION_NAME, + kind=Kind.LITERAL, + text=name, + value=name, + children=[]) + elif name.startswith(':'): + return Node( + nonterminal=Nonterminal.OPERAND, + kind=Kind.EXPRESSION_ATTRIBUTE_VALUE, + text=name, + value=self._lookup_expression_attribute_value(name), + children=[]) + elif name.startswith('#'): + return Node( + nonterminal=Nonterminal.IDENTIFIER, + kind=Kind.LITERAL, + text=name, + value=self._lookup_expression_attribute_name(name), + children=[]) + else: + return Node( + nonterminal=Nonterminal.IDENTIFIER, + kind=Kind.LITERAL, + text=name, + value=name, + children=[]) + + def _lookup_expression_attribute_value(self, name): + return self.expression_attribute_values[name] + + def _lookup_expression_attribute_name(self, name): + return self.expression_attribute_names[name] + + # NOTE: The following constructions are ordered from high precedence to low precedence + # according to + # https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Expressions.OperatorsAndFunctions.html#Expressions.OperatorsAndFunctions.Precedence + # + # = <> < <= > >= + # IN + # BETWEEN + # attribute_exists attribute_not_exists begins_with contains + # Parentheses + # NOT + # AND + # OR + # + # The grammar is taken from + # https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Expressions.OperatorsAndFunctions.html#Expressions.OperatorsAndFunctions.Syntax + # + # condition-expression ::= + # operand comparator operand + # operand BETWEEN operand AND operand + # operand IN ( operand (',' operand (, ...) )) + # function + # condition AND condition + # condition OR condition + # NOT condition + # ( condition ) + # + # comparator ::= + # = + # <> + # < + # <= + # > + # >= + # + # function ::= + # attribute_exists (path) + # attribute_not_exists (path) + # attribute_type (path, type) + # begins_with (path, substr) + # contains (path, operand) + # size (path) + + def _matches(self, nodes, production): + """Check if the nodes start with the given production. + + Parameters + ---------- + nodes: list of Node + production: list of str + The name of a Nonterminal, or '*' for anything + + """ + if len(nodes) < len(production): + return False + for i in range(len(production)): + if production[i] == '*': + continue + expected = getattr(Nonterminal, production[i]) + if nodes[i].nonterminal != expected: + return False + return True + + def _apply_comparator(self, nodes): + """Apply condition := operand comparator operand.""" + output = deque() + + while nodes: + if self._matches(nodes, ['*', 'COMPARATOR']): + self._assert( + self._matches(nodes, ['OPERAND', 'COMPARATOR', 'OPERAND']), + "Bad comparison", list(nodes)[:3]) + lhs = nodes.popleft() + comparator = nodes.popleft() + rhs = nodes.popleft() + output.append(Node( + nonterminal=Nonterminal.CONDITION, + kind=Kind.COMPARISON, + text=" ".join([ + lhs.text, + comparator.text, + rhs.text]), + value=None, + children=[lhs, comparator, rhs])) + else: + output.append(nodes.popleft()) + return output + + def _apply_in(self, nodes): + """Apply condition := operand IN ( operand , ... ).""" + output = deque() + while nodes: + if self._matches(nodes, ['*', 'IN']): + self._assert( + self._matches(nodes, ['OPERAND', 'IN', 'LEFT_PAREN']), + "Bad IN expression", list(nodes)[:3]) + lhs = nodes.popleft() + in_node = nodes.popleft() + left_paren = nodes.popleft() + all_children = [lhs, in_node, left_paren] + rhs = [] + while True: + if self._matches(nodes, ['OPERAND', 'COMMA']): + operand = nodes.popleft() + separator = nodes.popleft() + all_children += [operand, separator] + rhs.append(operand) + elif self._matches(nodes, ['OPERAND', 'RIGHT_PAREN']): + operand = nodes.popleft() + separator = nodes.popleft() + all_children += [operand, separator] + rhs.append(operand) + break # Close + else: + self._assert( + False, + "Bad IN expression starting at", nodes) + output.append(Node( + nonterminal=Nonterminal.CONDITION, + kind=Kind.IN, + text=" ".join([t.text for t in all_children]), + value=None, + children=[lhs] + rhs)) + else: + output.append(nodes.popleft()) + return output + + def _apply_between(self, nodes): + """Apply condition := operand BETWEEN operand AND operand.""" + output = deque() + while nodes: + if self._matches(nodes, ['*', 'BETWEEN']): + self._assert( + self._matches(nodes, ['OPERAND', 'BETWEEN', 'OPERAND', + 'AND', 'OPERAND']), + "Bad BETWEEN expression", list(nodes)[:5]) + lhs = nodes.popleft() + between_node = nodes.popleft() + low = nodes.popleft() + and_node = nodes.popleft() + high = nodes.popleft() + all_children = [lhs, between_node, low, and_node, high] + output.append(Node( + nonterminal=Nonterminal.CONDITION, + kind=Kind.BETWEEN, + text=" ".join([t.text for t in all_children]), + value=None, + children=[lhs, low, high])) + else: + output.append(nodes.popleft()) + return output + + def _apply_functions(self, nodes): + """Apply condition := function_name (operand , ...).""" + output = deque() + expected_argument_kind_map = { + 'attribute_exists': [{Kind.PATH}], + 'attribute_not_exists': [{Kind.PATH}], + 'attribute_type': [{Kind.PATH}, {Kind.EXPRESSION_ATTRIBUTE_VALUE}], + 'begins_with': [{Kind.PATH}, {Kind.EXPRESSION_ATTRIBUTE_VALUE}], + 'contains': [{Kind.PATH}, {Kind.PATH, Kind.EXPRESSION_ATTRIBUTE_VALUE}], + 'size': [{Kind.PATH}], + } + while nodes: + if self._matches(nodes, ['FUNCTION_NAME']): + self._assert( + self._matches(nodes, ['FUNCTION_NAME', 'LEFT_PAREN', + 'OPERAND', '*']), + "Bad function expression at", list(nodes)[:4]) + function_name = nodes.popleft() + left_paren = nodes.popleft() + all_children = [function_name, left_paren] + arguments = [] + while True: + if self._matches(nodes, ['OPERAND', 'COMMA']): + operand = nodes.popleft() + separator = nodes.popleft() + all_children += [operand, separator] + arguments.append(operand) + elif self._matches(nodes, ['OPERAND', 'RIGHT_PAREN']): + operand = nodes.popleft() + separator = nodes.popleft() + all_children += [operand, separator] + arguments.append(operand) + break # Close paren + else: + self._assert( + False, + "Bad function expression", all_children + list(nodes)[:2]) + expected_kinds = expected_argument_kind_map[function_name.value] + self._assert( + len(arguments) == len(expected_kinds), + "Wrong number of arguments in", all_children) + for i in range(len(expected_kinds)): + self._assert( + arguments[i].kind in expected_kinds[i], + "Wrong type for argument %d in" % i, all_children) + output.append(Node( + nonterminal=Nonterminal.CONDITION, + kind=Kind.FUNCTION, + text=" ".join([t.text for t in all_children]), + value=None, + children=[function_name] + arguments)) + else: + output.append(nodes.popleft()) + return output + + def _apply_parens_and_booleans(self, nodes, left_paren=None): + """Apply condition := ( condition ) and booleans.""" + output = deque() + while nodes: + if self._matches(nodes, ['LEFT_PAREN']): + parsed = self._apply_parens_and_booleans(nodes, left_paren=nodes.popleft()) + self._assert( + len(parsed) >= 1, + "Failed to close parentheses at", nodes) + parens = parsed.popleft() + self._assert( + parens.kind == Kind.PARENTHESES, + "Failed to close parentheses at", nodes) + output.append(parens) + nodes = parsed + elif self._matches(nodes, ['RIGHT_PAREN']): + self._assert( + left_paren is not None, + "Unmatched ) at", nodes) + close_paren = nodes.popleft() + children = self._apply_booleans(output) + all_children = [left_paren, *children, close_paren] + return deque([ + Node( + nonterminal=Nonterminal.CONDITION, + kind=Kind.PARENTHESES, + text=" ".join([t.text for t in all_children]), + value=None, + children=list(children), + ), *nodes]) + else: + output.append(nodes.popleft()) + + self._assert( + left_paren is None, + "Unmatched ( at", list(output)) + return self._apply_booleans(output) + + def _apply_booleans(self, nodes): + """Apply and, or, and not constructions.""" + nodes = self._apply_not(nodes) + nodes = self._apply_and(nodes) + nodes = self._apply_or(nodes) + # The expression should reduce to a single condition + self._assert( + len(nodes) == 1, + "Unexpected expression at", list(nodes)[1:]) + self._assert( + nodes[0].nonterminal == Nonterminal.CONDITION, + "Incomplete condition", nodes) + return nodes + + def _apply_not(self, nodes): + """Apply condition := NOT condition.""" + output = deque() + while nodes: + if self._matches(nodes, ['NOT']): + self._assert( + self._matches(nodes, ['NOT', 'CONDITION']), + "Bad NOT expression", list(nodes)[:2]) + not_node = nodes.popleft() + child = nodes.popleft() + output.append(Node( + nonterminal=Nonterminal.CONDITION, + kind=Kind.NOT, + text=" ".join([not_node['text'], value['text']]), + value=None, + children=[child])) + else: + output.append(nodes.popleft()) + + return output + + def _apply_and(self, nodes): + """Apply condition := condition AND condition.""" + output = deque() + while nodes: + if self._matches(nodes, ['*', 'AND']): + self._assert( + self._matches(nodes, ['CONDITION', 'AND', 'CONDITION']), + "Bad AND expression", list(nodes)[:3]) + lhs = nodes.popleft() + and_node = nodes.popleft() + rhs = nodes.popleft() + all_children = [lhs, and_node, rhs] + output.append(Node( + nonterminal=Nonterminal.CONDITION, + kind=Kind.AND, + text=" ".join([t.text for t in all_children]), + value=None, + children=[lhs, rhs])) + else: + output.append(nodes.popleft()) + + return output + + def _apply_or(self, nodes): + """Apply condition := condition OR condition.""" + output = deque() + while nodes: + if self._matches(nodes, ['*', 'OR']): + self._assert( + self._matches(nodes, ['CONDITION', 'OR', 'CONDITION']), + "Bad OR expression", list(nodes)[:3]) + lhs = nodes.popleft() + or_node = nodes.popleft() + rhs = nodes.popleft() + all_children = [lhs, or_node, rhs] + output.append(Node( + nonterminal=Nonterminal.CONDITION, + kind=Kind.OR, + text=" ".join([t.text for t in all_children]), + value=None, + children=[lhs, rhs])) + else: + output.append(nodes.popleft()) + + return output + + def _make_node_tree(self, node): + if len(node.children) > 0: + return ( + node.kind.name, + [ + self._make_node_tree(child) + for child in node.children + ]) + else: + return (node.kind.name, node.value) + + def _print_debug(self, nodes): + print('ROOT') + for node in nodes: + self._print_node_recursive(node, depth=1) + + def _print_node_recursive(self, node, depth=0): + if len(node.children) > 0: + print(' ' * depth, node.nonterminal, node.kind) + for child in node.children: + self._print_node_recursive(child, depth=depth + 1) + else: + print(' ' * depth, node.nonterminal, node.kind, node.value) + + + + def _assert(self, condition, message, nodes): + if not condition: + raise AssertionError(message + " " + " ".join([t.text for t in nodes])) diff --git a/moto/dynamodb2/models.py b/moto/dynamodb2/models.py index 6bcde41b..300479e9 100644 --- a/moto/dynamodb2/models.py +++ b/moto/dynamodb2/models.py @@ -68,10 +68,34 @@ class DynamoType(object): except ValueError: return float(self.value) elif self.is_set(): - return set(self.value) + sub_type = self.type[0] + return set([DynamoType({sub_type: v}).cast_value + for v in self.value]) + elif self.is_list(): + return [DynamoType(v).cast_value for v in self.value] + elif self.is_map(): + return dict([ + (k, DynamoType(v).cast_value) + for k, v in self.value.items()]) else: return self.value + def child_attr(self, key): + """ + Get Map or List children by key. str for Map, int for List. + + Returns DynamoType or None. + """ + if isinstance(key, str) and self.is_map() and key in self.value: + return DynamoType(self.value[key]) + + if isinstance(key, int) and self.is_list(): + idx = key + if idx >= 0 and idx < len(self.value): + return DynamoType(self.value[idx]) + + return None + def to_json(self): return {self.type: self.value} @@ -89,6 +113,12 @@ class DynamoType(object): def is_set(self): return self.type == 'SS' or self.type == 'NS' or self.type == 'BS' + def is_list(self): + return self.type == 'L' + + def is_map(self): + return self.type == 'M' + def same_type(self, other): return self.type == other.type @@ -954,10 +984,7 @@ class DynamoDBBackend(BaseBackend): range_values = [DynamoType(range_value) for range_value in range_value_dicts] - if filter_expression is not None: - filter_expression = get_filter_expression(filter_expression, expr_names, expr_values) - else: - filter_expression = Op(None, None) # Will always eval to true + filter_expression = get_filter_expression(filter_expression, expr_names, expr_values) return table.query(hash_key, range_comparison, range_values, limit, exclusive_start_key, scan_index_forward, projection_expression, index_name, filter_expression, **filter_kwargs) @@ -972,10 +999,8 @@ class DynamoDBBackend(BaseBackend): dynamo_types = [DynamoType(value) for value in comparison_values] scan_filters[key] = (comparison_operator, dynamo_types) - if filter_expression is not None: - filter_expression = get_filter_expression(filter_expression, expr_names, expr_values) - else: - filter_expression = Op(None, None) # Will always eval to true + + filter_expression = get_filter_expression(filter_expression, expr_names, expr_values) return table.scan(scan_filters, limit, exclusive_start_key, filter_expression, index_name) diff --git a/tests/test_dynamodb2/test_dynamodb.py b/tests/test_dynamodb2/test_dynamodb.py index 77846de0..932139ee 100644 --- a/tests/test_dynamodb2/test_dynamodb.py +++ b/tests/test_dynamodb2/test_dynamodb.py @@ -676,44 +676,47 @@ def test_filter_expression(): filter_expr.expr(row1).should.be(True) # NOT test 2 - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('NOT (Id = :v0)', {}, {':v0': {'N': 8}}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression('NOT (Id = :v0)', {}, {':v0': {'N': '8'}}) filter_expr.expr(row1).should.be(False) # Id = 8 so should be false # AND test - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id > :v0 AND Subs < :v1', {}, {':v0': {'N': 5}, ':v1': {'N': 7}}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id > :v0 AND Subs < :v1', {}, {':v0': {'N': '5'}, ':v1': {'N': '7'}}) filter_expr.expr(row1).should.be(True) filter_expr.expr(row2).should.be(False) # OR test - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id = :v0 OR Id=:v1', {}, {':v0': {'N': 5}, ':v1': {'N': 8}}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id = :v0 OR Id=:v1', {}, {':v0': {'N': '5'}, ':v1': {'N': '8'}}) filter_expr.expr(row1).should.be(True) # BETWEEN test - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id BETWEEN :v0 AND :v1', {}, {':v0': {'N': 5}, ':v1': {'N': 10}}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id BETWEEN :v0 AND :v1', {}, {':v0': {'N': '5'}, ':v1': {'N': '10'}}) filter_expr.expr(row1).should.be(True) # PAREN test - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id = :v0 AND (Subs = :v0 OR Subs = :v1)', {}, {':v0': {'N': 8}, ':v1': {'N': 5}}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id = :v0 AND (Subs = :v0 OR Subs = :v1)', {}, {':v0': {'N': '8'}, ':v1': {'N': '5'}}) filter_expr.expr(row1).should.be(True) # IN test - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id IN :v0', {}, {':v0': {'NS': [7, 8, 9]}}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id IN (:v0, :v1, :v2)', {}, { + ':v0': {'N': '7'}, + ':v1': {'N': '8'}, + ':v2': {'N': '9'}}) filter_expr.expr(row1).should.be(True) # attribute function tests (with extra spaces) filter_expr = moto.dynamodb2.comparisons.get_filter_expression('attribute_exists(Id) AND attribute_not_exists (User)', {}, {}) filter_expr.expr(row1).should.be(True) - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('attribute_type(Id, N)', {}, {}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression('attribute_type(Id, :v0)', {}, {':v0': {'S': 'N'}}) filter_expr.expr(row1).should.be(True) # beginswith function test - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('begins_with(Desc, Some)', {}, {}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression('begins_with(Desc, :v0)', {}, {':v0': {'S': 'Some'}}) filter_expr.expr(row1).should.be(True) filter_expr.expr(row2).should.be(False) # contains function test - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('contains(KV, test1)', {}, {}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression('contains(KV, :v0)', {}, {':v0': {'S': 'test1'}}) filter_expr.expr(row1).should.be(True) filter_expr.expr(row2).should.be(False) @@ -754,14 +757,26 @@ def test_query_filter(): TableName='test1', Item={ 'client': {'S': 'client1'}, - 'app': {'S': 'app1'} + 'app': {'S': 'app1'}, + 'nested': {'M': { + 'version': {'S': 'version1'}, + 'contents': {'L': [ + {'S': 'value1'}, {'S': 'value2'}, + ]}, + }}, } ) client.put_item( TableName='test1', Item={ 'client': {'S': 'client1'}, - 'app': {'S': 'app2'} + 'app': {'S': 'app2'}, + 'nested': {'M': { + 'version': {'S': 'version2'}, + 'contents': {'L': [ + {'S': 'value1'}, {'S': 'value2'}, + ]}, + }}, } ) @@ -783,6 +798,18 @@ def test_query_filter(): ) assert response['Count'] == 2 + response = table.query( + KeyConditionExpression=Key('client').eq('client1'), + FilterExpression=Attr('nested.version').contains('version') + ) + assert response['Count'] == 2 + + response = table.query( + KeyConditionExpression=Key('client').eq('client1'), + FilterExpression=Attr('nested.contents[0]').eq('value1') + ) + assert response['Count'] == 2 + @mock_dynamodb2 def test_scan_filter(): @@ -1061,7 +1088,7 @@ def test_delete_item(): with assert_raises(ClientError) as ex: table.delete_item(Key={'client': 'client1', 'app': 'app1'}, ReturnValues='ALL_NEW') - + # Test deletion and returning old value response = table.delete_item(Key={'client': 'client1', 'app': 'app1'}, ReturnValues='ALL_OLD') response['Attributes'].should.contain('client') @@ -1364,7 +1391,7 @@ def test_put_return_attributes(): ReturnValues='NONE' ) assert 'Attributes' not in r - + r = dynamodb.put_item( TableName='moto-test', Item={'id': {'S': 'foo'}, 'col1': {'S': 'val2'}}, @@ -1381,7 +1408,7 @@ def test_put_return_attributes(): ex.exception.response['Error']['Code'].should.equal('ValidationException') ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) ex.exception.response['Error']['Message'].should.equal('Return values set to invalid value') - + @mock_dynamodb2 def test_query_global_secondary_index_when_created_via_update_table_resource(): @@ -1489,7 +1516,7 @@ def test_dynamodb_streams_1(): 'StreamViewType': 'NEW_AND_OLD_IMAGES' } ) - + assert 'StreamSpecification' in resp['TableDescription'] assert resp['TableDescription']['StreamSpecification'] == { 'StreamEnabled': True, @@ -1497,11 +1524,11 @@ def test_dynamodb_streams_1(): } assert 'LatestStreamLabel' in resp['TableDescription'] assert 'LatestStreamArn' in resp['TableDescription'] - + resp = conn.delete_table(TableName='test-streams') assert 'StreamSpecification' in resp['TableDescription'] - + @mock_dynamodb2 def test_dynamodb_streams_2(): @@ -1532,7 +1559,7 @@ def test_dynamodb_streams_2(): assert 'LatestStreamLabel' in resp['TableDescription'] assert 'LatestStreamArn' in resp['TableDescription'] - + @mock_dynamodb2 def test_condition_expressions(): client = boto3.client('dynamodb', region_name='us-east-1')