Merge https://github.com/spulec/moto into support_optin_regions

This commit is contained in:
Matthew Gladney 2020-04-27 18:27:40 -04:00
commit ff1beea280
60 changed files with 4652 additions and 515 deletions

View file

@ -21,6 +21,7 @@ from .datasync import mock_datasync # noqa
from .dynamodb import mock_dynamodb, mock_dynamodb_deprecated # noqa
from .dynamodb2 import mock_dynamodb2, mock_dynamodb2_deprecated # noqa
from .dynamodbstreams import mock_dynamodbstreams # noqa
from .elasticbeanstalk import mock_elasticbeanstalk # noqa
from .ec2 import mock_ec2, mock_ec2_deprecated # noqa
from .ec2_instance_connect import mock_ec2_instance_connect # noqa
from .ecr import mock_ecr, mock_ecr_deprecated # noqa

View file

@ -461,6 +461,7 @@ class RestAPI(BaseModel):
self.description = description
self.create_date = int(time.time())
self.api_key_source = kwargs.get("api_key_source") or "HEADER"
self.policy = kwargs.get("policy") or None
self.endpoint_configuration = kwargs.get("endpoint_configuration") or {
"types": ["EDGE"]
}
@ -485,6 +486,7 @@ class RestAPI(BaseModel):
"apiKeySource": self.api_key_source,
"endpointConfiguration": self.endpoint_configuration,
"tags": self.tags,
"policy": self.policy,
}
def add_child(self, path, parent_id=None):
@ -713,6 +715,7 @@ class APIGatewayBackend(BaseBackend):
api_key_source=None,
endpoint_configuration=None,
tags=None,
policy=None,
):
api_id = create_id()
rest_api = RestAPI(
@ -723,6 +726,7 @@ class APIGatewayBackend(BaseBackend):
api_key_source=api_key_source,
endpoint_configuration=endpoint_configuration,
tags=tags,
policy=policy,
)
self.apis[api_id] = rest_api
return rest_api

View file

@ -59,6 +59,7 @@ class APIGatewayResponse(BaseResponse):
api_key_source = self._get_param("apiKeySource")
endpoint_configuration = self._get_param("endpointConfiguration")
tags = self._get_param("tags")
policy = self._get_param("policy")
# Param validation
if api_key_source and api_key_source not in API_KEY_SOURCES:
@ -94,6 +95,7 @@ class APIGatewayResponse(BaseResponse):
api_key_source=api_key_source,
endpoint_configuration=endpoint_configuration,
tags=tags,
policy=policy,
)
return 200, {}, json.dumps(rest_api.to_dict())

View file

@ -1006,11 +1006,11 @@ class LambdaBackend(BaseBackend):
return True
return False
def add_policy_statement(self, function_name, raw):
def add_permission(self, function_name, raw):
fn = self.get_function(function_name)
fn.policy.add_statement(raw)
def del_policy_statement(self, function_name, sid, revision=""):
def remove_permission(self, function_name, sid, revision=""):
fn = self.get_function(function_name)
fn.policy.del_statement(sid, revision)

View file

@ -146,7 +146,7 @@ class LambdaResponse(BaseResponse):
function_name = path.split("/")[-2]
if self.lambda_backend.get_function(function_name):
statement = self.body
self.lambda_backend.add_policy_statement(function_name, statement)
self.lambda_backend.add_permission(function_name, statement)
return 200, {}, json.dumps({"Statement": statement})
else:
return 404, {}, "{}"
@ -166,9 +166,7 @@ class LambdaResponse(BaseResponse):
statement_id = path.split("/")[-1].split("?")[0]
revision = querystring.get("RevisionId", "")
if self.lambda_backend.get_function(function_name):
self.lambda_backend.del_policy_statement(
function_name, statement_id, revision
)
self.lambda_backend.remove_permission(function_name, statement_id, revision)
return 204, {}, "{}"
else:
return 404, {}, "{}"

View file

@ -23,6 +23,7 @@ from moto.ec2 import ec2_backends
from moto.ec2_instance_connect import ec2_instance_connect_backends
from moto.ecr import ecr_backends
from moto.ecs import ecs_backends
from moto.elasticbeanstalk import eb_backends
from moto.elb import elb_backends
from moto.elbv2 import elbv2_backends
from moto.emr import emr_backends
@ -77,6 +78,7 @@ BACKENDS = {
"ec2_instance_connect": ec2_instance_connect_backends,
"ecr": ecr_backends,
"ecs": ecs_backends,
"elasticbeanstalk": eb_backends,
"elb": elb_backends,
"elbv2": elbv2_backends,
"events": events_backends,

View file

@ -239,8 +239,11 @@ class FakeStack(BaseModel):
self.cross_stack_resources = cross_stack_resources or {}
self.resource_map = self._create_resource_map()
self.output_map = self._create_output_map()
self._add_stack_event("CREATE_COMPLETE")
self.status = "CREATE_COMPLETE"
if create_change_set:
self.status = "REVIEW_IN_PROGRESS"
else:
self.create_resources()
self._add_stack_event("CREATE_COMPLETE")
self.creation_time = datetime.utcnow()
def _create_resource_map(self):
@ -253,7 +256,7 @@ class FakeStack(BaseModel):
self.template_dict,
self.cross_stack_resources,
)
resource_map.create()
resource_map.load()
return resource_map
def _create_output_map(self):
@ -326,6 +329,10 @@ class FakeStack(BaseModel):
def exports(self):
return self.output_map.exports
def create_resources(self):
self.resource_map.create()
self.status = "CREATE_COMPLETE"
def update(self, template, role_arn=None, parameters=None, tags=None):
self._add_stack_event(
"UPDATE_IN_PROGRESS", resource_status_reason="User Initiated"
@ -640,6 +647,7 @@ class CloudFormationBackend(BaseBackend):
else:
stack._add_stack_event("UPDATE_IN_PROGRESS")
stack._add_stack_event("UPDATE_COMPLETE")
stack.create_resources()
return True
def describe_stacks(self, name_or_stack_id):

View file

@ -531,14 +531,16 @@ class ResourceMap(collections_abc.Mapping):
for condition_name in self.lazy_condition_map:
self.lazy_condition_map[condition_name]
def create(self):
def load(self):
self.load_mapping()
self.transform_mapping()
self.load_parameters()
self.load_conditions()
def create(self):
# Since this is a lazy map, to create every object we just need to
# iterate through self.
# Assumes that self.load() has been called before
self.tags.update(
{
"aws:cloudformation:stack-name": self.get("AWS::StackName"),

View file

@ -22,6 +22,14 @@ class Dimension(object):
self.name = name
self.value = value
def __eq__(self, item):
if isinstance(item, Dimension):
return self.name == item.name and self.value == item.value
return False
def __ne__(self, item): # Only needed on Py2; Py3 defines it implicitly
return self != item
def daterange(start, stop, step=timedelta(days=1), inclusive=False):
"""
@ -124,6 +132,17 @@ class MetricDatum(BaseModel):
Dimension(dimension["Name"], dimension["Value"]) for dimension in dimensions
]
def filter(self, namespace, name, dimensions):
if namespace and namespace != self.namespace:
return False
if name and name != self.name:
return False
if dimensions and any(
Dimension(d["Name"], d["Value"]) not in self.dimensions for d in dimensions
):
return False
return True
class Dashboard(BaseModel):
def __init__(self, name, body):
@ -202,6 +221,15 @@ class CloudWatchBackend(BaseBackend):
self.metric_data = []
self.paged_metric_data = {}
@property
# Retrieve a list of all OOTB metrics that are provided by metrics providers
# Computed on the fly
def aws_metric_data(self):
md = []
for name, service in metric_providers.items():
md.extend(service.get_cloudwatch_metrics())
return md
def put_metric_alarm(
self,
name,
@ -295,6 +323,43 @@ class CloudWatchBackend(BaseBackend):
)
)
def get_metric_data(self, queries, start_time, end_time):
period_data = [
md for md in self.metric_data if start_time <= md.timestamp <= end_time
]
results = []
for query in queries:
query_ns = query["metric_stat._metric._namespace"]
query_name = query["metric_stat._metric._metric_name"]
query_data = [
md
for md in period_data
if md.namespace == query_ns and md.name == query_name
]
metric_values = [m.value for m in query_data]
result_vals = []
stat = query["metric_stat._stat"]
if len(metric_values) > 0:
if stat == "Average":
result_vals.append(sum(metric_values) / len(metric_values))
elif stat == "Minimum":
result_vals.append(min(metric_values))
elif stat == "Maximum":
result_vals.append(max(metric_values))
elif stat == "Sum":
result_vals.append(sum(metric_values))
label = query["metric_stat._metric._metric_name"] + " " + stat
results.append(
{
"id": query["id"],
"label": label,
"vals": result_vals,
"timestamps": [datetime.now() for _ in result_vals],
}
)
return results
def get_metric_statistics(
self, namespace, metric_name, start_time, end_time, period, stats
):
@ -334,7 +399,7 @@ class CloudWatchBackend(BaseBackend):
return data
def get_all_metrics(self):
return self.metric_data
return self.metric_data + self.aws_metric_data
def put_dashboard(self, name, body):
self.dashboards[name] = Dashboard(name, body)
@ -386,7 +451,7 @@ class CloudWatchBackend(BaseBackend):
self.alarms[alarm_name].update_state(reason, reason_data, state_value)
def list_metrics(self, next_token, namespace, metric_name):
def list_metrics(self, next_token, namespace, metric_name, dimensions):
if next_token:
if next_token not in self.paged_metric_data:
raise RESTError(
@ -397,15 +462,16 @@ class CloudWatchBackend(BaseBackend):
del self.paged_metric_data[next_token] # Cant reuse same token twice
return self._get_paginated(metrics)
else:
metrics = self.get_filtered_metrics(metric_name, namespace)
metrics = self.get_filtered_metrics(metric_name, namespace, dimensions)
return self._get_paginated(metrics)
def get_filtered_metrics(self, metric_name, namespace):
def get_filtered_metrics(self, metric_name, namespace, dimensions):
metrics = self.get_all_metrics()
if namespace:
metrics = [md for md in metrics if md.namespace == namespace]
if metric_name:
metrics = [md for md in metrics if md.name == metric_name]
metrics = [
md
for md in metrics
if md.filter(namespace=namespace, name=metric_name, dimensions=dimensions)
]
return metrics
def _get_paginated(self, metrics):
@ -445,3 +511,8 @@ for region in Session().get_available_regions(
cloudwatch_backends[region] = CloudWatchBackend()
for region in Session().get_available_regions("cloudwatch", partition_name="aws-cn"):
cloudwatch_backends[region] = CloudWatchBackend()
# List of services that provide OOTB CW metrics
# See the S3Backend constructor for an example
# TODO: We might have to separate this out per region for non-global services
metric_providers = {}

View file

@ -92,6 +92,18 @@ class CloudWatchResponse(BaseResponse):
template = self.response_template(PUT_METRIC_DATA_TEMPLATE)
return template.render()
@amzn_request_id
def get_metric_data(self):
start = dtparse(self._get_param("StartTime"))
end = dtparse(self._get_param("EndTime"))
queries = self._get_list_prefix("MetricDataQueries.member")
results = self.cloudwatch_backend.get_metric_data(
start_time=start, end_time=end, queries=queries
)
template = self.response_template(GET_METRIC_DATA_TEMPLATE)
return template.render(results=results)
@amzn_request_id
def get_metric_statistics(self):
namespace = self._get_param("Namespace")
@ -124,9 +136,10 @@ class CloudWatchResponse(BaseResponse):
def list_metrics(self):
namespace = self._get_param("Namespace")
metric_name = self._get_param("MetricName")
dimensions = self._get_multi_param("Dimensions.member")
next_token = self._get_param("NextToken")
next_token, metrics = self.cloudwatch_backend.list_metrics(
next_token, namespace, metric_name
next_token, namespace, metric_name, dimensions
)
template = self.response_template(LIST_METRICS_TEMPLATE)
return template.render(metrics=metrics, next_token=next_token)
@ -285,6 +298,35 @@ PUT_METRIC_DATA_TEMPLATE = """<PutMetricDataResponse xmlns="http://monitoring.am
</ResponseMetadata>
</PutMetricDataResponse>"""
GET_METRIC_DATA_TEMPLATE = """<GetMetricDataResponse xmlns="http://monitoring.amazonaws.com/doc/2010-08-01/">
<ResponseMetadata>
<RequestId>
{{ request_id }}
</RequestId>
</ResponseMetadata>
<GetMetricDataResult>
<MetricDataResults>
{% for result in results %}
<member>
<Id>{{ result.id }}</Id>
<Label>{{ result.label }}</Label>
<StatusCode>Complete</StatusCode>
<Timestamps>
{% for val in result.timestamps %}
<member>{{ val }}</member>
{% endfor %}
</Timestamps>
<Values>
{% for val in result.vals %}
<member>{{ val }}</member>
{% endfor %}
</Values>
</member>
{% endfor %}
</MetricDataResults>
</GetMetricDataResult>
</GetMetricDataResponse>"""
GET_METRIC_STATISTICS_TEMPLATE = """<GetMetricStatisticsResponse xmlns="http://monitoring.amazonaws.com/doc/2010-08-01/">
<ResponseMetadata>
<RequestId>

View file

@ -1,5 +1,5 @@
from moto.core.utils import get_random_hex
from uuid import uuid4
def get_random_identity_id(region):
return "{0}:{1}".format(region, get_random_hex(length=19))
return "{0}:{1}".format(region, uuid4())

View file

@ -328,3 +328,25 @@ def py2_strip_unicode_keys(blob):
blob = new_set
return blob
def tags_from_query_string(
querystring_dict, prefix="Tag", key_suffix="Key", value_suffix="Value"
):
response_values = {}
for key, value in querystring_dict.items():
if key.startswith(prefix) and key.endswith(key_suffix):
tag_index = key.replace(prefix + ".", "").replace("." + key_suffix, "")
tag_key = querystring_dict.get(
"{prefix}.{index}.{key_suffix}".format(
prefix=prefix, index=tag_index, key_suffix=key_suffix,
)
)[0]
tag_value_key = "{prefix}.{index}.{value_suffix}".format(
prefix=prefix, index=tag_index, value_suffix=value_suffix,
)
if tag_value_key in querystring_dict:
response_values[tag_key] = querystring_dict.get(tag_value_key)[0]
else:
response_values[tag_key] = None
return response_values

View file

@ -39,6 +39,17 @@ class AttributeDoesNotExist(MockValidationException):
super(AttributeDoesNotExist, self).__init__(self.attr_does_not_exist_msg)
class ProvidedKeyDoesNotExist(MockValidationException):
provided_key_does_not_exist_msg = (
"The provided key element does not match the schema"
)
def __init__(self):
super(ProvidedKeyDoesNotExist, self).__init__(
self.provided_key_does_not_exist_msg
)
class ExpressionAttributeNameNotDefined(InvalidUpdateExpression):
name_not_defined_msg = "An expression attribute name used in the document path is not defined; attribute name: {n}"
@ -131,3 +142,10 @@ class IncorrectOperandType(InvalidUpdateExpression):
super(IncorrectOperandType, self).__init__(
self.inv_operand_msg.format(f=operator_or_function, t=operand_type)
)
class IncorrectDataType(MockValidationException):
inc_data_type_msg = "An operand in the update expression has an incorrect data type"
def __init__(self):
super(IncorrectDataType, self).__init__(self.inc_data_type_msg)

View file

@ -8,7 +8,6 @@ import re
import uuid
from boto3 import Session
from botocore.exceptions import ParamValidationError
from moto.compat import OrderedDict
from moto.core import BaseBackend, BaseModel
from moto.core.utils import unix_time
@ -20,8 +19,9 @@ from moto.dynamodb2.exceptions import (
ItemSizeTooLarge,
ItemSizeToUpdateTooLarge,
)
from moto.dynamodb2.models.utilities import bytesize, attribute_is_list
from moto.dynamodb2.models.utilities import bytesize
from moto.dynamodb2.models.dynamo_type import DynamoType
from moto.dynamodb2.parsing.executors import UpdateExpressionExecutor
from moto.dynamodb2.parsing.expressions import UpdateExpressionParser
from moto.dynamodb2.parsing.validators import UpdateExpressionValidator
@ -71,9 +71,23 @@ class Item(BaseModel):
for key, value in attrs.items():
self.attrs[key] = DynamoType(value)
def __eq__(self, other):
return all(
[
self.hash_key == other.hash_key,
self.hash_key_type == other.hash_key_type,
self.range_key == other.range_key,
self.range_key_type == other.range_key_type,
self.attrs == other.attrs,
]
)
def __repr__(self):
return "Item: {0}".format(self.to_json())
def size(self):
return sum(bytesize(key) + value.size() for key, value in self.attrs.items())
def to_json(self):
attributes = {}
for attribute_key, attribute in self.attrs.items():
@ -91,192 +105,6 @@ class Item(BaseModel):
included = self.attrs
return {"Item": included}
def update(
self, update_expression, expression_attribute_names, expression_attribute_values
):
# Update subexpressions are identifiable by the operator keyword, so split on that and
# get rid of the empty leading string.
parts = [
p
for p in re.split(
r"\b(SET|REMOVE|ADD|DELETE)\b", update_expression, flags=re.I
)
if p
]
# make sure that we correctly found only operator/value pairs
assert (
len(parts) % 2 == 0
), "Mismatched operators and values in update expression: '{}'".format(
update_expression
)
for action, valstr in zip(parts[:-1:2], parts[1::2]):
action = action.upper()
# "Should" retain arguments in side (...)
values = re.split(r",(?![^(]*\))", valstr)
for value in values:
# A Real value
value = value.lstrip(":").rstrip(",").strip()
for k, v in expression_attribute_names.items():
value = re.sub(r"{0}\b".format(k), v, value)
if action == "REMOVE":
key = value
attr, list_index = attribute_is_list(key.split(".")[0])
if "." not in key:
if list_index:
new_list = DynamoType(self.attrs[attr])
new_list.delete(None, list_index)
self.attrs[attr] = new_list
else:
self.attrs.pop(value, None)
else:
# Handle nested dict updates
self.attrs[attr].delete(".".join(key.split(".")[1:]))
elif action == "SET":
key, value = value.split("=", 1)
key = key.strip()
value = value.strip()
# check whether key is a list
attr, list_index = attribute_is_list(key.split(".")[0])
# If value not exists, changes value to a default if needed, else its the same as it was
value = self._get_default(value)
# If operation == list_append, get the original value and append it
value = self._get_appended_list(value, expression_attribute_values)
if type(value) != DynamoType:
if value in expression_attribute_values:
dyn_value = DynamoType(expression_attribute_values[value])
else:
dyn_value = DynamoType({"S": value})
else:
dyn_value = value
if "." in key and attr not in self.attrs:
raise ValueError # Setting nested attr not allowed if first attr does not exist yet
elif attr not in self.attrs:
try:
self.attrs[attr] = dyn_value # set new top-level attribute
except ItemSizeTooLarge:
raise ItemSizeToUpdateTooLarge()
else:
self.attrs[attr].set(
".".join(key.split(".")[1:]), dyn_value, list_index
) # set value recursively
elif action == "ADD":
key, value = value.split(" ", 1)
key = key.strip()
value_str = value.strip()
if value_str in expression_attribute_values:
dyn_value = DynamoType(expression_attribute_values[value])
else:
raise TypeError
# Handle adding numbers - value gets added to existing value,
# or added to 0 if it doesn't exist yet
if dyn_value.is_number():
existing = self.attrs.get(key, DynamoType({"N": "0"}))
if not existing.same_type(dyn_value):
raise TypeError()
self.attrs[key] = DynamoType(
{
"N": str(
decimal.Decimal(existing.value)
+ decimal.Decimal(dyn_value.value)
)
}
)
# Handle adding sets - value is added to the set, or set is
# created with only this value if it doesn't exist yet
# New value must be of same set type as previous value
elif dyn_value.is_set():
key_head = key.split(".")[0]
key_tail = ".".join(key.split(".")[1:])
if key_head not in self.attrs:
self.attrs[key_head] = DynamoType({dyn_value.type: {}})
existing = self.attrs.get(key_head)
existing = existing.get(key_tail)
if existing.value and not existing.same_type(dyn_value):
raise TypeError()
new_set = set(existing.value or []).union(dyn_value.value)
existing.set(
key=None,
new_value=DynamoType({dyn_value.type: list(new_set)}),
)
else: # Number and Sets are the only supported types for ADD
raise TypeError
elif action == "DELETE":
key, value = value.split(" ", 1)
key = key.strip()
value_str = value.strip()
if value_str in expression_attribute_values:
dyn_value = DynamoType(expression_attribute_values[value])
else:
raise TypeError
if not dyn_value.is_set():
raise TypeError
key_head = key.split(".")[0]
key_tail = ".".join(key.split(".")[1:])
existing = self.attrs.get(key_head)
existing = existing.get(key_tail)
if existing:
if not existing.same_type(dyn_value):
raise TypeError
new_set = set(existing.value).difference(dyn_value.value)
existing.set(
key=None,
new_value=DynamoType({existing.type: list(new_set)}),
)
else:
raise NotImplementedError(
"{} update action not yet supported".format(action)
)
def _get_appended_list(self, value, expression_attribute_values):
if type(value) != DynamoType:
list_append_re = re.match("list_append\\((.+),(.+)\\)", value)
if list_append_re:
new_value = expression_attribute_values[list_append_re.group(2).strip()]
old_list_key = list_append_re.group(1)
# old_key could be a function itself (if_not_exists)
if old_list_key.startswith("if_not_exists"):
old_list = self._get_default(old_list_key)
if not isinstance(old_list, DynamoType):
old_list = DynamoType(expression_attribute_values[old_list])
else:
old_list = self.attrs[old_list_key.split(".")[0]]
if "." in old_list_key:
# Value is nested inside a map - find the appropriate child attr
old_list = old_list.child_attr(
".".join(old_list_key.split(".")[1:])
)
if not old_list.is_list():
raise ParamValidationError
old_list.value.extend([DynamoType(v) for v in new_value["L"]])
value = old_list
return value
def _get_default(self, value):
if value.startswith("if_not_exists"):
# Function signature
match = re.match(
r".*if_not_exists\s*\((?P<path>.+),\s*(?P<default>.+)\).*", value
)
if not match:
raise TypeError
path, value = match.groups()
# If it already exists, get its value so we dont overwrite it
if path in self.attrs:
value = self.attrs[path]
return value
def update_with_attribute_updates(self, attribute_updates):
for attribute_name, update_action in attribute_updates.items():
action = update_action["Action"]
@ -921,6 +749,14 @@ class Table(BaseModel):
break
last_evaluated_key = None
size_limit = 1000000 # DynamoDB has a 1MB size limit
item_size = sum(res.size() for res in results)
if item_size > size_limit:
item_size = idx = 0
while item_size + results[idx].size() < size_limit:
item_size += results[idx].size()
idx += 1
limit = min(limit, idx) if limit else idx
if limit and len(results) > limit:
results = results[:limit]
last_evaluated_key = {self.hash_key_attr: results[-1].hash_key}
@ -1198,9 +1034,9 @@ class DynamoDBBackend(BaseBackend):
table_name,
key,
update_expression,
attribute_updates,
expression_attribute_names,
expression_attribute_values,
attribute_updates=None,
expected=None,
condition_expression=None,
):
@ -1255,17 +1091,18 @@ class DynamoDBBackend(BaseBackend):
item = table.get_item(hash_value, range_value)
if update_expression:
UpdateExpressionValidator(
validated_ast = UpdateExpressionValidator(
update_expression_ast,
expression_attribute_names=expression_attribute_names,
expression_attribute_values=expression_attribute_values,
item=item,
).validate()
item.update(
update_expression,
expression_attribute_names,
expression_attribute_values,
)
try:
UpdateExpressionExecutor(
validated_ast, item, expression_attribute_names
).execute()
except ItemSizeTooLarge:
raise ItemSizeToUpdateTooLarge()
else:
item.update_with_attribute_updates(attribute_updates)
if table.stream_shard is not None:
@ -1321,6 +1158,94 @@ class DynamoDBBackend(BaseBackend):
return table.ttl
def transact_write_items(self, transact_items):
# Create a backup in case any of the transactions fail
original_table_state = copy.deepcopy(self.tables)
try:
for item in transact_items:
if "ConditionCheck" in item:
item = item["ConditionCheck"]
key = item["Key"]
table_name = item["TableName"]
condition_expression = item.get("ConditionExpression", None)
expression_attribute_names = item.get(
"ExpressionAttributeNames", None
)
expression_attribute_values = item.get(
"ExpressionAttributeValues", None
)
current = self.get_item(table_name, key)
condition_op = get_filter_expression(
condition_expression,
expression_attribute_names,
expression_attribute_values,
)
if not condition_op.expr(current):
raise ValueError("The conditional request failed")
elif "Put" in item:
item = item["Put"]
attrs = item["Item"]
table_name = item["TableName"]
condition_expression = item.get("ConditionExpression", None)
expression_attribute_names = item.get(
"ExpressionAttributeNames", None
)
expression_attribute_values = item.get(
"ExpressionAttributeValues", None
)
self.put_item(
table_name,
attrs,
condition_expression=condition_expression,
expression_attribute_names=expression_attribute_names,
expression_attribute_values=expression_attribute_values,
)
elif "Delete" in item:
item = item["Delete"]
key = item["Key"]
table_name = item["TableName"]
condition_expression = item.get("ConditionExpression", None)
expression_attribute_names = item.get(
"ExpressionAttributeNames", None
)
expression_attribute_values = item.get(
"ExpressionAttributeValues", None
)
self.delete_item(
table_name,
key,
condition_expression=condition_expression,
expression_attribute_names=expression_attribute_names,
expression_attribute_values=expression_attribute_values,
)
elif "Update" in item:
item = item["Update"]
key = item["Key"]
table_name = item["TableName"]
update_expression = item["UpdateExpression"]
condition_expression = item.get("ConditionExpression", None)
expression_attribute_names = item.get(
"ExpressionAttributeNames", None
)
expression_attribute_values = item.get(
"ExpressionAttributeValues", None
)
self.update_item(
table_name,
key,
update_expression=update_expression,
condition_expression=condition_expression,
expression_attribute_names=expression_attribute_names,
expression_attribute_values=expression_attribute_values,
)
else:
raise ValueError
except: # noqa: E722 Do not use bare except
# Rollback to the original state, and reraise the error
self.tables = original_table_state
raise
dynamodb_backends = {}
for region in Session().get_available_regions("dynamodb"):

View file

@ -1,10 +1,53 @@
import six
from moto.dynamodb2.comparisons import get_comparison_func
from moto.dynamodb2.exceptions import InvalidUpdateExpression
from moto.dynamodb2.exceptions import InvalidUpdateExpression, IncorrectDataType
from moto.dynamodb2.models.utilities import attribute_is_list, bytesize
class DDBType(object):
"""
Official documentation at https://docs.aws.amazon.com/amazondynamodb/latest/APIReference/API_AttributeValue.html
"""
BINARY_SET = "BS"
NUMBER_SET = "NS"
STRING_SET = "SS"
STRING = "S"
NUMBER = "N"
MAP = "M"
LIST = "L"
BOOLEAN = "BOOL"
BINARY = "B"
NULL = "NULL"
class DDBTypeConversion(object):
_human_type_mapping = {
val: key.replace("_", " ")
for key, val in DDBType.__dict__.items()
if key.upper() == key
}
@classmethod
def get_human_type(cls, abbreviated_type):
"""
Args:
abbreviated_type(str): An attribute of DDBType
Returns:
str: The human readable form of the DDBType.
"""
try:
human_type_str = cls._human_type_mapping[abbreviated_type]
except KeyError:
raise ValueError(
"Invalid abbreviated_type {at}".format(at=abbreviated_type)
)
return human_type_str
class DynamoType(object):
"""
http://docs.aws.amazon.com/amazondynamodb/latest/developerguide/DataModel.html#DataModelDataTypes
@ -50,13 +93,22 @@ class DynamoType(object):
self.value = new_value.value
else:
if attr not in self.value: # nonexistingattribute
type_of_new_attr = "M" if "." in key else new_value.type
type_of_new_attr = DDBType.MAP if "." in key else new_value.type
self.value[attr] = DynamoType({type_of_new_attr: {}})
# {'M': {'foo': DynamoType}} ==> DynamoType.set(new_value)
self.value[attr].set(
".".join(key.split(".")[1:]), new_value, list_index
)
def __contains__(self, item):
if self.type == DDBType.STRING:
return False
try:
self.__getitem__(item)
return True
except KeyError:
return False
def delete(self, key, index=None):
if index:
if not key:
@ -126,27 +178,35 @@ class DynamoType(object):
def __add__(self, other):
if self.type != other.type:
raise TypeError("Different types of operandi is not allowed.")
if self.type == "N":
return DynamoType({"N": "{v}".format(v=int(self.value) + int(other.value))})
if self.is_number():
self_value = float(self.value) if "." in self.value else int(self.value)
other_value = float(other.value) if "." in other.value else int(other.value)
return DynamoType(
{DDBType.NUMBER: "{v}".format(v=self_value + other_value)}
)
else:
raise TypeError("Sum only supported for Numbers.")
raise IncorrectDataType()
def __sub__(self, other):
if self.type != other.type:
raise TypeError("Different types of operandi is not allowed.")
if self.type == "N":
return DynamoType({"N": "{v}".format(v=int(self.value) - int(other.value))})
if self.type == DDBType.NUMBER:
self_value = float(self.value) if "." in self.value else int(self.value)
other_value = float(other.value) if "." in other.value else int(other.value)
return DynamoType(
{DDBType.NUMBER: "{v}".format(v=self_value - other_value)}
)
else:
raise TypeError("Sum only supported for Numbers.")
def __getitem__(self, item):
if isinstance(item, six.string_types):
# If our DynamoType is a map it should be subscriptable with a key
if self.type == "M":
if self.type == DDBType.MAP:
return self.value[item]
elif isinstance(item, int):
# If our DynamoType is a list is should be subscriptable with an index
if self.type == "L":
if self.type == DDBType.LIST:
return self.value[item]
raise TypeError(
"This DynamoType {dt} is not subscriptable by a {it}".format(
@ -154,6 +214,20 @@ class DynamoType(object):
)
)
def __setitem__(self, key, value):
if isinstance(key, int):
if self.is_list():
if key >= len(self.value):
# DynamoDB doesn't care you are out of box just add it to the end.
self.value.append(value)
else:
self.value[key] = value
elif isinstance(key, six.string_types):
if self.is_map():
self.value[key] = value
else:
raise NotImplementedError("No set_item for {t}".format(t=type(key)))
@property
def cast_value(self):
if self.is_number():
@ -222,16 +296,22 @@ class DynamoType(object):
return comparison_func(self.cast_value, *range_values)
def is_number(self):
return self.type == "N"
return self.type == DDBType.NUMBER
def is_set(self):
return self.type == "SS" or self.type == "NS" or self.type == "BS"
return self.type in (DDBType.STRING_SET, DDBType.NUMBER_SET, DDBType.BINARY_SET)
def is_list(self):
return self.type == "L"
return self.type == DDBType.LIST
def is_map(self):
return self.type == "M"
return self.type == DDBType.MAP
def same_type(self, other):
return self.type == other.type
def pop(self, key, *args, **kwargs):
if self.is_map() or self.is_list():
self.value.pop(key, *args, **kwargs)
else:
raise TypeError("pop not supported for DynamoType {t}".format(t=self.type))

View file

@ -0,0 +1,262 @@
from abc import abstractmethod
from moto.dynamodb2.exceptions import IncorrectOperandType, IncorrectDataType
from moto.dynamodb2.models import DynamoType
from moto.dynamodb2.models.dynamo_type import DDBTypeConversion, DDBType
from moto.dynamodb2.parsing.ast_nodes import (
UpdateExpressionSetAction,
UpdateExpressionDeleteAction,
UpdateExpressionRemoveAction,
UpdateExpressionAddAction,
UpdateExpressionPath,
DDBTypedValue,
ExpressionAttribute,
ExpressionSelector,
ExpressionAttributeName,
)
from moto.dynamodb2.parsing.validators import ExpressionPathResolver
class NodeExecutor(object):
def __init__(self, ast_node, expression_attribute_names):
self.node = ast_node
self.expression_attribute_names = expression_attribute_names
@abstractmethod
def execute(self, item):
pass
def get_item_part_for_path_nodes(self, item, path_nodes):
"""
For a list of path nodes travers the item by following the path_nodes
Args:
item(Item):
path_nodes(list):
Returns:
"""
if len(path_nodes) == 0:
return item.attrs
else:
return ExpressionPathResolver(
self.expression_attribute_names
).resolve_expression_path_nodes_to_dynamo_type(item, path_nodes)
def get_item_before_end_of_path(self, item):
"""
Get the part ot the item where the item will perform the action. For most actions this should be the parent. As
that element will need to be modified by the action.
Args:
item(Item):
Returns:
DynamoType or dict: The path to be set
"""
return self.get_item_part_for_path_nodes(
item, self.get_path_expression_nodes()[:-1]
)
def get_item_at_end_of_path(self, item):
"""
For a DELETE the path points at the stringset so we need to evaluate the full path.
Args:
item(Item):
Returns:
DynamoType or dict: The path to be set
"""
return self.get_item_part_for_path_nodes(item, self.get_path_expression_nodes())
# Get the part ot the item where the item will perform the action. For most actions this should be the parent. As
# that element will need to be modified by the action.
get_item_part_in_which_to_perform_action = get_item_before_end_of_path
def get_path_expression_nodes(self):
update_expression_path = self.node.children[0]
assert isinstance(update_expression_path, UpdateExpressionPath)
return update_expression_path.children
def get_element_to_action(self):
return self.get_path_expression_nodes()[-1]
def get_action_value(self):
"""
Returns:
DynamoType: The value to be set
"""
ddb_typed_value = self.node.children[1]
assert isinstance(ddb_typed_value, DDBTypedValue)
dynamo_type_value = ddb_typed_value.children[0]
assert isinstance(dynamo_type_value, DynamoType)
return dynamo_type_value
class SetExecutor(NodeExecutor):
def execute(self, item):
self.set(
item_part_to_modify_with_set=self.get_item_part_in_which_to_perform_action(
item
),
element_to_set=self.get_element_to_action(),
value_to_set=self.get_action_value(),
expression_attribute_names=self.expression_attribute_names,
)
@classmethod
def set(
cls,
item_part_to_modify_with_set,
element_to_set,
value_to_set,
expression_attribute_names,
):
if isinstance(element_to_set, ExpressionAttribute):
attribute_name = element_to_set.get_attribute_name()
item_part_to_modify_with_set[attribute_name] = value_to_set
elif isinstance(element_to_set, ExpressionSelector):
index = element_to_set.get_index()
item_part_to_modify_with_set[index] = value_to_set
elif isinstance(element_to_set, ExpressionAttributeName):
attribute_name = expression_attribute_names[
element_to_set.get_attribute_name_placeholder()
]
item_part_to_modify_with_set[attribute_name] = value_to_set
else:
raise NotImplementedError(
"Moto does not support setting {t} yet".format(t=type(element_to_set))
)
class DeleteExecutor(NodeExecutor):
operator = "operator: DELETE"
def execute(self, item):
string_set_to_remove = self.get_action_value()
assert isinstance(string_set_to_remove, DynamoType)
if not string_set_to_remove.is_set():
raise IncorrectOperandType(
self.operator,
DDBTypeConversion.get_human_type(string_set_to_remove.type),
)
string_set = self.get_item_at_end_of_path(item)
assert isinstance(string_set, DynamoType)
if string_set.type != string_set_to_remove.type:
raise IncorrectDataType()
# String set is currently implemented as a list
string_set_list = string_set.value
stringset_to_remove_list = string_set_to_remove.value
for value in stringset_to_remove_list:
try:
string_set_list.remove(value)
except (KeyError, ValueError):
# DynamoDB does not mind if value is not present
pass
class RemoveExecutor(NodeExecutor):
def execute(self, item):
element_to_remove = self.get_element_to_action()
if isinstance(element_to_remove, ExpressionAttribute):
attribute_name = element_to_remove.get_attribute_name()
self.get_item_part_in_which_to_perform_action(item).pop(
attribute_name, None
)
elif isinstance(element_to_remove, ExpressionAttributeName):
attribute_name = self.expression_attribute_names[
element_to_remove.get_attribute_name_placeholder()
]
self.get_item_part_in_which_to_perform_action(item).pop(
attribute_name, None
)
elif isinstance(element_to_remove, ExpressionSelector):
index = element_to_remove.get_index()
try:
self.get_item_part_in_which_to_perform_action(item).pop(index)
except IndexError:
# DynamoDB does not care that index is out of bounds, it will just do nothing.
pass
else:
raise NotImplementedError(
"Moto does not support setting {t} yet".format(
t=type(element_to_remove)
)
)
class AddExecutor(NodeExecutor):
def execute(self, item):
value_to_add = self.get_action_value()
if isinstance(value_to_add, DynamoType):
if value_to_add.is_set():
current_string_set = self.get_item_at_end_of_path(item)
assert isinstance(current_string_set, DynamoType)
if not current_string_set.type == value_to_add.type:
raise IncorrectDataType()
# Sets are implemented as list
for value in value_to_add.value:
if value in current_string_set.value:
continue
else:
current_string_set.value.append(value)
elif value_to_add.type == DDBType.NUMBER:
existing_value = self.get_item_at_end_of_path(item)
assert isinstance(existing_value, DynamoType)
if not existing_value.type == DDBType.NUMBER:
raise IncorrectDataType()
new_value = existing_value + value_to_add
SetExecutor.set(
item_part_to_modify_with_set=self.get_item_before_end_of_path(item),
element_to_set=self.get_element_to_action(),
value_to_set=new_value,
expression_attribute_names=self.expression_attribute_names,
)
else:
raise IncorrectDataType()
class UpdateExpressionExecutor(object):
execution_map = {
UpdateExpressionSetAction: SetExecutor,
UpdateExpressionAddAction: AddExecutor,
UpdateExpressionRemoveAction: RemoveExecutor,
UpdateExpressionDeleteAction: DeleteExecutor,
}
def __init__(self, update_ast, item, expression_attribute_names):
self.update_ast = update_ast
self.item = item
self.expression_attribute_names = expression_attribute_names
def execute(self, node=None):
"""
As explained in moto.dynamodb2.parsing.expressions.NestableExpressionParserMixin._create_node the order of nodes
in the AST can be translated of the order of statements in the expression. As such we can start at the root node
and process the nodes 1-by-1. If no specific execution for the node type is defined we can execute the children
in order since it will be a container node that is expandable and left child will be first in the statement.
Args:
node(Node):
Returns:
None
"""
if node is None:
node = self.update_ast
node_executor = self.get_specific_execution(node)
if node_executor is None:
for node in node.children:
self.execute(node)
else:
node_executor(node, self.expression_attribute_names).execute(self.item)
def get_specific_execution(self, node):
for node_class in self.execution_map:
if isinstance(node, node_class):
return self.execution_map[node_class]
return None

View file

@ -11,6 +11,7 @@ from moto.dynamodb2.exceptions import (
ExpressionAttributeNameNotDefined,
IncorrectOperandType,
InvalidUpdateExpressionInvalidDocumentPath,
ProvidedKeyDoesNotExist,
)
from moto.dynamodb2.models import DynamoType
from moto.dynamodb2.parsing.ast_nodes import (
@ -56,6 +57,76 @@ class ExpressionAttributeValueProcessor(DepthFirstTraverser):
return DDBTypedValue(DynamoType(target))
class ExpressionPathResolver(object):
def __init__(self, expression_attribute_names):
self.expression_attribute_names = expression_attribute_names
@classmethod
def raise_exception_if_keyword(cls, attribute):
if attribute.upper() in ReservedKeywords.get_reserved_keywords():
raise AttributeIsReservedKeyword(attribute)
def resolve_expression_path(self, item, update_expression_path):
assert isinstance(update_expression_path, UpdateExpressionPath)
return self.resolve_expression_path_nodes(item, update_expression_path.children)
def resolve_expression_path_nodes(self, item, update_expression_path_nodes):
target = item.attrs
for child in update_expression_path_nodes:
# First replace placeholder with attribute_name
attr_name = None
if isinstance(child, ExpressionAttributeName):
attr_placeholder = child.get_attribute_name_placeholder()
try:
attr_name = self.expression_attribute_names[attr_placeholder]
except KeyError:
raise ExpressionAttributeNameNotDefined(attr_placeholder)
elif isinstance(child, ExpressionAttribute):
attr_name = child.get_attribute_name()
self.raise_exception_if_keyword(attr_name)
if attr_name is not None:
# Resolv attribute_name
try:
target = target[attr_name]
except (KeyError, TypeError):
if child == update_expression_path_nodes[-1]:
return NoneExistingPath(creatable=True)
return NoneExistingPath()
else:
if isinstance(child, ExpressionPathDescender):
continue
elif isinstance(child, ExpressionSelector):
index = child.get_index()
if target.is_list():
try:
target = target[index]
except IndexError:
# When a list goes out of bounds when assigning that is no problem when at the assignment
# side. It will just append to the list.
if child == update_expression_path_nodes[-1]:
return NoneExistingPath(creatable=True)
return NoneExistingPath()
else:
raise InvalidUpdateExpressionInvalidDocumentPath
else:
raise NotImplementedError(
"Path resolution for {t}".format(t=type(child))
)
if not isinstance(target, DynamoType):
print(target)
return DDBTypedValue(target)
def resolve_expression_path_nodes_to_dynamo_type(
self, item, update_expression_path_nodes
):
node = self.resolve_expression_path_nodes(item, update_expression_path_nodes)
if isinstance(node, NoneExistingPath):
raise ProvidedKeyDoesNotExist()
assert isinstance(node, DDBTypedValue)
return node.get_value()
class ExpressionAttributeResolvingProcessor(DepthFirstTraverser):
def _processing_map(self):
return {
@ -107,55 +178,9 @@ class ExpressionAttributeResolvingProcessor(DepthFirstTraverser):
return node
def resolve_expression_path(self, node):
assert isinstance(node, UpdateExpressionPath)
target = deepcopy(self.item.attrs)
for child in node.children:
# First replace placeholder with attribute_name
attr_name = None
if isinstance(child, ExpressionAttributeName):
attr_placeholder = child.get_attribute_name_placeholder()
try:
attr_name = self.expression_attribute_names[attr_placeholder]
except KeyError:
raise ExpressionAttributeNameNotDefined(attr_placeholder)
elif isinstance(child, ExpressionAttribute):
attr_name = child.get_attribute_name()
self.raise_exception_if_keyword(attr_name)
if attr_name is not None:
# Resolv attribute_name
try:
target = target[attr_name]
except (KeyError, TypeError):
if child == node.children[-1]:
return NoneExistingPath(creatable=True)
return NoneExistingPath()
else:
if isinstance(child, ExpressionPathDescender):
continue
elif isinstance(child, ExpressionSelector):
index = child.get_index()
if target.is_list():
try:
target = target[index]
except IndexError:
# When a list goes out of bounds when assigning that is no problem when at the assignment
# side. It will just append to the list.
if child == node.children[-1]:
return NoneExistingPath(creatable=True)
return NoneExistingPath()
else:
raise InvalidUpdateExpressionInvalidDocumentPath
else:
raise NotImplementedError(
"Path resolution for {t}".format(t=type(child))
)
return DDBTypedValue(DynamoType(target))
@classmethod
def raise_exception_if_keyword(cls, attribute):
if attribute.upper() in ReservedKeywords.get_reserved_keywords():
raise AttributeIsReservedKeyword(attribute)
return ExpressionPathResolver(
self.expression_attribute_names
).resolve_expression_path(self.item, node)
class UpdateExpressionFunctionEvaluator(DepthFirstTraverser):
@ -183,7 +208,9 @@ class UpdateExpressionFunctionEvaluator(DepthFirstTraverser):
assert isinstance(result, (DDBTypedValue, NoneExistingPath))
return result
elif function_name == "list_append":
first_arg = self.get_list_from_ddb_typed_value(first_arg, function_name)
first_arg = deepcopy(
self.get_list_from_ddb_typed_value(first_arg, function_name)
)
second_arg = self.get_list_from_ddb_typed_value(second_arg, function_name)
for list_element in second_arg.value:
first_arg.value.append(list_element)

View file

@ -470,8 +470,10 @@ class DynamoHandler(BaseResponse):
for k, v in six.iteritems(self.body.get("ExpressionAttributeNames", {}))
)
if " AND " in key_condition_expression:
expressions = key_condition_expression.split(" AND ", 1)
if " and " in key_condition_expression.lower():
expressions = re.split(
" AND ", key_condition_expression, maxsplit=1, flags=re.IGNORECASE
)
index_hash_key = [key for key in index if key["KeyType"] == "HASH"][0]
hash_key_var = reverse_attribute_lookup.get(
@ -760,12 +762,12 @@ class DynamoHandler(BaseResponse):
item = self.dynamodb_backend.update_item(
name,
key,
update_expression,
attribute_updates,
expression_attribute_names,
expression_attribute_values,
expected,
condition_expression,
update_expression=update_expression,
attribute_updates=attribute_updates,
expression_attribute_names=expression_attribute_names,
expression_attribute_values=expression_attribute_values,
expected=expected,
condition_expression=condition_expression,
)
except MockValidationException as mve:
er = "com.amazonaws.dynamodb.v20111205#ValidationException"
@ -922,3 +924,15 @@ class DynamoHandler(BaseResponse):
result.update({"ConsumedCapacity": [v for v in consumed_capacity.values()]})
return dynamo_json_dump(result)
def transact_write_items(self):
transact_items = self.body["TransactItems"]
try:
self.dynamodb_backend.transact_write_items(transact_items)
except ValueError:
er = "com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException"
return self.error(
er, "A condition specified in the operation could not be evaluated."
)
response = {"ConsumedCapacity": [], "ItemCollectionMetrics": {}}
return dynamo_json_dump(response)

View file

@ -557,6 +557,10 @@ class Instance(TaggedEC2Resource, BotoInstance):
# worst case we'll get IP address exaustion... rarely
pass
def add_block_device(self, size, device_path):
volume = self.ec2_backend.create_volume(size, self.region_name)
self.ec2_backend.attach_volume(volume.id, self.id, device_path)
def setup_defaults(self):
# Default have an instance with root volume should you not wish to
# override with attach volume cmd.
@ -564,9 +568,10 @@ class Instance(TaggedEC2Resource, BotoInstance):
self.ec2_backend.attach_volume(volume.id, self.id, "/dev/sda1")
def teardown_defaults(self):
volume_id = self.block_device_mapping["/dev/sda1"].volume_id
self.ec2_backend.detach_volume(volume_id, self.id, "/dev/sda1")
self.ec2_backend.delete_volume(volume_id)
if "/dev/sda1" in self.block_device_mapping:
volume_id = self.block_device_mapping["/dev/sda1"].volume_id
self.ec2_backend.detach_volume(volume_id, self.id, "/dev/sda1")
self.ec2_backend.delete_volume(volume_id)
@property
def get_block_device_mapping(self):
@ -621,6 +626,7 @@ class Instance(TaggedEC2Resource, BotoInstance):
subnet_id=properties.get("SubnetId"),
key_name=properties.get("KeyName"),
private_ip=properties.get("PrivateIpAddress"),
block_device_mappings=properties.get("BlockDeviceMappings", {}),
)
instance = reservation.instances[0]
for tag in properties.get("Tags", []):
@ -880,7 +886,14 @@ class InstanceBackend(object):
)
new_reservation.instances.append(new_instance)
new_instance.add_tags(instance_tags)
new_instance.setup_defaults()
if "block_device_mappings" in kwargs:
for block_device in kwargs["block_device_mappings"]:
new_instance.add_block_device(
block_device["Ebs"]["VolumeSize"], block_device["DeviceName"]
)
else:
new_instance.setup_defaults()
return new_reservation
def start_instances(self, instance_ids):

View file

@ -52,7 +52,7 @@ class InstanceResponse(BaseResponse):
private_ip = self._get_param("PrivateIpAddress")
associate_public_ip = self._get_param("AssociatePublicIpAddress")
key_name = self._get_param("KeyName")
ebs_optimized = self._get_param("EbsOptimized")
ebs_optimized = self._get_param("EbsOptimized") or False
instance_initiated_shutdown_behavior = self._get_param(
"InstanceInitiatedShutdownBehavior"
)

View file

@ -2,7 +2,8 @@ from __future__ import unicode_literals
from moto.core.responses import BaseResponse
from moto.ec2.models import validate_resource_ids
from moto.ec2.utils import tags_from_query_string, filters_from_querystring
from moto.ec2.utils import filters_from_querystring
from moto.core.utils import tags_from_query_string
class TagResponse(BaseResponse):

View file

@ -196,22 +196,6 @@ def split_route_id(route_id):
return values[0], values[1]
def tags_from_query_string(querystring_dict):
prefix = "Tag"
suffix = "Key"
response_values = {}
for key, value in querystring_dict.items():
if key.startswith(prefix) and key.endswith(suffix):
tag_index = key.replace(prefix + ".", "").replace("." + suffix, "")
tag_key = querystring_dict.get("Tag.{0}.Key".format(tag_index))[0]
tag_value_key = "Tag.{0}.Value".format(tag_index)
if tag_value_key in querystring_dict:
response_values[tag_key] = querystring_dict.get(tag_value_key)[0]
else:
response_values[tag_key] = None
return response_values
def dhcp_configuration_from_querystring(querystring, option="DhcpConfiguration"):
"""
turn:

View file

@ -0,0 +1,4 @@
from .models import eb_backends
from moto.core.models import base_decorator
mock_elasticbeanstalk = base_decorator(eb_backends)

View file

@ -0,0 +1,15 @@
from moto.core.exceptions import RESTError
class InvalidParameterValueError(RESTError):
def __init__(self, message):
super(InvalidParameterValueError, self).__init__(
"InvalidParameterValue", message
)
class ResourceNotFoundException(RESTError):
def __init__(self, message):
super(ResourceNotFoundException, self).__init__(
"ResourceNotFoundException", message
)

View file

@ -0,0 +1,152 @@
import weakref
from boto3 import Session
from moto.core import BaseBackend, BaseModel
from .exceptions import InvalidParameterValueError, ResourceNotFoundException
class FakeEnvironment(BaseModel):
def __init__(
self, application, environment_name, solution_stack_name, tags,
):
self.application = weakref.proxy(
application
) # weakref to break circular dependencies
self.environment_name = environment_name
self.solution_stack_name = solution_stack_name
self.tags = tags
@property
def application_name(self):
return self.application.application_name
@property
def environment_arn(self):
return (
"arn:aws:elasticbeanstalk:{region}:{account_id}:"
"environment/{application_name}/{environment_name}".format(
region=self.region,
account_id="123456789012",
application_name=self.application_name,
environment_name=self.environment_name,
)
)
@property
def platform_arn(self):
return "TODO" # TODO
@property
def region(self):
return self.application.region
class FakeApplication(BaseModel):
def __init__(self, backend, application_name):
self.backend = weakref.proxy(backend) # weakref to break cycles
self.application_name = application_name
self.environments = dict()
def create_environment(
self, environment_name, solution_stack_name, tags,
):
if environment_name in self.environments:
raise InvalidParameterValueError
env = FakeEnvironment(
application=self,
environment_name=environment_name,
solution_stack_name=solution_stack_name,
tags=tags,
)
self.environments[environment_name] = env
return env
@property
def region(self):
return self.backend.region
class EBBackend(BaseBackend):
def __init__(self, region):
self.region = region
self.applications = dict()
def reset(self):
# preserve region
region = self.region
self._reset_model_refs()
self.__dict__ = {}
self.__init__(region)
def create_application(self, application_name):
if application_name in self.applications:
raise InvalidParameterValueError(
"Application {} already exists.".format(application_name)
)
new_app = FakeApplication(backend=self, application_name=application_name,)
self.applications[application_name] = new_app
return new_app
def create_environment(self, app, environment_name, stack_name, tags):
return app.create_environment(
environment_name=environment_name,
solution_stack_name=stack_name,
tags=tags,
)
def describe_environments(self):
envs = []
for app in self.applications.values():
for env in app.environments.values():
envs.append(env)
return envs
def list_available_solution_stacks(self):
# Implemented in response.py
pass
def update_tags_for_resource(self, resource_arn, tags_to_add, tags_to_remove):
try:
res = self._find_environment_by_arn(resource_arn)
except KeyError:
raise ResourceNotFoundException(
"Resource not found for ARN '{}'.".format(resource_arn)
)
for key, value in tags_to_add.items():
res.tags[key] = value
for key in tags_to_remove:
del res.tags[key]
def list_tags_for_resource(self, resource_arn):
try:
res = self._find_environment_by_arn(resource_arn)
except KeyError:
raise ResourceNotFoundException(
"Resource not found for ARN '{}'.".format(resource_arn)
)
return res.tags
def _find_environment_by_arn(self, arn):
for app in self.applications.keys():
for env in self.applications[app].environments.values():
if env.environment_arn == arn:
return env
raise KeyError()
eb_backends = {}
for region in Session().get_available_regions("elasticbeanstalk"):
eb_backends[region] = EBBackend(region)
for region in Session().get_available_regions(
"elasticbeanstalk", partition_name="aws-us-gov"
):
eb_backends[region] = EBBackend(region)
for region in Session().get_available_regions(
"elasticbeanstalk", partition_name="aws-cn"
):
eb_backends[region] = EBBackend(region)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,11 @@
from __future__ import unicode_literals
from .responses import EBResponse
url_bases = [
r"https?://elasticbeanstalk.(?P<region>[a-zA-Z0-9\-_]+).amazonaws.com",
]
url_paths = {
"{0}/$": EBResponse.dispatch,
}

View file

@ -10,9 +10,10 @@ from six.moves.urllib.parse import urlparse
from moto.core.responses import AWSServiceSpec
from moto.core.responses import BaseResponse
from moto.core.responses import xml_to_json_response
from moto.core.utils import tags_from_query_string
from .exceptions import EmrError
from .models import emr_backends
from .utils import steps_from_query_string, tags_from_query_string
from .utils import steps_from_query_string
def generate_boto3_response(operation):
@ -91,7 +92,7 @@ class ElasticMapReduceResponse(BaseResponse):
@generate_boto3_response("AddTags")
def add_tags(self):
cluster_id = self._get_param("ResourceId")
tags = tags_from_query_string(self.querystring)
tags = tags_from_query_string(self.querystring, prefix="Tags")
self.backend.add_tags(cluster_id, tags)
template = self.response_template(ADD_TAGS_TEMPLATE)
return template.render()

View file

@ -22,22 +22,6 @@ def random_instance_group_id(size=13):
return "i-{0}".format(random_id())
def tags_from_query_string(querystring_dict):
prefix = "Tags"
suffix = "Key"
response_values = {}
for key, value in querystring_dict.items():
if key.startswith(prefix) and key.endswith(suffix):
tag_index = key.replace(prefix + ".", "").replace("." + suffix, "")
tag_key = querystring_dict.get("Tags.{0}.Key".format(tag_index))[0]
tag_value_key = "Tags.{0}.Value".format(tag_index)
if tag_value_key in querystring_dict:
response_values[tag_key] = querystring_dict.get(tag_value_key)[0]
else:
response_values[tag_key] = None
return response_values
def steps_from_query_string(querystring_dict):
steps = []
for step in querystring_dict:

View file

@ -145,10 +145,7 @@ class ResourceGroupsTaggingAPIBackend(BaseBackend):
# Do S3, resource type s3
if not resource_type_filters or "s3" in resource_type_filters:
for bucket in self.s3_backend.buckets.values():
tags = []
for tag in bucket.tags.tag_set.tags:
tags.append({"Key": tag.key, "Value": tag.value})
tags = self.s3_backend.tagger.list_tags_for_resource(bucket.arn)["Tags"]
if not tags or not tag_filter(
tags
): # Skip if no tags, or invalid filter
@ -362,8 +359,9 @@ class ResourceGroupsTaggingAPIBackend(BaseBackend):
# Do S3, resource type s3
for bucket in self.s3_backend.buckets.values():
for tag in bucket.tags.tag_set.tags:
yield tag.key
tags = self.s3_backend.tagger.get_tag_dict_for_resource(bucket.arn)
for key, _ in tags.items():
yield key
# EC2 tags
def get_ec2_keys(res_id):
@ -414,9 +412,10 @@ class ResourceGroupsTaggingAPIBackend(BaseBackend):
# Do S3, resource type s3
for bucket in self.s3_backend.buckets.values():
for tag in bucket.tags.tag_set.tags:
if tag.key == tag_key:
yield tag.value
tags = self.s3_backend.tagger.get_tag_dict_for_resource(bucket.arn)
for key, value in tags.items():
if key == tag_key:
yield value
# EC2 tags
def get_ec2_values(res_id):

View file

@ -22,6 +22,8 @@ import six
from bisect import insort
from moto.core import ACCOUNT_ID, BaseBackend, BaseModel
from moto.core.utils import iso_8601_datetime_with_milliseconds, rfc_1123_datetime
from moto.cloudwatch.models import MetricDatum
from moto.utilities.tagging_service import TaggingService
from .exceptions import (
BucketAlreadyExists,
MissingBucket,
@ -34,7 +36,6 @@ from .exceptions import (
MalformedXML,
InvalidStorageClass,
InvalidTargetBucketForLogging,
DuplicateTagKeys,
CrossLocationLoggingProhibitted,
NoSuchPublicAccessBlockConfiguration,
InvalidPublicAccessBlockConfiguration,
@ -94,6 +95,7 @@ class FakeKey(BaseModel):
version_id=0,
max_buffer_size=DEFAULT_KEY_BUFFER_SIZE,
multipart=None,
bucket_name=None,
):
self.name = name
self.last_modified = datetime.datetime.utcnow()
@ -105,8 +107,8 @@ class FakeKey(BaseModel):
self._etag = etag
self._version_id = version_id
self._is_versioned = is_versioned
self._tagging = FakeTagging()
self.multipart = multipart
self.bucket_name = bucket_name
self._value_buffer = tempfile.SpooledTemporaryFile(max_size=max_buffer_size)
self._max_buffer_size = max_buffer_size
@ -126,6 +128,13 @@ class FakeKey(BaseModel):
self.lock.release()
return r
@property
def arn(self):
# S3 Objects don't have an ARN, but we do need something unique when creating tags against this resource
return "arn:aws:s3:::{}/{}/{}".format(
self.bucket_name, self.name, self.version_id
)
@value.setter
def value(self, new_value):
self._value_buffer.seek(0)
@ -152,9 +161,6 @@ class FakeKey(BaseModel):
self._metadata = {}
self._metadata.update(metadata)
def set_tagging(self, tagging):
self._tagging = tagging
def set_storage_class(self, storage):
if storage is not None and storage not in STORAGE_CLASS:
raise InvalidStorageClass(storage=storage)
@ -210,10 +216,6 @@ class FakeKey(BaseModel):
def metadata(self):
return self._metadata
@property
def tagging(self):
return self._tagging
@property
def response_dict(self):
res = {
@ -471,26 +473,10 @@ def get_canned_acl(acl):
return FakeAcl(grants=grants)
class FakeTagging(BaseModel):
def __init__(self, tag_set=None):
self.tag_set = tag_set or FakeTagSet()
class FakeTagSet(BaseModel):
def __init__(self, tags=None):
self.tags = tags or []
class FakeTag(BaseModel):
def __init__(self, key, value=None):
self.key = key
self.value = value
class LifecycleFilter(BaseModel):
def __init__(self, prefix=None, tag=None, and_filter=None):
self.prefix = prefix
self.tag = tag
(self.tag_key, self.tag_value) = tag if tag else (None, None)
self.and_filter = and_filter
def to_config_dict(self):
@ -499,11 +485,11 @@ class LifecycleFilter(BaseModel):
"predicate": {"type": "LifecyclePrefixPredicate", "prefix": self.prefix}
}
elif self.tag:
elif self.tag_key:
return {
"predicate": {
"type": "LifecycleTagPredicate",
"tag": {"key": self.tag.key, "value": self.tag.value},
"tag": {"key": self.tag_key, "value": self.tag_value},
}
}
@ -527,12 +513,9 @@ class LifecycleAndFilter(BaseModel):
if self.prefix is not None:
data.append({"type": "LifecyclePrefixPredicate", "prefix": self.prefix})
for tag in self.tags:
for key, value in self.tags.items():
data.append(
{
"type": "LifecycleTagPredicate",
"tag": {"key": tag.key, "value": tag.value},
}
{"type": "LifecycleTagPredicate", "tag": {"key": key, "value": value},}
)
return data
@ -787,7 +770,6 @@ class FakeBucket(BaseModel):
self.policy = None
self.website_configuration = None
self.acl = get_canned_acl("private")
self.tags = FakeTagging()
self.cors = []
self.logging = {}
self.notification_configuration = None
@ -879,7 +861,7 @@ class FakeBucket(BaseModel):
and_filter = None
if rule["Filter"].get("And"):
filters += 1
and_tags = []
and_tags = {}
if rule["Filter"]["And"].get("Tag"):
if not isinstance(rule["Filter"]["And"]["Tag"], list):
rule["Filter"]["And"]["Tag"] = [
@ -887,7 +869,7 @@ class FakeBucket(BaseModel):
]
for t in rule["Filter"]["And"]["Tag"]:
and_tags.append(FakeTag(t["Key"], t.get("Value", "")))
and_tags[t["Key"]] = t.get("Value", "")
try:
and_prefix = (
@ -901,7 +883,7 @@ class FakeBucket(BaseModel):
filter_tag = None
if rule["Filter"].get("Tag"):
filters += 1
filter_tag = FakeTag(
filter_tag = (
rule["Filter"]["Tag"]["Key"],
rule["Filter"]["Tag"].get("Value", ""),
)
@ -988,16 +970,6 @@ class FakeBucket(BaseModel):
def delete_cors(self):
self.cors = []
def set_tags(self, tagging):
self.tags = tagging
def delete_tags(self):
self.tags = FakeTagging()
@property
def tagging(self):
return self.tags
def set_logging(self, logging_config, bucket_backend):
if not logging_config:
self.logging = {}
@ -1085,6 +1057,10 @@ class FakeBucket(BaseModel):
def set_acl(self, acl):
self.acl = acl
@property
def arn(self):
return "arn:aws:s3:::{}".format(self.name)
@property
def physical_resource_id(self):
return self.name
@ -1110,7 +1086,7 @@ class FakeBucket(BaseModel):
int(time.mktime(self.creation_date.timetuple()))
), # PY2 and 3 compatible
"configurationItemMD5Hash": "",
"arn": "arn:aws:s3:::{}".format(self.name),
"arn": self.arn,
"resourceType": "AWS::S3::Bucket",
"resourceId": self.name,
"resourceName": self.name,
@ -1119,7 +1095,7 @@ class FakeBucket(BaseModel):
"resourceCreationTime": str(self.creation_date),
"relatedEvents": [],
"relationships": [],
"tags": {tag.key: tag.value for tag in self.tagging.tag_set.tags},
"tags": s3_backend.tagger.get_tag_dict_for_resource(self.arn),
"configuration": {
"name": self.name,
"owner": {"id": OWNER},
@ -1181,6 +1157,42 @@ class S3Backend(BaseBackend):
def __init__(self):
self.buckets = {}
self.account_public_access_block = None
self.tagger = TaggingService()
# TODO: This is broken! DO NOT IMPORT MUTABLE DATA TYPES FROM OTHER AREAS -- THIS BREAKS UNMOCKING!
# WRAP WITH A GETTER/SETTER FUNCTION
# Register this class as a CloudWatch Metric Provider
# Must provide a method 'get_cloudwatch_metrics' that will return a list of metrics, based on the data available
# metric_providers["S3"] = self
def get_cloudwatch_metrics(self):
metrics = []
for name, bucket in self.buckets.items():
metrics.append(
MetricDatum(
namespace="AWS/S3",
name="BucketSizeBytes",
value=bucket.keys.item_size(),
dimensions=[
{"Name": "StorageType", "Value": "StandardStorage"},
{"Name": "BucketName", "Value": name},
],
timestamp=datetime.datetime.now(),
)
)
metrics.append(
MetricDatum(
namespace="AWS/S3",
name="NumberOfObjects",
value=len(bucket.keys),
dimensions=[
{"Name": "StorageType", "Value": "AllStorageTypes"},
{"Name": "BucketName", "Value": name},
],
timestamp=datetime.datetime.now(),
)
)
return metrics
def create_bucket(self, bucket_name, region_name):
if bucket_name in self.buckets:
@ -1350,23 +1362,32 @@ class S3Backend(BaseBackend):
else:
return None
def set_key_tagging(self, bucket_name, key_name, tagging, version_id=None):
key = self.get_key(bucket_name, key_name, version_id)
def get_key_tags(self, key):
return self.tagger.list_tags_for_resource(key.arn)
def set_key_tags(self, key, tags, key_name=None):
if key is None:
raise MissingKey(key_name)
key.set_tagging(tagging)
self.tagger.delete_all_tags_for_resource(key.arn)
self.tagger.tag_resource(
key.arn, [{"Key": key, "Value": value} for key, value in tags.items()],
)
return key
def put_bucket_tagging(self, bucket_name, tagging):
tag_keys = [tag.key for tag in tagging.tag_set.tags]
if len(tag_keys) != len(set(tag_keys)):
raise DuplicateTagKeys()
def get_bucket_tags(self, bucket_name):
bucket = self.get_bucket(bucket_name)
bucket.set_tags(tagging)
return self.tagger.list_tags_for_resource(bucket.arn)
def put_bucket_tags(self, bucket_name, tags):
bucket = self.get_bucket(bucket_name)
self.tagger.delete_all_tags_for_resource(bucket.arn)
self.tagger.tag_resource(
bucket.arn, [{"Key": key, "Value": value} for key, value in tags.items()],
)
def delete_bucket_tagging(self, bucket_name):
bucket = self.get_bucket(bucket_name)
bucket.delete_tags()
self.tagger.delete_all_tags_for_resource(bucket.arn)
def put_bucket_cors(self, bucket_name, cors_rules):
bucket = self.get_bucket(bucket_name)
@ -1574,6 +1595,7 @@ class S3Backend(BaseBackend):
key = self.get_key(src_bucket_name, src_key_name, version_id=src_version_id)
new_key = key.copy(dest_key_name, dest_bucket.is_versioned)
self.tagger.copy_tags(key.arn, new_key.arn)
if storage is not None:
new_key.set_storage_class(storage)

View file

@ -24,6 +24,7 @@ from moto.s3bucket_path.utils import (
from .exceptions import (
BucketAlreadyExists,
DuplicateTagKeys,
S3ClientError,
MissingBucket,
MissingKey,
@ -43,9 +44,6 @@ from .models import (
FakeGrant,
FakeAcl,
FakeKey,
FakeTagging,
FakeTagSet,
FakeTag,
)
from .utils import (
bucket_name_from_url,
@ -134,7 +132,8 @@ ACTION_MAP = {
def parse_key_name(pth):
return pth.lstrip("/")
# strip the first '/' left by urlparse
return pth[1:] if pth.startswith("/") else pth
def is_delete_keys(request, path, bucket_name):
@ -378,13 +377,13 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
template = self.response_template(S3_OBJECT_ACL_RESPONSE)
return template.render(obj=bucket)
elif "tagging" in querystring:
bucket = self.backend.get_bucket(bucket_name)
tags = self.backend.get_bucket_tags(bucket_name)["Tags"]
# "Special Error" if no tags:
if len(bucket.tagging.tag_set.tags) == 0:
if len(tags) == 0:
template = self.response_template(S3_NO_BUCKET_TAGGING)
return 404, {}, template.render(bucket_name=bucket_name)
template = self.response_template(S3_BUCKET_TAGGING_RESPONSE)
return template.render(bucket=bucket)
template = self.response_template(S3_OBJECT_TAGGING_RESPONSE)
return template.render(tags=tags)
elif "logging" in querystring:
bucket = self.backend.get_bucket(bucket_name)
if not bucket.logging:
@ -652,7 +651,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return ""
elif "tagging" in querystring:
tagging = self._bucket_tagging_from_xml(body)
self.backend.put_bucket_tagging(bucket_name, tagging)
self.backend.put_bucket_tags(bucket_name, tagging)
return ""
elif "website" in querystring:
self.backend.set_bucket_website_configuration(bucket_name, body)
@ -839,27 +838,35 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
def _bucket_response_delete_keys(self, request, body, bucket_name):
template = self.response_template(S3_DELETE_KEYS_RESPONSE)
body_dict = xmltodict.parse(body)
keys = minidom.parseString(body).getElementsByTagName("Key")
deleted_names = []
error_names = []
if len(keys) == 0:
objects = body_dict["Delete"].get("Object", [])
if not isinstance(objects, list):
# We expect a list of objects, but when there is a single <Object> node xmltodict does not
# return a list.
objects = [objects]
if len(objects) == 0:
raise MalformedXML()
for k in keys:
key_name = k.firstChild.nodeValue
deleted_objects = []
error_names = []
for object_ in objects:
key_name = object_["Key"]
version_id = object_.get("VersionId", None)
success = self.backend.delete_key(
bucket_name, undo_clean_key_name(key_name)
bucket_name, undo_clean_key_name(key_name), version_id=version_id
)
if success:
deleted_names.append(key_name)
deleted_objects.append((key_name, version_id))
else:
error_names.append(key_name)
return (
200,
{},
template.render(deleted=deleted_names, delete_errors=error_names),
template.render(deleted=deleted_objects, delete_errors=error_names),
)
def _handle_range_header(self, request, headers, response_content):
@ -1098,8 +1105,9 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
template = self.response_template(S3_OBJECT_ACL_RESPONSE)
return 200, response_headers, template.render(obj=key)
if "tagging" in query:
tags = self.backend.get_key_tags(key)["Tags"]
template = self.response_template(S3_OBJECT_TAGGING_RESPONSE)
return 200, response_headers, template.render(obj=key)
return 200, response_headers, template.render(tags=tags)
response_headers.update(key.metadata)
response_headers.update(key.response_dict)
@ -1171,8 +1179,9 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
version_id = query["versionId"][0]
else:
version_id = None
key = self.backend.get_key(bucket_name, key_name, version_id=version_id)
tagging = self._tagging_from_xml(body)
self.backend.set_key_tagging(bucket_name, key_name, tagging, version_id)
self.backend.set_key_tags(key, tagging, key_name)
return 200, response_headers, ""
if "x-amz-copy-source" in request.headers:
@ -1213,7 +1222,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
tdirective = request.headers.get("x-amz-tagging-directive")
if tdirective == "REPLACE":
tagging = self._tagging_from_headers(request.headers)
new_key.set_tagging(tagging)
self.backend.set_key_tags(new_key, tagging)
template = self.response_template(S3_OBJECT_COPY_RESPONSE)
response_headers.update(new_key.response_dict)
return 200, response_headers, template.render(key=new_key)
@ -1237,7 +1246,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
new_key.website_redirect_location = request.headers.get(
"x-amz-website-redirect-location"
)
new_key.set_tagging(tagging)
self.backend.set_key_tags(new_key, tagging)
response_headers.update(new_key.response_dict)
return 200, response_headers, ""
@ -1365,55 +1374,45 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return None
def _tagging_from_headers(self, headers):
tags = {}
if headers.get("x-amz-tagging"):
parsed_header = parse_qs(headers["x-amz-tagging"], keep_blank_values=True)
tags = []
for tag in parsed_header.items():
tags.append(FakeTag(tag[0], tag[1][0]))
tag_set = FakeTagSet(tags)
tagging = FakeTagging(tag_set)
return tagging
else:
return FakeTagging()
tags[tag[0]] = tag[1][0]
return tags
def _tagging_from_xml(self, xml):
parsed_xml = xmltodict.parse(xml, force_list={"Tag": True})
tags = []
tags = {}
for tag in parsed_xml["Tagging"]["TagSet"]["Tag"]:
tags.append(FakeTag(tag["Key"], tag["Value"]))
tags[tag["Key"]] = tag["Value"]
tag_set = FakeTagSet(tags)
tagging = FakeTagging(tag_set)
return tagging
return tags
def _bucket_tagging_from_xml(self, xml):
parsed_xml = xmltodict.parse(xml)
tags = []
tags = {}
# Optional if no tags are being sent:
if parsed_xml["Tagging"].get("TagSet"):
# If there is only 1 tag, then it's not a list:
if not isinstance(parsed_xml["Tagging"]["TagSet"]["Tag"], list):
tags.append(
FakeTag(
parsed_xml["Tagging"]["TagSet"]["Tag"]["Key"],
parsed_xml["Tagging"]["TagSet"]["Tag"]["Value"],
)
)
tags[parsed_xml["Tagging"]["TagSet"]["Tag"]["Key"]] = parsed_xml[
"Tagging"
]["TagSet"]["Tag"]["Value"]
else:
for tag in parsed_xml["Tagging"]["TagSet"]["Tag"]:
tags.append(FakeTag(tag["Key"], tag["Value"]))
if tag["Key"] in tags:
raise DuplicateTagKeys()
tags[tag["Key"]] = tag["Value"]
# Verify that "aws:" is not in the tags. If so, then this is a problem:
for tag in tags:
if tag.key.startswith("aws:"):
for key, _ in tags.items():
if key.startswith("aws:"):
raise NoSystemTags()
tag_set = FakeTagSet(tags)
tagging = FakeTagging(tag_set)
return tagging
return tags
def _cors_from_xml(self, xml):
parsed_xml = xmltodict.parse(xml)
@ -1733,10 +1732,10 @@ S3_BUCKET_LIFECYCLE_CONFIGURATION = """<?xml version="1.0" encoding="UTF-8"?>
{% if rule.filter.prefix != None %}
<Prefix>{{ rule.filter.prefix }}</Prefix>
{% endif %}
{% if rule.filter.tag %}
{% if rule.filter.tag_key %}
<Tag>
<Key>{{ rule.filter.tag.key }}</Key>
<Value>{{ rule.filter.tag.value }}</Value>
<Key>{{ rule.filter.tag_key }}</Key>
<Value>{{ rule.filter.tag_value }}</Value>
</Tag>
{% endif %}
{% if rule.filter.and_filter %}
@ -1744,10 +1743,10 @@ S3_BUCKET_LIFECYCLE_CONFIGURATION = """<?xml version="1.0" encoding="UTF-8"?>
{% if rule.filter.and_filter.prefix != None %}
<Prefix>{{ rule.filter.and_filter.prefix }}</Prefix>
{% endif %}
{% for tag in rule.filter.and_filter.tags %}
{% for key, value in rule.filter.and_filter.tags.items() %}
<Tag>
<Key>{{ tag.key }}</Key>
<Value>{{ tag.value }}</Value>
<Key>{{ key }}</Key>
<Value>{{ value }}</Value>
</Tag>
{% endfor %}
</And>
@ -1861,9 +1860,10 @@ S3_BUCKET_GET_VERSIONS = """<?xml version="1.0" encoding="UTF-8"?>
S3_DELETE_KEYS_RESPONSE = """<?xml version="1.0" encoding="UTF-8"?>
<DeleteResult xmlns="http://s3.amazonaws.com/doc/2006-03-01">
{% for k in deleted %}
{% for k, v in deleted %}
<Deleted>
<Key>{{k}}</Key>
{% if v %}<VersionId>{{v}}</VersionId>{% endif %}
</Deleted>
{% endfor %}
{% for k in delete_errors %}
@ -1908,22 +1908,10 @@ S3_OBJECT_TAGGING_RESPONSE = """\
<?xml version="1.0" encoding="UTF-8"?>
<Tagging xmlns="http://s3.amazonaws.com/doc/2006-03-01/">
<TagSet>
{% for tag in obj.tagging.tag_set.tags %}
{% for tag in tags %}
<Tag>
<Key>{{ tag.key }}</Key>
<Value>{{ tag.value }}</Value>
</Tag>
{% endfor %}
</TagSet>
</Tagging>"""
S3_BUCKET_TAGGING_RESPONSE = """<?xml version="1.0" encoding="UTF-8"?>
<Tagging>
<TagSet>
{% for tag in bucket.tagging.tag_set.tags %}
<Tag>
<Key>{{ tag.key }}</Key>
<Value>{{ tag.value }}</Value>
<Key>{{ tag.Key }}</Key>
<Value>{{ tag.Value }}</Value>
</Tag>
{% endfor %}
</TagSet>

View file

@ -15,5 +15,5 @@ url_paths = {
# path-based bucket + key
"{0}/(?P<bucket_name_path>[^/]+)/(?P<key_name>.+)": S3ResponseInstance.key_or_control_response,
# subdomain bucket + key with empty first part of path
"{0}//(?P<key_name>.*)$": S3ResponseInstance.key_or_control_response,
"{0}/(?P<key_name>/.*)$": S3ResponseInstance.key_or_control_response,
}

View file

@ -146,6 +146,12 @@ class _VersionedKeyStore(dict):
for key in self:
yield key, self.getlist(key)
def item_size(self):
size = 0
for val in self.values():
size += sys.getsizeof(val)
return size
items = iteritems = _iteritems
lists = iterlists = _iterlists
values = itervalues = _itervalues

View file

@ -121,8 +121,16 @@ class SecretsManagerBackend(BaseBackend):
"You can't perform this operation on the secret because it was marked for deletion."
)
secret = self.secrets[secret_id]
tags = secret["tags"]
description = secret["description"]
version_id = self._add_secret(
secret_id, secret_string=secret_string, secret_binary=secret_binary
secret_id,
secret_string=secret_string,
secret_binary=secret_binary,
description=description,
tags=tags,
)
response = json.dumps(
@ -136,7 +144,13 @@ class SecretsManagerBackend(BaseBackend):
return response
def create_secret(
self, name, secret_string=None, secret_binary=None, tags=[], **kwargs
self,
name,
secret_string=None,
secret_binary=None,
description=None,
tags=[],
**kwargs
):
# error if secret exists
@ -146,7 +160,11 @@ class SecretsManagerBackend(BaseBackend):
)
version_id = self._add_secret(
name, secret_string=secret_string, secret_binary=secret_binary, tags=tags
name,
secret_string=secret_string,
secret_binary=secret_binary,
description=description,
tags=tags,
)
response = json.dumps(
@ -164,6 +182,7 @@ class SecretsManagerBackend(BaseBackend):
secret_id,
secret_string=None,
secret_binary=None,
description=None,
tags=[],
version_id=None,
version_stages=None,
@ -216,13 +235,27 @@ class SecretsManagerBackend(BaseBackend):
secret["rotation_lambda_arn"] = ""
secret["auto_rotate_after_days"] = 0
secret["tags"] = tags
secret["description"] = description
return version_id
def put_secret_value(self, secret_id, secret_string, secret_binary, version_stages):
if secret_id in self.secrets.keys():
secret = self.secrets[secret_id]
tags = secret["tags"]
description = secret["description"]
else:
tags = []
description = ""
version_id = self._add_secret(
secret_id, secret_string, secret_binary, version_stages=version_stages
secret_id,
secret_string,
secret_binary,
description=description,
tags=tags,
version_stages=version_stages,
)
response = json.dumps(
@ -246,7 +279,7 @@ class SecretsManagerBackend(BaseBackend):
{
"ARN": secret_arn(self.region, secret["secret_id"]),
"Name": secret["name"],
"Description": "",
"Description": secret.get("description", ""),
"KmsKeyId": "",
"RotationEnabled": secret["rotation_enabled"],
"RotationLambdaARN": secret["rotation_lambda_arn"],
@ -310,6 +343,7 @@ class SecretsManagerBackend(BaseBackend):
self._add_secret(
secret_id,
old_secret_version["secret_string"],
secret["description"],
secret["tags"],
version_id=new_version_id,
version_stages=["AWSCURRENT"],
@ -416,7 +450,7 @@ class SecretsManagerBackend(BaseBackend):
{
"ARN": secret_arn(self.region, secret["secret_id"]),
"DeletedDate": secret.get("deleted_date", None),
"Description": "",
"Description": secret.get("description", ""),
"KmsKeyId": "",
"LastAccessedDate": None,
"LastChangedDate": None,

View file

@ -21,11 +21,13 @@ class SecretsManagerResponse(BaseResponse):
name = self._get_param("Name")
secret_string = self._get_param("SecretString")
secret_binary = self._get_param("SecretBinary")
description = self._get_param("Description", if_none="")
tags = self._get_param("Tags", if_none=[])
return secretsmanager_backends[self.region].create_secret(
name=name,
secret_string=secret_string,
secret_binary=secret_binary,
description=description,
tags=tags,
)

View file

@ -1,6 +1,7 @@
from __future__ import unicode_literals
import argparse
import io
import json
import re
import sys
@ -29,6 +30,7 @@ UNSIGNED_REQUESTS = {
"AWSCognitoIdentityService": ("cognito-identity", "us-east-1"),
"AWSCognitoIdentityProviderService": ("cognito-idp", "us-east-1"),
}
UNSIGNED_ACTIONS = {"AssumeRoleWithSAML": ("sts", "us-east-1")}
class DomainDispatcherApplication(object):
@ -77,9 +79,13 @@ class DomainDispatcherApplication(object):
else:
# Unsigned request
target = environ.get("HTTP_X_AMZ_TARGET")
action = self.get_action_from_body(environ)
if target:
service, _ = target.split(".", 1)
service, region = UNSIGNED_REQUESTS.get(service, DEFAULT_SERVICE_REGION)
elif action and action in UNSIGNED_ACTIONS:
# See if we can match the Action to a known service
service, region = UNSIGNED_ACTIONS.get(action)
else:
# S3 is the last resort when the target is also unknown
service, region = DEFAULT_SERVICE_REGION
@ -130,6 +136,26 @@ class DomainDispatcherApplication(object):
self.app_instances[backend] = app
return app
def get_action_from_body(self, environ):
body = None
try:
# AWS requests use querystrings as the body (Action=x&Data=y&...)
simple_form = environ["CONTENT_TYPE"].startswith(
"application/x-www-form-urlencoded"
)
request_body_size = int(environ["CONTENT_LENGTH"])
if simple_form and request_body_size:
body = environ["wsgi.input"].read(request_body_size).decode("utf-8")
body_dict = dict(x.split("=") for x in body.split("&"))
return body_dict["Action"]
except (KeyError, ValueError):
pass
finally:
if body:
# We've consumed the body = need to reset it
environ["wsgi.input"] = io.StringIO(body)
return None
def __call__(self, environ, start_response):
backend_app = self.get_application(environ)
return backend_app(environ, start_response)

View file

@ -1,5 +1,7 @@
from __future__ import unicode_literals
from base64 import b64decode
import datetime
import xmltodict
from moto.core import BaseBackend, BaseModel
from moto.core.utils import iso_8601_datetime_with_milliseconds
from moto.core import ACCOUNT_ID
@ -79,5 +81,24 @@ class STSBackend(BaseBackend):
def assume_role_with_web_identity(self, **kwargs):
return self.assume_role(**kwargs)
def assume_role_with_saml(self, **kwargs):
del kwargs["principal_arn"]
saml_assertion_encoded = kwargs.pop("saml_assertion")
saml_assertion_decoded = b64decode(saml_assertion_encoded)
saml_assertion = xmltodict.parse(saml_assertion_decoded.decode("utf-8"))
kwargs["duration"] = int(
saml_assertion["samlp:Response"]["Assertion"]["AttributeStatement"][
"Attribute"
][2]["AttributeValue"]
)
kwargs["role_session_name"] = saml_assertion["samlp:Response"]["Assertion"][
"AttributeStatement"
]["Attribute"][0]["AttributeValue"]
kwargs["external_id"] = None
kwargs["policy"] = None
role = AssumedRole(**kwargs)
self.assumed_roles.append(role)
return role
sts_backend = STSBackend()

View file

@ -71,6 +71,19 @@ class TokenResponse(BaseResponse):
template = self.response_template(ASSUME_ROLE_WITH_WEB_IDENTITY_RESPONSE)
return template.render(role=role)
def assume_role_with_saml(self):
role_arn = self.querystring.get("RoleArn")[0]
principal_arn = self.querystring.get("PrincipalArn")[0]
saml_assertion = self.querystring.get("SAMLAssertion")[0]
role = sts_backend.assume_role_with_saml(
role_arn=role_arn,
principal_arn=principal_arn,
saml_assertion=saml_assertion,
)
template = self.response_template(ASSUME_ROLE_WITH_SAML_RESPONSE)
return template.render(role=role)
def get_caller_identity(self):
template = self.response_template(GET_CALLER_IDENTITY_RESPONSE)
@ -168,6 +181,30 @@ ASSUME_ROLE_WITH_WEB_IDENTITY_RESPONSE = """<AssumeRoleWithWebIdentityResponse x
</AssumeRoleWithWebIdentityResponse>"""
ASSUME_ROLE_WITH_SAML_RESPONSE = """<AssumeRoleWithSAMLResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
<AssumeRoleWithSAMLResult>
<Audience>https://signin.aws.amazon.com/saml</Audience>
<AssumedRoleUser>
<AssumedRoleId>{{ role.user_id }}</AssumedRoleId>
<Arn>{{ role.arn }}</Arn>
</AssumedRoleUser>
<Credentials>
<AccessKeyId>{{ role.access_key_id }}</AccessKeyId>
<SecretAccessKey>{{ role.secret_access_key }}</SecretAccessKey>
<SessionToken>{{ role.session_token }}</SessionToken>
<Expiration>{{ role.expiration_ISO8601 }}</Expiration>
</Credentials>
<Subject>{{ role.user_id }}</Subject>
<NameQualifier>B64EncodedStringOfHashOfIssuerAccountIdAndUserId=</NameQualifier>
<SubjectType>persistent</SubjectType>
<Issuer>http://localhost:3000/</Issuer>
</AssumeRoleWithSAMLResult>
<ResponseMetadata>
<RequestId>c6104cbe-af31-11e0-8154-cbc7ccf896c7</RequestId>
</ResponseMetadata>
</AssumeRoleWithSAMLResponse>"""
GET_CALLER_IDENTITY_RESPONSE = """<GetCallerIdentityResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
<GetCallerIdentityResult>
<Arn>{{ arn }}</Arn>

View file

@ -5,15 +5,23 @@ class TaggingService:
self.valueName = valueName
self.tags = {}
def get_tag_dict_for_resource(self, arn):
result = {}
if self.has_tags(arn):
for k, v in self.tags[arn].items():
result[k] = v
return result
def list_tags_for_resource(self, arn):
result = []
if arn in self.tags:
if self.has_tags(arn):
for k, v in self.tags[arn].items():
result.append({self.keyName: k, self.valueName: v})
return {self.tagName: result}
def delete_all_tags_for_resource(self, arn):
del self.tags[arn]
if self.has_tags(arn):
del self.tags[arn]
def has_tags(self, arn):
return arn in self.tags
@ -27,6 +35,12 @@ class TaggingService:
else:
self.tags[arn][t[self.keyName]] = None
def copy_tags(self, from_arn, to_arn):
if self.has_tags(from_arn):
self.tag_resource(
to_arn, self.list_tags_for_resource(from_arn)[self.tagName]
)
def untag_resource_using_names(self, arn, tag_names):
for name in tag_names:
if name in self.tags.get(arn, {}):