Merge remote-tracking branch 'spulec/master'

This commit is contained in:
Peter Gorniak 2017-06-02 13:38:36 -07:00
commit 91657a537b
417 changed files with 35876 additions and 7482 deletions

View file

@ -1,3 +1,6 @@
from __future__ import unicode_literals
from .models import dynamodb_backend2
dynamodb_backends2 = {"global": dynamodb_backend2}
mock_dynamodb2 = dynamodb_backend2.decorator
mock_dynamodb2_deprecated = dynamodb_backend2.deprecated_decorator

View file

@ -1,12 +1,12 @@
from __future__ import unicode_literals
# TODO add tests for all of these
EQ_FUNCTION = lambda item_value, test_value: item_value == test_value
NE_FUNCTION = lambda item_value, test_value: item_value != test_value
LE_FUNCTION = lambda item_value, test_value: item_value <= test_value
LT_FUNCTION = lambda item_value, test_value: item_value < test_value
GE_FUNCTION = lambda item_value, test_value: item_value >= test_value
GT_FUNCTION = lambda item_value, test_value: item_value > test_value
EQ_FUNCTION = lambda item_value, test_value: item_value == test_value # flake8: noqa
NE_FUNCTION = lambda item_value, test_value: item_value != test_value # flake8: noqa
LE_FUNCTION = lambda item_value, test_value: item_value <= test_value # flake8: noqa
LT_FUNCTION = lambda item_value, test_value: item_value < test_value # flake8: noqa
GE_FUNCTION = lambda item_value, test_value: item_value >= test_value # flake8: noqa
GT_FUNCTION = lambda item_value, test_value: item_value > test_value # flake8: noqa
COMPARISON_FUNCS = {
'EQ': EQ_FUNCTION,

View file

@ -6,12 +6,13 @@ import json
import re
from moto.compat import OrderedDict
from moto.core import BaseBackend
from moto.core import BaseBackend, BaseModel
from moto.core.utils import unix_time
from .comparisons import get_comparison_func
class DynamoJsonEncoder(json.JSONEncoder):
def default(self, obj):
if hasattr(obj, 'to_json'):
return obj.to_json()
@ -57,7 +58,10 @@ class DynamoType(object):
@property
def cast_value(self):
if self.type == 'N':
return int(self.value)
try:
return int(self.value)
except ValueError:
return float(self.value)
else:
return self.value
@ -73,7 +77,8 @@ class DynamoType(object):
return comparison_func(self.cast_value, *range_values)
class Item(object):
class Item(BaseModel):
def __init__(self, hash_key, hash_key_type, range_key, range_key_type, attrs):
self.hash_key = hash_key
self.hash_key_type = hash_key_type
@ -137,7 +142,7 @@ class Item(object):
def update_with_attribute_updates(self, attribute_updates):
for attribute_name, update_action in attribute_updates.items():
action = update_action['Action']
if action == 'DELETE' and not 'Value' in update_action:
if action == 'DELETE' and 'Value' not in update_action:
if attribute_name in self.attrs:
del self.attrs[attribute_name]
continue
@ -157,17 +162,19 @@ class Item(object):
self.attrs[attribute_name] = DynamoType({"S": new_value})
elif action == 'ADD':
if set(update_action['Value'].keys()) == set(['N']):
existing = self.attrs.get(attribute_name, DynamoType({"N": '0'}))
existing = self.attrs.get(
attribute_name, DynamoType({"N": '0'}))
self.attrs[attribute_name] = DynamoType({"N": str(
decimal.Decimal(existing.value) +
decimal.Decimal(new_value)
decimal.Decimal(existing.value) +
decimal.Decimal(new_value)
)})
else:
# TODO: implement other data types
raise NotImplementedError('ADD not supported for %s' % ', '.join(update_action['Value'].keys()))
raise NotImplementedError(
'ADD not supported for %s' % ', '.join(update_action['Value'].keys()))
class Table(object):
class Table(BaseModel):
def __init__(self, table_name, schema=None, attr=None, throughput=None, indexes=None, global_indexes=None):
self.name = table_name
@ -185,7 +192,8 @@ class Table(object):
self.range_key_attr = elem["AttributeName"]
self.range_key_type = elem["KeyType"]
if throughput is None:
self.throughput = {'WriteCapacityUnits': 10, 'ReadCapacityUnits': 10}
self.throughput = {
'WriteCapacityUnits': 10, 'ReadCapacityUnits': 10}
else:
self.throughput = throughput
self.throughput["NumberOfDecreasesToday"] = 0
@ -193,6 +201,11 @@ class Table(object):
self.global_indexes = global_indexes if global_indexes else []
self.created_at = datetime.datetime.utcnow()
self.items = defaultdict(dict)
self.table_arn = self._generate_arn(table_name)
self.tags = []
def _generate_arn(self, name):
return 'arn:aws:dynamodb:us-east-1:123456789011:table/' + name
def describe(self, base_key='TableDescription'):
results = {
@ -202,11 +215,12 @@ class Table(object):
'TableSizeBytes': 0,
'TableName': self.name,
'TableStatus': 'ACTIVE',
'TableArn': self.table_arn,
'KeySchema': self.schema,
'ItemCount': len(self),
'CreationDateTime': unix_time(self.created_at),
'GlobalSecondaryIndexes': [index for index in self.global_indexes],
'LocalSecondaryIndexes': [index for index in self.indexes]
'LocalSecondaryIndexes': [index for index in self.indexes],
}
}
return results
@ -249,14 +263,16 @@ class Table(object):
else:
range_value = None
item = Item(hash_value, self.hash_key_type, range_value, self.range_key_type, item_attrs)
item = Item(hash_value, self.hash_key_type, range_value,
self.range_key_type, item_attrs)
if not overwrite:
if expected is None:
expected = {}
lookup_range_value = range_value
else:
expected_range_value = expected.get(self.range_key_attr, {}).get("Value")
expected_range_value = expected.get(
self.range_key_attr, {}).get("Value")
if(expected_range_value is None):
lookup_range_value = range_value
else:
@ -280,8 +296,10 @@ class Table(object):
elif 'Value' in val and DynamoType(val['Value']).value != current_attr[key].value:
raise ValueError("The conditional request failed")
elif 'ComparisonOperator' in val:
comparison_func = get_comparison_func(val['ComparisonOperator'])
dynamo_types = [DynamoType(ele) for ele in val["AttributeValueList"]]
comparison_func = get_comparison_func(
val['ComparisonOperator'])
dynamo_types = [DynamoType(ele) for ele in val[
"AttributeValueList"]]
for t in dynamo_types:
if not comparison_func(current_attr[key].value, t.value):
raise ValueError('The conditional request failed')
@ -303,7 +321,8 @@ class Table(object):
def get_item(self, hash_key, range_key=None):
if self.has_range_key and not range_key:
raise ValueError("Table has a range key, but no range key was passed into get_item")
raise ValueError(
"Table has a range key, but no range key was passed into get_item")
try:
if range_key:
return self.items[hash_key][range_key]
@ -337,9 +356,11 @@ class Table(object):
index = indexes_by_name[index_name]
try:
index_hash_key = [key for key in index['KeySchema'] if key['KeyType'] == 'HASH'][0]
index_hash_key = [key for key in index[
'KeySchema'] if key['KeyType'] == 'HASH'][0]
except IndexError:
raise ValueError('Missing Hash Key. KeySchema: %s' % index['KeySchema'])
raise ValueError('Missing Hash Key. KeySchema: %s' %
index['KeySchema'])
possible_results = []
for item in self.all_items():
@ -349,17 +370,20 @@ class Table(object):
if item_hash_key and item_hash_key == hash_key:
possible_results.append(item)
else:
possible_results = [item for item in list(self.all_items()) if isinstance(item, Item) and item.hash_key == hash_key]
possible_results = [item for item in list(self.all_items()) if isinstance(
item, Item) and item.hash_key == hash_key]
if index_name:
try:
index_range_key = [key for key in index['KeySchema'] if key['KeyType'] == 'RANGE'][0]
index_range_key = [key for key in index[
'KeySchema'] if key['KeyType'] == 'RANGE'][0]
except IndexError:
index_range_key = None
if range_comparison:
if index_name and not index_range_key:
raise ValueError('Range Key comparison but no range key found for index: %s' % index_name)
raise ValueError(
'Range Key comparison but no range key found for index: %s' % index_name)
elif index_name:
for result in possible_results:
@ -373,19 +397,21 @@ class Table(object):
if filter_kwargs:
for result in possible_results:
for field, value in filter_kwargs.items():
dynamo_types = [DynamoType(ele) for ele in value["AttributeValueList"]]
dynamo_types = [DynamoType(ele) for ele in value[
"AttributeValueList"]]
if result.attrs.get(field).compare(value['ComparisonOperator'], dynamo_types):
results.append(result)
if not range_comparison and not filter_kwargs:
# If we're not filtering on range key or on an index return all values
# If we're not filtering on range key or on an index return all
# values
results = possible_results
if index_name:
if index_range_key:
results.sort(key=lambda item: item.attrs[index_range_key['AttributeName']].value
if item.attrs.get(index_range_key['AttributeName']) else None)
if item.attrs.get(index_range_key['AttributeName']) else None)
else:
results.sort(key=lambda item: item.range_key)
@ -425,7 +451,8 @@ class Table(object):
# Comparison is NULL and we don't have the attribute
continue
else:
# No attribute found and comparison is no NULL. This item fails
# No attribute found and comparison is no NULL. This item
# fails
passes_all_conditions = False
break
@ -458,7 +485,6 @@ class Table(object):
return results, last_evaluated_key
def lookup(self, *args, **kwargs):
if not self.schema:
self.describe()
@ -485,6 +511,18 @@ class DynamoDBBackend(BaseBackend):
def delete_table(self, name):
return self.tables.pop(name, None)
def tag_resource(self, table_arn, tags):
for table in self.tables:
if self.tables[table].table_arn == table_arn:
self.tables[table].tags.extend(tags)
def list_tags_of_resource(self, table_arn):
required_table = None
for table in self.tables:
if self.tables[table].table_arn == table_arn:
required_table = self.tables[table]
return required_table.tags
def update_table_throughput(self, name, throughput):
table = self.tables[name]
table.throughput = throughput
@ -515,7 +553,8 @@ class DynamoDBBackend(BaseBackend):
if gsi_to_create:
if gsi_to_create['IndexName'] in gsis_by_name:
raise ValueError('Global Secondary Index already exists: %s' % gsi_to_create['IndexName'])
raise ValueError(
'Global Secondary Index already exists: %s' % gsi_to_create['IndexName'])
gsis_by_name[gsi_to_create['IndexName']] = gsi_to_create
@ -553,9 +592,11 @@ class DynamoDBBackend(BaseBackend):
def get_keys_value(self, table, keys):
if table.hash_key_attr not in keys or (table.has_range_key and table.range_key_attr not in keys):
raise ValueError("Table has a range key, but no range key was passed into get_item")
raise ValueError(
"Table has a range key, but no range key was passed into get_item")
hash_key = DynamoType(keys[table.hash_key_attr])
range_key = DynamoType(keys[table.range_key_attr]) if table.has_range_key else None
range_key = DynamoType(
keys[table.range_key_attr]) if table.has_range_key else None
return hash_key, range_key
def get_table(self, table_name):
@ -575,7 +616,8 @@ class DynamoDBBackend(BaseBackend):
return None, None
hash_key = DynamoType(hash_key_dict)
range_values = [DynamoType(range_value) for range_value in range_value_dicts]
range_values = [DynamoType(range_value)
for range_value in range_value_dicts]
return table.query(hash_key, range_comparison, range_values, limit,
exclusive_start_key, scan_index_forward, index_name, **filter_kwargs)
@ -596,7 +638,8 @@ class DynamoDBBackend(BaseBackend):
table = self.get_table(table_name)
if all([table.hash_key_attr in key, table.range_key_attr in key]):
# Covers cases where table has hash and range keys, ``key`` param will be a dict
# Covers cases where table has hash and range keys, ``key`` param
# will be a dict
hash_value = DynamoType(key[table.hash_key_attr])
range_value = DynamoType(key[table.range_key_attr])
elif table.hash_key_attr in key:
@ -627,13 +670,14 @@ class DynamoDBBackend(BaseBackend):
item = table.get_item(hash_value, range_value)
if update_expression:
item.update(update_expression, expression_attribute_names, expression_attribute_values)
item.update(update_expression, expression_attribute_names,
expression_attribute_values)
else:
item.update_with_attribute_updates(attribute_updates)
return item
def delete_item(self, table_name, keys):
table = self.tables.get(table_name)
table = self.get_table(table_name)
if not table:
return None
hash_key, range_key = self.get_keys_value(table, keys)

View file

@ -52,7 +52,7 @@ class DynamoHandler(BaseResponse):
return status, self.response_headers, dynamo_json_dump({'__type': type_})
def call_action(self):
body = self.body.decode('utf-8')
body = self.body
if 'GetSessionToken' in body:
return 200, self.response_headers, sts_handler()
@ -73,7 +73,7 @@ class DynamoHandler(BaseResponse):
def list_tables(self):
body = self.body
limit = body.get('Limit')
limit = body.get('Limit', 100)
if body.get("ExclusiveStartTableName"):
last = body.get("ExclusiveStartTableName")
start = list(dynamodb_backend2.tables.keys()).index(last) + 1
@ -104,11 +104,11 @@ class DynamoHandler(BaseResponse):
local_secondary_indexes = body.get("LocalSecondaryIndexes", [])
table = dynamodb_backend2.create_table(table_name,
schema=key_schema,
throughput=throughput,
attr=attr,
global_indexes=global_indexes,
indexes=local_secondary_indexes)
schema=key_schema,
throughput=throughput,
attr=attr,
global_indexes=global_indexes,
indexes=local_secondary_indexes)
if table is not None:
return dynamo_json_dump(table.describe())
else:
@ -124,10 +124,40 @@ class DynamoHandler(BaseResponse):
er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException'
return self.error(er)
def tag_resource(self):
tags = self.body['Tags']
table_arn = self.body['ResourceArn']
dynamodb_backend2.tag_resource(table_arn, tags)
return json.dumps({})
def list_tags_of_resource(self):
try:
table_arn = self.body['ResourceArn']
all_tags = dynamodb_backend2.list_tags_of_resource(table_arn)
all_tag_keys = [tag['Key'] for tag in all_tags]
marker = self.body.get('NextToken')
if marker:
start = all_tag_keys.index(marker) + 1
else:
start = 0
max_items = 10 # there is no default, but using 10 to make testing easier
tags_resp = all_tags[start:start + max_items]
next_marker = None
if len(all_tags) > start + max_items:
next_marker = tags_resp[-1]['Key']
if next_marker:
return json.dumps({'Tags': tags_resp,
'NextToken': next_marker})
return json.dumps({'Tags': tags_resp})
except AttributeError:
er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException'
return self.error(er)
def update_table(self):
name = self.body['TableName']
if 'GlobalSecondaryIndexUpdates' in self.body:
table = dynamodb_backend2.update_table_global_indexes(name, self.body['GlobalSecondaryIndexUpdates'])
table = dynamodb_backend2.update_table_global_indexes(
name, self.body['GlobalSecondaryIndexUpdates'])
if 'ProvisionedThroughput' in self.body:
throughput = self.body["ProvisionedThroughput"]
table = dynamodb_backend2.update_table_throughput(name, throughput)
@ -151,17 +181,20 @@ class DynamoHandler(BaseResponse):
else:
expected = None
# Attempt to parse simple ConditionExpressions into an Expected expression
# Attempt to parse simple ConditionExpressions into an Expected
# expression
if not expected:
condition_expression = self.body.get('ConditionExpression')
if condition_expression and 'OR' not in condition_expression:
cond_items = [c.strip() for c in condition_expression.split('AND')]
cond_items = [c.strip()
for c in condition_expression.split('AND')]
if cond_items:
expected = {}
overwrite = False
exists_re = re.compile('^attribute_exists\((.*)\)$')
not_exists_re = re.compile('^attribute_not_exists\((.*)\)$')
not_exists_re = re.compile(
'^attribute_not_exists\((.*)\)$')
for cond in cond_items:
exists_m = exists_re.match(cond)
@ -172,7 +205,8 @@ class DynamoHandler(BaseResponse):
expected[not_exists_m.group(1)] = {'Exists': False}
try:
result = dynamodb_backend2.put_item(name, item, expected, overwrite)
result = dynamodb_backend2.put_item(
name, item, expected, overwrite)
except Exception:
er = 'com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException'
return self.error(er)
@ -249,7 +283,8 @@ class DynamoHandler(BaseResponse):
item = dynamodb_backend2.get_item(table_name, key)
if item:
item_describe = item.describe_attrs(attributes_to_get)
results["Responses"][table_name].append(item_describe["Item"])
results["Responses"][table_name].append(
item_describe["Item"])
results["ConsumedCapacity"].append({
"CapacityUnits": len(keys),
@ -268,8 +303,10 @@ class DynamoHandler(BaseResponse):
table = dynamodb_backend2.get_table(name)
index_name = self.body.get('IndexName')
if index_name:
all_indexes = (table.global_indexes or []) + (table.indexes or [])
indexes_by_name = dict((i['IndexName'], i) for i in all_indexes)
all_indexes = (table.global_indexes or []) + \
(table.indexes or [])
indexes_by_name = dict((i['IndexName'], i)
for i in all_indexes)
if index_name not in indexes_by_name:
raise ValueError('Invalid index: %s for table: %s. Available indexes are: %s' % (
index_name, name, ', '.join(indexes_by_name.keys())
@ -279,7 +316,8 @@ class DynamoHandler(BaseResponse):
else:
index = table.schema
reverse_attribute_lookup = dict((v, k) for k, v in six.iteritems(self.body['ExpressionAttributeNames']))
reverse_attribute_lookup = dict((v, k) for k, v in
six.iteritems(self.body['ExpressionAttributeNames']))
if " AND " in key_condition_expression:
expressions = key_condition_expression.split(" AND ", 1)
@ -308,7 +346,8 @@ class DynamoHandler(BaseResponse):
value_alias_map[range_key_expression_components[1]],
]
else:
range_values = [value_alias_map[range_key_expression_components[2]]]
range_values = [value_alias_map[
range_key_expression_components[2]]]
else:
hash_key_expression = key_condition_expression
range_comparison = None
@ -319,15 +358,18 @@ class DynamoHandler(BaseResponse):
else:
# 'KeyConditions': {u'forum_name': {u'ComparisonOperator': u'EQ', u'AttributeValueList': [{u'S': u'the-key'}]}}
key_conditions = self.body.get('KeyConditions')
query_filters = self.body.get("QueryFilter")
if key_conditions:
hash_key_name, range_key_name = dynamodb_backend2.get_table_keys_name(name, key_conditions.keys())
hash_key_name, range_key_name = dynamodb_backend2.get_table_keys_name(
name, key_conditions.keys())
for key, value in key_conditions.items():
if key not in (hash_key_name, range_key_name):
filter_kwargs[key] = value
if hash_key_name is None:
er = "'com.amazonaws.dynamodb.v20120810#ResourceNotFoundException"
return self.error(er)
hash_key = key_conditions[hash_key_name]['AttributeValueList'][0]
hash_key = key_conditions[hash_key_name][
'AttributeValueList'][0]
if len(key_conditions) == 1:
range_comparison = None
range_values = []
@ -338,11 +380,15 @@ class DynamoHandler(BaseResponse):
else:
range_condition = key_conditions.get(range_key_name)
if range_condition:
range_comparison = range_condition['ComparisonOperator']
range_values = range_condition['AttributeValueList']
range_comparison = range_condition[
'ComparisonOperator']
range_values = range_condition[
'AttributeValueList']
else:
range_comparison = None
range_values = []
if query_filters:
filter_kwargs.update(query_filters)
index_name = self.body.get('IndexName')
exclusive_start_key = self.body.get('ExclusiveStartKey')
limit = self.body.get("Limit")
@ -373,7 +419,8 @@ class DynamoHandler(BaseResponse):
filters = {}
scan_filters = self.body.get('ScanFilter', {})
for attribute_name, scan_filter in scan_filters.items():
# Keys are attribute names. Values are tuples of (comparison, comparison_value)
# Keys are attribute names. Values are tuples of (comparison,
# comparison_value)
comparison_operator = scan_filter["ComparisonOperator"]
comparison_values = scan_filter.get("AttributeValueList", [])
filters[attribute_name] = (comparison_operator, comparison_values)
@ -403,33 +450,38 @@ class DynamoHandler(BaseResponse):
name = self.body['TableName']
keys = self.body['Key']
return_values = self.body.get('ReturnValues', '')
item = dynamodb_backend2.delete_item(name, keys)
if item:
if return_values == 'ALL_OLD':
item_dict = item.to_json()
else:
item_dict = {'Attributes': {}}
item_dict['ConsumedCapacityUnits'] = 0.5
return dynamo_json_dump(item_dict)
else:
table = dynamodb_backend2.get_table(name)
if not table:
er = 'com.amazonaws.dynamodb.v20120810#ConditionalCheckFailedException'
return self.error(er)
item = dynamodb_backend2.delete_item(name, keys)
if item and return_values == 'ALL_OLD':
item_dict = item.to_json()
else:
item_dict = {'Attributes': {}}
item_dict['ConsumedCapacityUnits'] = 0.5
return dynamo_json_dump(item_dict)
def update_item(self):
name = self.body['TableName']
key = self.body['Key']
update_expression = self.body.get('UpdateExpression')
attribute_updates = self.body.get('AttributeUpdates')
expression_attribute_names = self.body.get('ExpressionAttributeNames', {})
expression_attribute_values = self.body.get('ExpressionAttributeValues', {})
expression_attribute_names = self.body.get(
'ExpressionAttributeNames', {})
expression_attribute_values = self.body.get(
'ExpressionAttributeValues', {})
existing_item = dynamodb_backend2.get_item(name, key)
# Support spaces between operators in an update expression
# E.g. `a = b + c` -> `a=b+c`
if update_expression:
update_expression = re.sub('\s*([=\+-])\s*', '\\1', update_expression)
update_expression = re.sub(
'\s*([=\+-])\s*', '\\1', update_expression)
item = dynamodb_backend2.update_item(name, key, update_expression, attribute_updates, expression_attribute_names, expression_attribute_values)
item = dynamodb_backend2.update_item(
name, key, update_expression, attribute_updates, expression_attribute_names, expression_attribute_values)
item_dict = item.to_json()
item_dict['ConsumedCapacityUnits'] = 0.5