fix keys to use types
This commit is contained in:
parent
11c1a2a4c1
commit
930e4c9762
5 changed files with 105 additions and 68 deletions
|
|
@ -1,24 +1,64 @@
|
|||
import datetime
|
||||
|
||||
from collections import defaultdict, OrderedDict
|
||||
import datetime
|
||||
import json
|
||||
|
||||
from moto.core import BaseBackend
|
||||
from .comparisons import get_comparison_func
|
||||
from .utils import unix_time
|
||||
|
||||
|
||||
class DynamoJsonEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if hasattr(obj, 'to_json'):
|
||||
return obj.to_json()
|
||||
|
||||
|
||||
def dynamo_json_dump(dynamo_object):
|
||||
return json.dumps(dynamo_object, cls=DynamoJsonEncoder)
|
||||
|
||||
|
||||
class DynamoType(object):
|
||||
def __init__(self, type_as_dict):
|
||||
self.type = type_as_dict.keys()[0]
|
||||
self.value = type_as_dict.values()[0]
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.type, self.value))
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
self.type == other.type and
|
||||
self.value == other.value
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return "DynamoType: {}".format(self.to_json())
|
||||
|
||||
def to_json(self):
|
||||
return {self.type: self.value}
|
||||
|
||||
|
||||
class Item(object):
|
||||
def __init__(self, hash_key, hash_key_type, range_key, range_key_type, attrs):
|
||||
self.hash_key = hash_key
|
||||
self.hash_key_type = hash_key_type
|
||||
self.range_key = range_key
|
||||
self.range_key_type = range_key_type
|
||||
self.attrs = attrs
|
||||
|
||||
@property
|
||||
def describe(self):
|
||||
self.attrs = {}
|
||||
for key, value in attrs.iteritems():
|
||||
self.attrs[key] = DynamoType(value)
|
||||
|
||||
def __repr__(self):
|
||||
return "Item: {}".format(self.to_json())
|
||||
|
||||
def to_json(self):
|
||||
attributes = {}
|
||||
for attribute_key, attribute in self.attrs.iteritems():
|
||||
attributes[attribute_key] = attribute.value
|
||||
|
||||
return {
|
||||
"Attributes": self.attrs
|
||||
"Attributes": attributes
|
||||
}
|
||||
|
||||
def describe_attrs(self, attributes):
|
||||
|
|
@ -90,11 +130,12 @@ class Table(object):
|
|||
return True
|
||||
|
||||
def put_item(self, item_attrs):
|
||||
hash_value = item_attrs.get(self.hash_key_attr).values()[0]
|
||||
hash_value = DynamoType(item_attrs.get(self.hash_key_attr))
|
||||
if self.range_key_attr:
|
||||
range_value = item_attrs.get(self.range_key_attr).values()[0]
|
||||
range_value = DynamoType(item_attrs.get(self.range_key_attr))
|
||||
else:
|
||||
range_value = None
|
||||
|
||||
item = Item(hash_value, self.hash_key_type, range_value, self.range_key_type, item_attrs)
|
||||
|
||||
if range_value:
|
||||
|
|
@ -112,7 +153,7 @@ class Table(object):
|
|||
except KeyError:
|
||||
return None
|
||||
|
||||
def query(self, hash_key, range_comparison, range_values):
|
||||
def query(self, hash_key, range_comparison, range_objs):
|
||||
results = []
|
||||
last_page = True # Once pagination is implemented, change this
|
||||
|
||||
|
|
@ -120,7 +161,8 @@ class Table(object):
|
|||
if range_comparison:
|
||||
comparison_func = get_comparison_func(range_comparison)
|
||||
for result in possible_results:
|
||||
if comparison_func(result.range_key, *range_values):
|
||||
range_values = [obj.value for obj in range_objs]
|
||||
if comparison_func(result.range_key.value, *range_values):
|
||||
results.append(result)
|
||||
else:
|
||||
# If we're not filtering on range key, return all values
|
||||
|
|
@ -143,14 +185,14 @@ class Table(object):
|
|||
for result in self.all_items():
|
||||
scanned_count += 1
|
||||
passes_all_conditions = True
|
||||
for attribute_name, (comparison_operator, comparison_values) in filters.iteritems():
|
||||
for attribute_name, (comparison_operator, comparison_objs) in filters.iteritems():
|
||||
comparison_func = get_comparison_func(comparison_operator)
|
||||
|
||||
attribute = result.attrs.get(attribute_name)
|
||||
if attribute:
|
||||
# Attribute found
|
||||
attribute_value = attribute.values()[0]
|
||||
if not comparison_func(attribute_value, *comparison_values):
|
||||
comparison_values = [obj.value for obj in comparison_objs]
|
||||
if not comparison_func(attribute.value, *comparison_values):
|
||||
passes_all_conditions = False
|
||||
break
|
||||
elif comparison_operator == 'NULL':
|
||||
|
|
@ -202,18 +244,24 @@ class DynamoDBBackend(BaseBackend):
|
|||
|
||||
return table.put_item(item_attrs)
|
||||
|
||||
def get_item(self, table_name, hash_key, range_key):
|
||||
def get_item(self, table_name, hash_key_dict, range_key_dict):
|
||||
table = self.tables.get(table_name)
|
||||
if not table:
|
||||
return None
|
||||
|
||||
hash_key = DynamoType(hash_key_dict)
|
||||
range_key = DynamoType(range_key_dict) if range_key_dict else None
|
||||
|
||||
return table.get_item(hash_key, range_key)
|
||||
|
||||
def query(self, table_name, hash_key, range_comparison, range_values):
|
||||
def query(self, table_name, hash_key_dict, range_comparison, range_value_dicts):
|
||||
table = self.tables.get(table_name)
|
||||
if not table:
|
||||
return None, None
|
||||
|
||||
hash_key = DynamoType(hash_key_dict)
|
||||
range_values = [DynamoType(range_value) for range_value in range_value_dicts]
|
||||
|
||||
return table.query(hash_key, range_comparison, range_values)
|
||||
|
||||
def scan(self, table_name, filters):
|
||||
|
|
@ -221,13 +269,21 @@ class DynamoDBBackend(BaseBackend):
|
|||
if not table:
|
||||
return None, None, None
|
||||
|
||||
return table.scan(filters)
|
||||
scan_filters = {}
|
||||
for key, (comparison_operator, comparison_values) in filters.iteritems():
|
||||
dynamo_types = [DynamoType(value) for value in comparison_values]
|
||||
scan_filters[key] = (comparison_operator, dynamo_types)
|
||||
|
||||
def delete_item(self, table_name, hash_key, range_key):
|
||||
return table.scan(scan_filters)
|
||||
|
||||
def delete_item(self, table_name, hash_key_dict, range_key_dict):
|
||||
table = self.tables.get(table_name)
|
||||
if not table:
|
||||
return None
|
||||
|
||||
hash_key = DynamoType(hash_key_dict)
|
||||
range_key = DynamoType(range_key_dict) if range_key_dict else None
|
||||
|
||||
return table.delete_item(hash_key, range_key)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue