Merge pull request #2266 from garrettheel/feat/dynamodb-expressions

Improve DynamoDB condition expression support
This commit is contained in:
Steve Pulec 2019-07-09 18:22:55 -05:00 committed by GitHub
commit 6a13d54616
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 1172 additions and 452 deletions

File diff suppressed because it is too large Load diff

View file

@ -6,13 +6,16 @@ import decimal
import json
import re
import uuid
import six
import boto3
from moto.compat import OrderedDict
from moto.core import BaseBackend, BaseModel
from moto.core.utils import unix_time
from moto.core.exceptions import JsonRESTError
from .comparisons import get_comparison_func, get_filter_expression, Op
from .comparisons import get_comparison_func
from .comparisons import get_filter_expression
from .comparisons import get_expected
from .exceptions import InvalidIndexNameError
@ -68,10 +71,34 @@ class DynamoType(object):
except ValueError:
return float(self.value)
elif self.is_set():
return set(self.value)
sub_type = self.type[0]
return set([DynamoType({sub_type: v}).cast_value
for v in self.value])
elif self.is_list():
return [DynamoType(v).cast_value for v in self.value]
elif self.is_map():
return dict([
(k, DynamoType(v).cast_value)
for k, v in self.value.items()])
else:
return self.value
def child_attr(self, key):
"""
Get Map or List children by key. str for Map, int for List.
Returns DynamoType or None.
"""
if isinstance(key, six.string_types) and self.is_map() and key in self.value:
return DynamoType(self.value[key])
if isinstance(key, int) and self.is_list():
idx = key
if idx >= 0 and idx < len(self.value):
return DynamoType(self.value[idx])
return None
def to_json(self):
return {self.type: self.value}
@ -89,6 +116,12 @@ class DynamoType(object):
def is_set(self):
return self.type == 'SS' or self.type == 'NS' or self.type == 'BS'
def is_list(self):
return self.type == 'L'
def is_map(self):
return self.type == 'M'
def same_type(self, other):
return self.type == other.type
@ -504,7 +537,9 @@ class Table(BaseModel):
keys.append(range_key)
return keys
def put_item(self, item_attrs, expected=None, overwrite=False):
def put_item(self, item_attrs, expected=None, condition_expression=None,
expression_attribute_names=None,
expression_attribute_values=None, overwrite=False):
hash_value = DynamoType(item_attrs.get(self.hash_key_attr))
if self.has_range_key:
range_value = DynamoType(item_attrs.get(self.range_key_attr))
@ -527,29 +562,15 @@ class Table(BaseModel):
self.range_key_type, item_attrs)
if not overwrite:
if current is None:
current_attr = {}
elif hasattr(current, 'attrs'):
current_attr = current.attrs
else:
current_attr = current
if not get_expected(expected).expr(current):
raise ValueError('The conditional request failed')
condition_op = get_filter_expression(
condition_expression,
expression_attribute_names,
expression_attribute_values)
if not condition_op.expr(current):
raise ValueError('The conditional request failed')
for key, val in expected.items():
if 'Exists' in val and val['Exists'] is False \
or 'ComparisonOperator' in val and val['ComparisonOperator'] == 'NULL':
if key in current_attr:
raise ValueError("The conditional request failed")
elif key not in current_attr:
raise ValueError("The conditional request failed")
elif 'Value' in val and DynamoType(val['Value']).value != current_attr[key].value:
raise ValueError("The conditional request failed")
elif 'ComparisonOperator' in val:
dynamo_types = [
DynamoType(ele) for ele in
val.get("AttributeValueList", [])
]
if not current_attr[key].compare(val['ComparisonOperator'], dynamo_types):
raise ValueError('The conditional request failed')
if range_value:
self.items[hash_value][range_value] = item
else:
@ -902,11 +923,15 @@ class DynamoDBBackend(BaseBackend):
table.global_indexes = list(gsis_by_name.values())
return table
def put_item(self, table_name, item_attrs, expected=None, overwrite=False):
def put_item(self, table_name, item_attrs, expected=None,
condition_expression=None, expression_attribute_names=None,
expression_attribute_values=None, overwrite=False):
table = self.tables.get(table_name)
if not table:
return None
return table.put_item(item_attrs, expected, overwrite)
return table.put_item(item_attrs, expected, condition_expression,
expression_attribute_names,
expression_attribute_values, overwrite)
def get_table_keys_name(self, table_name, keys):
"""
@ -962,10 +987,7 @@ class DynamoDBBackend(BaseBackend):
range_values = [DynamoType(range_value)
for range_value in range_value_dicts]
if filter_expression is not None:
filter_expression = get_filter_expression(filter_expression, expr_names, expr_values)
else:
filter_expression = Op(None, None) # Will always eval to true
filter_expression = get_filter_expression(filter_expression, expr_names, expr_values)
return table.query(hash_key, range_comparison, range_values, limit,
exclusive_start_key, scan_index_forward, projection_expression, index_name, filter_expression, **filter_kwargs)
@ -980,17 +1002,14 @@ class DynamoDBBackend(BaseBackend):
dynamo_types = [DynamoType(value) for value in comparison_values]
scan_filters[key] = (comparison_operator, dynamo_types)
if filter_expression is not None:
filter_expression = get_filter_expression(filter_expression, expr_names, expr_values)
else:
filter_expression = Op(None, None) # Will always eval to true
filter_expression = get_filter_expression(filter_expression, expr_names, expr_values)
projection_expression = ','.join([expr_names.get(attr, attr) for attr in projection_expression.replace(' ', '').split(',')])
return table.scan(scan_filters, limit, exclusive_start_key, filter_expression, index_name, projection_expression)
def update_item(self, table_name, key, update_expression, attribute_updates, expression_attribute_names,
expression_attribute_values, expected=None):
expression_attribute_values, expected=None, condition_expression=None):
table = self.get_table(table_name)
if all([table.hash_key_attr in key, table.range_key_attr in key]):
@ -1009,32 +1028,17 @@ class DynamoDBBackend(BaseBackend):
item = table.get_item(hash_value, range_value)
if item is None:
item_attr = {}
elif hasattr(item, 'attrs'):
item_attr = item.attrs
else:
item_attr = item
if not expected:
expected = {}
for key, val in expected.items():
if 'Exists' in val and val['Exists'] is False \
or 'ComparisonOperator' in val and val['ComparisonOperator'] == 'NULL':
if key in item_attr:
raise ValueError("The conditional request failed")
elif key not in item_attr:
raise ValueError("The conditional request failed")
elif 'Value' in val and DynamoType(val['Value']).value != item_attr[key].value:
raise ValueError("The conditional request failed")
elif 'ComparisonOperator' in val:
dynamo_types = [
DynamoType(ele) for ele in
val.get("AttributeValueList", [])
]
if not item_attr[key].compare(val['ComparisonOperator'], dynamo_types):
raise ValueError('The conditional request failed')
if not get_expected(expected).expr(item):
raise ValueError('The conditional request failed')
condition_op = get_filter_expression(
condition_expression,
expression_attribute_names,
expression_attribute_values)
if not condition_op.expr(item):
raise ValueError('The conditional request failed')
# Update does not fail on new items, so create one
if item is None:

View file

@ -32,67 +32,6 @@ def get_empty_str_error():
))
def condition_expression_to_expected(condition_expression, expression_attribute_names, expression_attribute_values):
"""
Limited condition expression syntax parsing.
Supports Global Negation ex: NOT(inner expressions).
Supports simple AND conditions ex: cond_a AND cond_b and cond_c.
Atomic expressions supported are attribute_exists(key), attribute_not_exists(key) and #key = :value.
"""
expected = {}
if condition_expression and 'OR' not in condition_expression:
reverse_re = re.compile('^NOT\s*\((.*)\)$')
reverse_m = reverse_re.match(condition_expression.strip())
reverse = False
if reverse_m:
reverse = True
condition_expression = reverse_m.group(1)
cond_items = [c.strip() for c in condition_expression.split('AND')]
if cond_items:
exists_re = re.compile('^attribute_exists\s*\((.*)\)$')
not_exists_re = re.compile(
'^attribute_not_exists\s*\((.*)\)$')
equals_re = re.compile('^(#?\w+)\s*=\s*(\:?\w+)')
for cond in cond_items:
exists_m = exists_re.match(cond)
not_exists_m = not_exists_re.match(cond)
equals_m = equals_re.match(cond)
if exists_m:
attribute_name = expression_attribute_names_lookup(exists_m.group(1), expression_attribute_names)
expected[attribute_name] = {'Exists': True if not reverse else False}
elif not_exists_m:
attribute_name = expression_attribute_names_lookup(not_exists_m.group(1), expression_attribute_names)
expected[attribute_name] = {'Exists': False if not reverse else True}
elif equals_m:
attribute_name = expression_attribute_names_lookup(equals_m.group(1), expression_attribute_names)
attribute_value = expression_attribute_values_lookup(equals_m.group(2), expression_attribute_values)
expected[attribute_name] = {
'AttributeValueList': [attribute_value],
'ComparisonOperator': 'EQ' if not reverse else 'NEQ'}
return expected
def expression_attribute_names_lookup(attribute_name, expression_attribute_names):
if attribute_name.startswith('#') and attribute_name in expression_attribute_names:
return expression_attribute_names[attribute_name]
else:
return attribute_name
def expression_attribute_values_lookup(attribute_value, expression_attribute_values):
if isinstance(attribute_value, six.string_types) and \
attribute_value.startswith(':') and\
attribute_value in expression_attribute_values:
return expression_attribute_values[attribute_value]
else:
return attribute_value
class DynamoHandler(BaseResponse):
def get_endpoint_name(self, headers):
@ -288,18 +227,18 @@ class DynamoHandler(BaseResponse):
# Attempt to parse simple ConditionExpressions into an Expected
# expression
if not expected:
condition_expression = self.body.get('ConditionExpression')
expression_attribute_names = self.body.get('ExpressionAttributeNames', {})
expression_attribute_values = self.body.get('ExpressionAttributeValues', {})
expected = condition_expression_to_expected(condition_expression,
expression_attribute_names,
expression_attribute_values)
if expected:
overwrite = False
condition_expression = self.body.get('ConditionExpression')
expression_attribute_names = self.body.get('ExpressionAttributeNames', {})
expression_attribute_values = self.body.get('ExpressionAttributeValues', {})
if condition_expression:
overwrite = False
try:
result = self.dynamodb_backend.put_item(name, item, expected, overwrite)
result = self.dynamodb_backend.put_item(
name, item, expected, condition_expression,
expression_attribute_names, expression_attribute_values,
overwrite)
except ValueError:
er = 'com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException'
return self.error(er, 'A condition specified in the operation could not be evaluated.')
@ -653,13 +592,9 @@ class DynamoHandler(BaseResponse):
# Attempt to parse simple ConditionExpressions into an Expected
# expression
if not expected:
condition_expression = self.body.get('ConditionExpression')
expression_attribute_names = self.body.get('ExpressionAttributeNames', {})
expression_attribute_values = self.body.get('ExpressionAttributeValues', {})
expected = condition_expression_to_expected(condition_expression,
expression_attribute_names,
expression_attribute_values)
condition_expression = self.body.get('ConditionExpression')
expression_attribute_names = self.body.get('ExpressionAttributeNames', {})
expression_attribute_values = self.body.get('ExpressionAttributeValues', {})
# Support spaces between operators in an update expression
# E.g. `a = b + c` -> `a=b+c`
@ -670,7 +605,7 @@ class DynamoHandler(BaseResponse):
try:
item = self.dynamodb_backend.update_item(
name, key, update_expression, attribute_updates, expression_attribute_names,
expression_attribute_values, expected
expression_attribute_values, expected, condition_expression
)
except ValueError:
er = 'com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException'