Refactor DynamoDB update expressions (#2497)
* Refactor DynamoDB.update to use recursive method for nested updates * Simplify DynamoDB.update_item logic
This commit is contained in:
parent
9e4860ccd8
commit
64cf1fc2c9
2 changed files with 147 additions and 148 deletions
|
|
@ -34,14 +34,76 @@ def bytesize(val):
|
|||
return len(str(val).encode('utf-8'))
|
||||
|
||||
|
||||
def attribute_is_list(attr):
|
||||
"""
|
||||
Checks if attribute denotes a list, and returns the regular expression if so
|
||||
:param attr: attr or attr[index]
|
||||
:return: attr, re or None
|
||||
"""
|
||||
list_index_update = re.match('(.+)\\[([0-9]+)\\]', attr)
|
||||
if list_index_update:
|
||||
attr = list_index_update.group(1)
|
||||
return attr, list_index_update.group(2) if list_index_update else None
|
||||
|
||||
|
||||
class DynamoType(object):
|
||||
"""
|
||||
http://docs.aws.amazon.com/amazondynamodb/latest/developerguide/DataModel.html#DataModelDataTypes
|
||||
"""
|
||||
|
||||
def __init__(self, type_as_dict):
|
||||
self.type = list(type_as_dict)[0]
|
||||
self.value = list(type_as_dict.values())[0]
|
||||
if type(type_as_dict) == DynamoType:
|
||||
self.type = type_as_dict.type
|
||||
self.value = type_as_dict.value
|
||||
else:
|
||||
self.type = list(type_as_dict)[0]
|
||||
self.value = list(type_as_dict.values())[0]
|
||||
if self.is_list():
|
||||
self.value = [DynamoType(val) for val in self.value]
|
||||
elif self.is_map():
|
||||
self.value = dict((k, DynamoType(v)) for k, v in self.value.items())
|
||||
|
||||
def set(self, key, new_value, index=None):
|
||||
if index:
|
||||
index = int(index)
|
||||
if type(self.value) is not list:
|
||||
raise InvalidUpdateExpression
|
||||
if index >= len(self.value):
|
||||
self.value.append(new_value)
|
||||
# {'L': [DynamoType, ..]} ==> DynamoType.set()
|
||||
self.value[min(index, len(self.value) - 1)].set(key, new_value)
|
||||
else:
|
||||
attr = (key or '').split('.').pop(0)
|
||||
attr, list_index = attribute_is_list(attr)
|
||||
if not key:
|
||||
# {'S': value} ==> {'S': new_value}
|
||||
self.value = new_value.value
|
||||
else:
|
||||
if attr not in self.value: # nonexistingattribute
|
||||
type_of_new_attr = 'M' if '.' in key else new_value.type
|
||||
self.value[attr] = DynamoType({type_of_new_attr: {}})
|
||||
# {'M': {'foo': DynamoType}} ==> DynamoType.set(new_value)
|
||||
self.value[attr].set('.'.join(key.split('.')[1:]), new_value, list_index)
|
||||
|
||||
def delete(self, key, index=None):
|
||||
if index:
|
||||
if not key:
|
||||
if int(index) < len(self.value):
|
||||
del self.value[int(index)]
|
||||
elif '.' in key:
|
||||
self.value[int(index)].delete('.'.join(key.split('.')[1:]))
|
||||
else:
|
||||
self.value[int(index)].delete(key)
|
||||
else:
|
||||
attr = key.split('.')[0]
|
||||
attr, list_index = attribute_is_list(attr)
|
||||
|
||||
if list_index:
|
||||
self.value[attr].delete('.'.join(key.split('.')[1:]), list_index)
|
||||
elif '.' in key:
|
||||
self.value[attr].delete('.'.join(key.split('.')[1:]))
|
||||
else:
|
||||
self.value.pop(key)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.type, self.value))
|
||||
|
|
@ -98,7 +160,7 @@ class DynamoType(object):
|
|||
|
||||
if isinstance(key, int) and self.is_list():
|
||||
idx = key
|
||||
if idx >= 0 and idx < len(self.value):
|
||||
if 0 <= idx < len(self.value):
|
||||
return DynamoType(self.value[idx])
|
||||
|
||||
return None
|
||||
|
|
@ -110,7 +172,7 @@ class DynamoType(object):
|
|||
sub_type = self.type[0]
|
||||
value_size = sum([DynamoType({sub_type: v}).size() for v in self.value])
|
||||
elif self.is_list():
|
||||
value_size = sum([DynamoType(v).size() for v in self.value])
|
||||
value_size = sum([v.size() for v in self.value])
|
||||
elif self.is_map():
|
||||
value_size = sum([bytesize(k) + DynamoType(v).size() for k, v in self.value.items()])
|
||||
elif type(self.value) == bool:
|
||||
|
|
@ -162,22 +224,6 @@ class LimitedSizeDict(dict):
|
|||
raise ItemSizeTooLarge
|
||||
super(LimitedSizeDict, self).__setitem__(key, value)
|
||||
|
||||
def update(self, *args, **kwargs):
|
||||
if args:
|
||||
if len(args) > 1:
|
||||
raise TypeError("update expected at most 1 arguments, "
|
||||
"got %d" % len(args))
|
||||
other = dict(args[0])
|
||||
for key in other:
|
||||
self[key] = other[key]
|
||||
for key in kwargs:
|
||||
self[key] = kwargs[key]
|
||||
|
||||
def setdefault(self, key, value=None):
|
||||
if key not in self:
|
||||
self[key] = value
|
||||
return self[key]
|
||||
|
||||
|
||||
class Item(BaseModel):
|
||||
|
||||
|
|
@ -236,72 +282,26 @@ class Item(BaseModel):
|
|||
|
||||
if action == "REMOVE":
|
||||
key = value
|
||||
attr, list_index = attribute_is_list(key.split('.')[0])
|
||||
if '.' not in key:
|
||||
list_index_update = re.match('(.+)\\[([0-9]+)\\]', key)
|
||||
if list_index_update:
|
||||
# We need to remove an item from a list (REMOVE listattr[0])
|
||||
key_attr = self.attrs[list_index_update.group(1)]
|
||||
list_index = int(list_index_update.group(2))
|
||||
if key_attr.is_list():
|
||||
if len(key_attr.value) > list_index:
|
||||
del key_attr.value[list_index]
|
||||
if list_index:
|
||||
new_list = DynamoType(self.attrs[attr])
|
||||
new_list.delete(None, list_index)
|
||||
self.attrs[attr] = new_list
|
||||
else:
|
||||
self.attrs.pop(value, None)
|
||||
else:
|
||||
# Handle nested dict updates
|
||||
key_parts = key.split('.')
|
||||
attr = key_parts.pop(0)
|
||||
if attr not in self.attrs:
|
||||
raise ValueError
|
||||
|
||||
last_val = self.attrs[attr].value
|
||||
for key_part in key_parts[:-1]:
|
||||
list_index_update = re.match('(.+)\\[([0-9]+)\\]', key_part)
|
||||
if list_index_update:
|
||||
key_part = list_index_update.group(1) # listattr[1] ==> listattr
|
||||
# Hack but it'll do, traverses into a dict
|
||||
last_val_type = list(last_val.keys())
|
||||
if last_val_type and last_val_type[0] == 'M':
|
||||
last_val = last_val['M']
|
||||
|
||||
if key_part not in last_val:
|
||||
last_val[key_part] = {'M': {}}
|
||||
|
||||
last_val = last_val[key_part]
|
||||
if list_index_update:
|
||||
last_val = last_val['L'][int(list_index_update.group(2))]
|
||||
|
||||
last_val_type = list(last_val.keys())
|
||||
list_index_update = re.match('(.+)\\[([0-9]+)\\]', key_parts[-1])
|
||||
if list_index_update:
|
||||
# We need to remove an item from a list (REMOVE attr.listattr[0])
|
||||
key_part = list_index_update.group(1) # listattr[1] ==> listattr
|
||||
list_to_update = last_val[key_part]['L']
|
||||
index_to_remove = int(list_index_update.group(2))
|
||||
if index_to_remove < len(list_to_update):
|
||||
del list_to_update[index_to_remove]
|
||||
else:
|
||||
if last_val_type and last_val_type[0] == 'M':
|
||||
last_val['M'].pop(key_parts[-1], None)
|
||||
else:
|
||||
last_val.pop(key_parts[-1], None)
|
||||
self.attrs[attr].delete('.'.join(key.split('.')[1:]))
|
||||
elif action == 'SET':
|
||||
key, value = value.split("=", 1)
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
|
||||
# If not exists, changes value to a default if needed, else its the same as it was
|
||||
if value.startswith('if_not_exists'):
|
||||
# Function signature
|
||||
match = re.match(r'.*if_not_exists\s*\((?P<path>.+),\s*(?P<default>.+)\).*', value)
|
||||
if not match:
|
||||
raise TypeError
|
||||
|
||||
path, value = match.groups()
|
||||
|
||||
# If it already exists, get its value so we dont overwrite it
|
||||
if path in self.attrs:
|
||||
value = self.attrs[path]
|
||||
# check whether key is a list
|
||||
attr, list_index = attribute_is_list(key.split('.')[0])
|
||||
# If value not exists, changes value to a default if needed, else its the same as it was
|
||||
value = self._get_default(value)
|
||||
|
||||
if type(value) != DynamoType:
|
||||
if value in expression_attribute_values:
|
||||
|
|
@ -311,55 +311,12 @@ class Item(BaseModel):
|
|||
else:
|
||||
dyn_value = value
|
||||
|
||||
if '.' not in key:
|
||||
list_index_update = re.match('(.+)\\[([0-9]+)\\]', key)
|
||||
if list_index_update:
|
||||
key_attr = self.attrs[list_index_update.group(1)]
|
||||
list_index = int(list_index_update.group(2))
|
||||
if key_attr.is_list():
|
||||
if len(key_attr.value) > list_index:
|
||||
key_attr.value[list_index] = expression_attribute_values[value]
|
||||
else:
|
||||
key_attr.value.append(expression_attribute_values[value])
|
||||
else:
|
||||
raise InvalidUpdateExpression
|
||||
else:
|
||||
self.attrs[key] = dyn_value
|
||||
if '.' in key and attr not in self.attrs:
|
||||
raise ValueError # Setting nested attr not allowed if first attr does not exist yet
|
||||
elif attr not in self.attrs:
|
||||
self.attrs[attr] = dyn_value # set new top-level attribute
|
||||
else:
|
||||
# Handle nested dict updates
|
||||
key_parts = key.split('.')
|
||||
attr = key_parts.pop(0)
|
||||
if attr not in self.attrs:
|
||||
raise ValueError
|
||||
last_val = self.attrs[attr].value
|
||||
for key_part in key_parts:
|
||||
list_index_update = re.match('(.+)\\[([0-9]+)\\]', key_part)
|
||||
if list_index_update:
|
||||
key_part = list_index_update.group(1) # listattr[1] ==> listattr
|
||||
# Hack but it'll do, traverses into a dict
|
||||
last_val_type = list(last_val.keys())
|
||||
if last_val_type and last_val_type[0] == 'M':
|
||||
last_val = last_val['M']
|
||||
|
||||
if key_part not in last_val:
|
||||
last_val[key_part] = {'M': {}}
|
||||
last_val = last_val[key_part]
|
||||
|
||||
current_type = list(last_val.keys())[0]
|
||||
if list_index_update:
|
||||
# We need to add an item to a list
|
||||
list_index = int(list_index_update.group(2))
|
||||
if len(last_val['L']) > list_index:
|
||||
last_val['L'][list_index] = expression_attribute_values[value]
|
||||
else:
|
||||
last_val['L'].append(expression_attribute_values[value])
|
||||
else:
|
||||
# We have reference to a nested object but we cant just assign to it
|
||||
if current_type == dyn_value.type:
|
||||
last_val[current_type] = dyn_value.value
|
||||
else:
|
||||
last_val[dyn_value.type] = dyn_value.value
|
||||
del last_val[current_type]
|
||||
self.attrs[attr].set('.'.join(key.split('.')[1:]), dyn_value, list_index) # set value recursively
|
||||
|
||||
elif action == 'ADD':
|
||||
key, value = value.split(" ", 1)
|
||||
|
|
@ -413,6 +370,20 @@ class Item(BaseModel):
|
|||
else:
|
||||
raise NotImplementedError('{} update action not yet supported'.format(action))
|
||||
|
||||
def _get_default(self, value):
|
||||
if value.startswith('if_not_exists'):
|
||||
# Function signature
|
||||
match = re.match(r'.*if_not_exists\s*\((?P<path>.+),\s*(?P<default>.+)\).*', value)
|
||||
if not match:
|
||||
raise TypeError
|
||||
|
||||
path, value = match.groups()
|
||||
|
||||
# If it already exists, get its value so we dont overwrite it
|
||||
if path in self.attrs:
|
||||
value = self.attrs[path]
|
||||
return value
|
||||
|
||||
def update_with_attribute_updates(self, attribute_updates):
|
||||
for attribute_name, update_action in attribute_updates.items():
|
||||
action = update_action['Action']
|
||||
|
|
@ -810,7 +781,6 @@ class Table(BaseModel):
|
|||
else:
|
||||
possible_results = [item for item in list(self.all_items()) if isinstance(
|
||||
item, Item) and item.hash_key == hash_key]
|
||||
|
||||
if range_comparison:
|
||||
if index_name and not index_range_key:
|
||||
raise ValueError(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue