diff --git a/.coveragerc b/.coveragerc index c93685a3..25d85b80 100644 --- a/.coveragerc +++ b/.coveragerc @@ -3,7 +3,7 @@ exclude_lines = if __name__ == .__main__.: raise NotImplemented. - + def __repr__ [run] include = moto/* diff --git a/moto/dynamodb/comparisons.py b/moto/dynamodb/comparisons.py index 1a0db237..58fa43c4 100644 --- a/moto/dynamodb/comparisons.py +++ b/moto/dynamodb/comparisons.py @@ -6,8 +6,8 @@ COMPARISON_FUNCS = { 'LT': lambda item_value, test_value: item_value < test_value, 'GE': lambda item_value, test_value: item_value >= test_value, 'GT': lambda item_value, test_value: item_value > test_value, - 'NULL': lambda item_value, test_value: item_value is None, - 'NOT_NULL': lambda item_value, test_value: item_value is not None, + 'NULL': lambda item_value: item_value is None, + 'NOT_NULL': lambda item_value: item_value is not None, 'CONTAINS': lambda item_value, test_value: test_value in item_value, 'NOT_CONTAINS': lambda item_value, test_value: test_value not in item_value, 'BEGINS_WITH': lambda item_value, test_value: item_value.startswith(test_value), diff --git a/moto/dynamodb/models.py b/moto/dynamodb/models.py index 63c01545..1f9dc4b2 100644 --- a/moto/dynamodb/models.py +++ b/moto/dynamodb/models.py @@ -1,24 +1,64 @@ -import datetime - from collections import defaultdict, OrderedDict +import datetime +import json from moto.core import BaseBackend from .comparisons import get_comparison_func from .utils import unix_time +class DynamoJsonEncoder(json.JSONEncoder): + def default(self, obj): + if hasattr(obj, 'to_json'): + return obj.to_json() + + +def dynamo_json_dump(dynamo_object): + return json.dumps(dynamo_object, cls=DynamoJsonEncoder) + + +class DynamoType(object): + def __init__(self, type_as_dict): + self.type = type_as_dict.keys()[0] + self.value = type_as_dict.values()[0] + + def __hash__(self): + return hash((self.type, self.value)) + + def __eq__(self, other): + return ( + self.type == other.type and + self.value == other.value + ) + + def __repr__(self): + return "DynamoType: {}".format(self.to_json()) + + def to_json(self): + return {self.type: self.value} + + class Item(object): def __init__(self, hash_key, hash_key_type, range_key, range_key_type, attrs): self.hash_key = hash_key self.hash_key_type = hash_key_type self.range_key = range_key self.range_key_type = range_key_type - self.attrs = attrs - @property - def describe(self): + self.attrs = {} + for key, value in attrs.iteritems(): + self.attrs[key] = DynamoType(value) + + def __repr__(self): + return "Item: {}".format(self.to_json()) + + def to_json(self): + attributes = {} + for attribute_key, attribute in self.attrs.iteritems(): + attributes[attribute_key] = attribute.value + return { - "Attributes": self.attrs + "Attributes": attributes } def describe_attrs(self, attributes): @@ -90,11 +130,12 @@ class Table(object): return True def put_item(self, item_attrs): - hash_value = item_attrs.get(self.hash_key_attr).values()[0] + hash_value = DynamoType(item_attrs.get(self.hash_key_attr)) if self.range_key_attr: - range_value = item_attrs.get(self.range_key_attr).values()[0] + range_value = DynamoType(item_attrs.get(self.range_key_attr)) else: range_value = None + item = Item(hash_value, self.hash_key_type, range_value, self.range_key_type, item_attrs) if range_value: @@ -112,7 +153,7 @@ class Table(object): except KeyError: return None - def query(self, hash_key, range_comparison, range_values): + def query(self, hash_key, range_comparison, range_objs): results = [] last_page = True # Once pagination is implemented, change this @@ -120,7 +161,8 @@ class Table(object): if range_comparison: comparison_func = get_comparison_func(range_comparison) for result in possible_results: - if comparison_func(result.range_key, *range_values): + range_values = [obj.value for obj in range_objs] + if comparison_func(result.range_key.value, *range_values): results.append(result) else: # If we're not filtering on range key, return all values @@ -143,14 +185,14 @@ class Table(object): for result in self.all_items(): scanned_count += 1 passes_all_conditions = True - for attribute_name, (comparison_operator, comparison_values) in filters.iteritems(): + for attribute_name, (comparison_operator, comparison_objs) in filters.iteritems(): comparison_func = get_comparison_func(comparison_operator) attribute = result.attrs.get(attribute_name) if attribute: # Attribute found - attribute_value = attribute.values()[0] - if not comparison_func(attribute_value, *comparison_values): + comparison_values = [obj.value for obj in comparison_objs] + if not comparison_func(attribute.value, *comparison_values): passes_all_conditions = False break elif comparison_operator == 'NULL': @@ -202,18 +244,24 @@ class DynamoDBBackend(BaseBackend): return table.put_item(item_attrs) - def get_item(self, table_name, hash_key, range_key): + def get_item(self, table_name, hash_key_dict, range_key_dict): table = self.tables.get(table_name) if not table: return None + hash_key = DynamoType(hash_key_dict) + range_key = DynamoType(range_key_dict) if range_key_dict else None + return table.get_item(hash_key, range_key) - def query(self, table_name, hash_key, range_comparison, range_values): + def query(self, table_name, hash_key_dict, range_comparison, range_value_dicts): table = self.tables.get(table_name) if not table: return None, None + hash_key = DynamoType(hash_key_dict) + range_values = [DynamoType(range_value) for range_value in range_value_dicts] + return table.query(hash_key, range_comparison, range_values) def scan(self, table_name, filters): @@ -221,13 +269,21 @@ class DynamoDBBackend(BaseBackend): if not table: return None, None, None - return table.scan(filters) + scan_filters = {} + for key, (comparison_operator, comparison_values) in filters.iteritems(): + dynamo_types = [DynamoType(value) for value in comparison_values] + scan_filters[key] = (comparison_operator, dynamo_types) - def delete_item(self, table_name, hash_key, range_key): + return table.scan(scan_filters) + + def delete_item(self, table_name, hash_key_dict, range_key_dict): table = self.tables.get(table_name) if not table: return None + hash_key = DynamoType(hash_key_dict) + range_key = DynamoType(range_key_dict) if range_key_dict else None + return table.delete_item(hash_key, range_key) diff --git a/moto/dynamodb/responses.py b/moto/dynamodb/responses.py index d5129082..1b9f1683 100644 --- a/moto/dynamodb/responses.py +++ b/moto/dynamodb/responses.py @@ -1,8 +1,7 @@ import json from moto.core.utils import headers_to_dict -from .models import dynamodb_backend -from .utils import value_from_dynamo_type, values_from_dynamo_types +from .models import dynamodb_backend, dynamo_json_dump class DynamoHandler(object): @@ -23,7 +22,7 @@ class DynamoHandler(object): return match.split(".")[1] def error(self, type_, status=400): - return json.dumps({'__type': type_}), dict(status=400) + return dynamo_json_dump({'__type': type_}), dict(status=400) def dispatch(self): method = self.get_method_name(self.headers) @@ -47,7 +46,7 @@ class DynamoHandler(object): response = {"TableNames": tables} if limit and len(all_tables) > start + limit: response["LastEvaluatedTableName"] = tables[-1] - return json.dumps(response) + return dynamo_json_dump(response) def CreateTable(self, uri, body, headers): name = body['TableName'] @@ -74,13 +73,13 @@ class DynamoHandler(object): read_capacity=int(read_units), write_capacity=int(write_units), ) - return json.dumps(table.describe) + return dynamo_json_dump(table.describe) def DeleteTable(self, uri, body, headers): name = body['TableName'] table = dynamodb_backend.delete_table(name) if table: - return json.dumps(table.describe) + return dynamo_json_dump(table.describe) else: er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException' return self.error(er) @@ -91,7 +90,7 @@ class DynamoHandler(object): new_read_units = throughput["ReadCapacityUnits"] new_write_units = throughput["WriteCapacityUnits"] table = dynamodb_backend.update_table_throughput(name, new_read_units, new_write_units) - return json.dumps(table.describe) + return dynamo_json_dump(table.describe) def DescribeTable(self, uri, body, headers): name = body['TableName'] @@ -100,16 +99,16 @@ class DynamoHandler(object): except KeyError: er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException' return self.error(er) - return json.dumps(table.describe) + return dynamo_json_dump(table.describe) def PutItem(self, uri, body, headers): name = body['TableName'] item = body['Item'] result = dynamodb_backend.put_item(name, item) if result: - item_dict = result.describe + item_dict = result.to_json() item_dict['ConsumedCapacityUnits'] = 1 - return json.dumps(item_dict) + return dynamo_json_dump(item_dict) else: er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException' return self.error(er) @@ -127,8 +126,8 @@ class DynamoHandler(object): dynamodb_backend.put_item(table_name, item) elif request_type == 'DeleteRequest': key = request['Key'] - hash_key = value_from_dynamo_type(key['HashKeyElement']) - range_key = value_from_dynamo_type(key.get('RangeKeyElement')) + hash_key = key['HashKeyElement'] + range_key = key.get('RangeKeyElement') item = dynamodb_backend.delete_item(table_name, hash_key, range_key) response = { @@ -143,19 +142,19 @@ class DynamoHandler(object): "UnprocessedItems": {} } - return json.dumps(response) + return dynamo_json_dump(response) def GetItem(self, uri, body, headers): name = body['TableName'] key = body['Key'] - hash_key = key['HashKeyElement'].values()[0] - range_key = value_from_dynamo_type(key.get('RangeKeyElement')) + hash_key = key['HashKeyElement'] + range_key = key.get('RangeKeyElement') attrs_to_get = body.get('AttributesToGet') item = dynamodb_backend.get_item(name, hash_key, range_key) if item: item_dict = item.describe_attrs(attrs_to_get) item_dict['ConsumedCapacityUnits'] = 0.5 - return json.dumps(item_dict) + return dynamo_json_dump(item_dict) else: er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException' return self.error(er) @@ -174,24 +173,25 @@ class DynamoHandler(object): keys = table_request['Keys'] attributes_to_get = table_request.get('AttributesToGet') for key in keys: - hash_key = value_from_dynamo_type(key["HashKeyElement"]) - range_key = value_from_dynamo_type(key.get("RangeKeyElement")) + hash_key = key["HashKeyElement"] + range_key = key.get("RangeKeyElement") item = dynamodb_backend.get_item(table_name, hash_key, range_key) if item: item_describe = item.describe_attrs(attributes_to_get) items.append(item_describe) results["Responses"][table_name] = {"Items": items, "ConsumedCapacityUnits": 1} - return json.dumps(results) + return dynamo_json_dump(results) def Query(self, uri, body, headers): name = body['TableName'] - hash_key = body['HashKeyValue'].values()[0] + hash_key = body['HashKeyValue'] range_condition = body.get('RangeKeyCondition') if range_condition: range_comparison = range_condition['ComparisonOperator'] - range_values = values_from_dynamo_types(range_condition['AttributeValueList']) + range_values = range_condition['AttributeValueList'] else: - range_comparison = range_values = None + range_comparison = None + range_values = [] items, last_page = dynamodb_backend.query(name, hash_key, range_comparison, range_values) @@ -211,7 +211,7 @@ class DynamoHandler(object): # "HashKeyElement": items[-1].hash_key, # "RangeKeyElement": items[-1].range_key, # } - return json.dumps(result) + return dynamo_json_dump(result) def Scan(self, uri, body, headers): name = body['TableName'] @@ -221,10 +221,7 @@ class DynamoHandler(object): for attribute_name, scan_filter in scan_filters.iteritems(): # Keys are attribute names. Values are tuples of (comparison, comparison_value) comparison_operator = scan_filter["ComparisonOperator"] - if scan_filter.get("AttributeValueList"): - comparison_values = values_from_dynamo_types(scan_filter.get("AttributeValueList")) - else: - comparison_values = [None] + comparison_values = scan_filter.get("AttributeValueList", []) filters[attribute_name] = (comparison_operator, comparison_values) items, scanned_count, last_page = dynamodb_backend.scan(name, filters) @@ -246,22 +243,22 @@ class DynamoHandler(object): # "HashKeyElement": items[-1].hash_key, # "RangeKeyElement": items[-1].range_key, # } - return json.dumps(result) + return dynamo_json_dump(result) def DeleteItem(self, uri, body, headers): name = body['TableName'] key = body['Key'] - hash_key = value_from_dynamo_type(key['HashKeyElement']) - range_key = value_from_dynamo_type(key.get('RangeKeyElement')) + hash_key = key['HashKeyElement'] + range_key = key.get('RangeKeyElement') return_values = body.get('ReturnValues', '') item = dynamodb_backend.delete_item(name, hash_key, range_key) if item: if return_values == 'ALL_OLD': - item_dict = item.describe + item_dict = item.to_json() else: item_dict = {'Attributes': []} item_dict['ConsumedCapacityUnits'] = 0.5 - return json.dumps(item_dict) + return dynamo_json_dump(item_dict) else: er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException' return self.error(er) diff --git a/moto/dynamodb/utils.py b/moto/dynamodb/utils.py index 8b2adc81..e4787d10 100644 --- a/moto/dynamodb/utils.py +++ b/moto/dynamodb/utils.py @@ -5,19 +5,3 @@ def unix_time(dt): epoch = datetime.datetime.utcfromtimestamp(0) delta = dt - epoch return delta.total_seconds() - - -def value_from_dynamo_type(dynamo_type): - """ - Dynamo return attributes like {"S": "AttributeValue1"}. - This function takes that value and returns "AttributeValue1". - - # TODO eventually this should be smarted to actually read the type of - the attribute - """ - if dynamo_type: - return dynamo_type.values()[0] - - -def values_from_dynamo_types(dynamo_types): - return [value_from_dynamo_type(dynamo_type) for dynamo_type in dynamo_types]