fix keys to use types

This commit is contained in:
Steve Pulec 2013-03-15 00:45:12 -04:00
commit 930e4c9762
5 changed files with 105 additions and 68 deletions

View file

@ -1,24 +1,64 @@
import datetime
from collections import defaultdict, OrderedDict
import datetime
import json
from moto.core import BaseBackend
from .comparisons import get_comparison_func
from .utils import unix_time
class DynamoJsonEncoder(json.JSONEncoder):
def default(self, obj):
if hasattr(obj, 'to_json'):
return obj.to_json()
def dynamo_json_dump(dynamo_object):
return json.dumps(dynamo_object, cls=DynamoJsonEncoder)
class DynamoType(object):
def __init__(self, type_as_dict):
self.type = type_as_dict.keys()[0]
self.value = type_as_dict.values()[0]
def __hash__(self):
return hash((self.type, self.value))
def __eq__(self, other):
return (
self.type == other.type and
self.value == other.value
)
def __repr__(self):
return "DynamoType: {}".format(self.to_json())
def to_json(self):
return {self.type: self.value}
class Item(object):
def __init__(self, hash_key, hash_key_type, range_key, range_key_type, attrs):
self.hash_key = hash_key
self.hash_key_type = hash_key_type
self.range_key = range_key
self.range_key_type = range_key_type
self.attrs = attrs
@property
def describe(self):
self.attrs = {}
for key, value in attrs.iteritems():
self.attrs[key] = DynamoType(value)
def __repr__(self):
return "Item: {}".format(self.to_json())
def to_json(self):
attributes = {}
for attribute_key, attribute in self.attrs.iteritems():
attributes[attribute_key] = attribute.value
return {
"Attributes": self.attrs
"Attributes": attributes
}
def describe_attrs(self, attributes):
@ -90,11 +130,12 @@ class Table(object):
return True
def put_item(self, item_attrs):
hash_value = item_attrs.get(self.hash_key_attr).values()[0]
hash_value = DynamoType(item_attrs.get(self.hash_key_attr))
if self.range_key_attr:
range_value = item_attrs.get(self.range_key_attr).values()[0]
range_value = DynamoType(item_attrs.get(self.range_key_attr))
else:
range_value = None
item = Item(hash_value, self.hash_key_type, range_value, self.range_key_type, item_attrs)
if range_value:
@ -112,7 +153,7 @@ class Table(object):
except KeyError:
return None
def query(self, hash_key, range_comparison, range_values):
def query(self, hash_key, range_comparison, range_objs):
results = []
last_page = True # Once pagination is implemented, change this
@ -120,7 +161,8 @@ class Table(object):
if range_comparison:
comparison_func = get_comparison_func(range_comparison)
for result in possible_results:
if comparison_func(result.range_key, *range_values):
range_values = [obj.value for obj in range_objs]
if comparison_func(result.range_key.value, *range_values):
results.append(result)
else:
# If we're not filtering on range key, return all values
@ -143,14 +185,14 @@ class Table(object):
for result in self.all_items():
scanned_count += 1
passes_all_conditions = True
for attribute_name, (comparison_operator, comparison_values) in filters.iteritems():
for attribute_name, (comparison_operator, comparison_objs) in filters.iteritems():
comparison_func = get_comparison_func(comparison_operator)
attribute = result.attrs.get(attribute_name)
if attribute:
# Attribute found
attribute_value = attribute.values()[0]
if not comparison_func(attribute_value, *comparison_values):
comparison_values = [obj.value for obj in comparison_objs]
if not comparison_func(attribute.value, *comparison_values):
passes_all_conditions = False
break
elif comparison_operator == 'NULL':
@ -202,18 +244,24 @@ class DynamoDBBackend(BaseBackend):
return table.put_item(item_attrs)
def get_item(self, table_name, hash_key, range_key):
def get_item(self, table_name, hash_key_dict, range_key_dict):
table = self.tables.get(table_name)
if not table:
return None
hash_key = DynamoType(hash_key_dict)
range_key = DynamoType(range_key_dict) if range_key_dict else None
return table.get_item(hash_key, range_key)
def query(self, table_name, hash_key, range_comparison, range_values):
def query(self, table_name, hash_key_dict, range_comparison, range_value_dicts):
table = self.tables.get(table_name)
if not table:
return None, None
hash_key = DynamoType(hash_key_dict)
range_values = [DynamoType(range_value) for range_value in range_value_dicts]
return table.query(hash_key, range_comparison, range_values)
def scan(self, table_name, filters):
@ -221,13 +269,21 @@ class DynamoDBBackend(BaseBackend):
if not table:
return None, None, None
return table.scan(filters)
scan_filters = {}
for key, (comparison_operator, comparison_values) in filters.iteritems():
dynamo_types = [DynamoType(value) for value in comparison_values]
scan_filters[key] = (comparison_operator, dynamo_types)
def delete_item(self, table_name, hash_key, range_key):
return table.scan(scan_filters)
def delete_item(self, table_name, hash_key_dict, range_key_dict):
table = self.tables.get(table_name)
if not table:
return None
hash_key = DynamoType(hash_key_dict)
range_key = DynamoType(range_key_dict) if range_key_dict else None
return table.delete_item(hash_key, range_key)