Merge branch 'master' into feature/dynamodb_transact_write_items
This commit is contained in:
commit
56aa454397
101 changed files with 8946 additions and 663 deletions
|
|
@ -21,6 +21,7 @@ from .datasync import mock_datasync # noqa
|
|||
from .dynamodb import mock_dynamodb, mock_dynamodb_deprecated # noqa
|
||||
from .dynamodb2 import mock_dynamodb2, mock_dynamodb2_deprecated # noqa
|
||||
from .dynamodbstreams import mock_dynamodbstreams # noqa
|
||||
from .elasticbeanstalk import mock_elasticbeanstalk # noqa
|
||||
from .ec2 import mock_ec2, mock_ec2_deprecated # noqa
|
||||
from .ec2_instance_connect import mock_ec2_instance_connect # noqa
|
||||
from .ecr import mock_ecr, mock_ecr_deprecated # noqa
|
||||
|
|
|
|||
|
|
@ -119,3 +119,57 @@ class ApiKeyAlreadyExists(RESTError):
|
|||
super(ApiKeyAlreadyExists, self).__init__(
|
||||
"ConflictException", "API Key already exists"
|
||||
)
|
||||
|
||||
|
||||
class InvalidDomainName(BadRequestException):
|
||||
code = 404
|
||||
|
||||
def __init__(self):
|
||||
super(InvalidDomainName, self).__init__(
|
||||
"BadRequestException", "No Domain Name specified"
|
||||
)
|
||||
|
||||
|
||||
class DomainNameNotFound(RESTError):
|
||||
code = 404
|
||||
|
||||
def __init__(self):
|
||||
super(DomainNameNotFound, self).__init__(
|
||||
"NotFoundException", "Invalid Domain Name specified"
|
||||
)
|
||||
|
||||
|
||||
class InvalidRestApiId(BadRequestException):
|
||||
code = 404
|
||||
|
||||
def __init__(self):
|
||||
super(InvalidRestApiId, self).__init__(
|
||||
"BadRequestException", "No Rest API Id specified"
|
||||
)
|
||||
|
||||
|
||||
class InvalidModelName(BadRequestException):
|
||||
code = 404
|
||||
|
||||
def __init__(self):
|
||||
super(InvalidModelName, self).__init__(
|
||||
"BadRequestException", "No Model Name specified"
|
||||
)
|
||||
|
||||
|
||||
class RestAPINotFound(RESTError):
|
||||
code = 404
|
||||
|
||||
def __init__(self):
|
||||
super(RestAPINotFound, self).__init__(
|
||||
"NotFoundException", "Invalid Rest API Id specified"
|
||||
)
|
||||
|
||||
|
||||
class ModelNotFound(RESTError):
|
||||
code = 404
|
||||
|
||||
def __init__(self):
|
||||
super(ModelNotFound, self).__init__(
|
||||
"NotFoundException", "Invalid Model Name specified"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -34,6 +34,12 @@ from .exceptions import (
|
|||
NoIntegrationDefined,
|
||||
NoMethodDefined,
|
||||
ApiKeyAlreadyExists,
|
||||
DomainNameNotFound,
|
||||
InvalidDomainName,
|
||||
InvalidRestApiId,
|
||||
InvalidModelName,
|
||||
RestAPINotFound,
|
||||
ModelNotFound,
|
||||
)
|
||||
|
||||
STAGE_URL = "https://{api_id}.execute-api.{region_name}.amazonaws.com/{stage_name}"
|
||||
|
|
@ -455,6 +461,7 @@ class RestAPI(BaseModel):
|
|||
self.description = description
|
||||
self.create_date = int(time.time())
|
||||
self.api_key_source = kwargs.get("api_key_source") or "HEADER"
|
||||
self.policy = kwargs.get("policy") or None
|
||||
self.endpoint_configuration = kwargs.get("endpoint_configuration") or {
|
||||
"types": ["EDGE"]
|
||||
}
|
||||
|
|
@ -463,8 +470,8 @@ class RestAPI(BaseModel):
|
|||
self.deployments = {}
|
||||
self.authorizers = {}
|
||||
self.stages = {}
|
||||
|
||||
self.resources = {}
|
||||
self.models = {}
|
||||
self.add_child("/") # Add default child
|
||||
|
||||
def __repr__(self):
|
||||
|
|
@ -479,6 +486,7 @@ class RestAPI(BaseModel):
|
|||
"apiKeySource": self.api_key_source,
|
||||
"endpointConfiguration": self.endpoint_configuration,
|
||||
"tags": self.tags,
|
||||
"policy": self.policy,
|
||||
}
|
||||
|
||||
def add_child(self, path, parent_id=None):
|
||||
|
|
@ -493,6 +501,29 @@ class RestAPI(BaseModel):
|
|||
self.resources[child_id] = child
|
||||
return child
|
||||
|
||||
def add_model(
|
||||
self,
|
||||
name,
|
||||
description=None,
|
||||
schema=None,
|
||||
content_type=None,
|
||||
cli_input_json=None,
|
||||
generate_cli_skeleton=None,
|
||||
):
|
||||
model_id = create_id()
|
||||
new_model = Model(
|
||||
id=model_id,
|
||||
name=name,
|
||||
description=description,
|
||||
schema=schema,
|
||||
content_type=content_type,
|
||||
cli_input_json=cli_input_json,
|
||||
generate_cli_skeleton=generate_cli_skeleton,
|
||||
)
|
||||
|
||||
self.models[name] = new_model
|
||||
return new_model
|
||||
|
||||
def get_resource_for_path(self, path_after_stage_name):
|
||||
for resource in self.resources.values():
|
||||
if resource.get_path() == path_after_stage_name:
|
||||
|
|
@ -609,6 +640,58 @@ class RestAPI(BaseModel):
|
|||
return self.deployments.pop(deployment_id)
|
||||
|
||||
|
||||
class DomainName(BaseModel, dict):
|
||||
def __init__(self, domain_name, **kwargs):
|
||||
super(DomainName, self).__init__()
|
||||
self["domainName"] = domain_name
|
||||
self["regionalDomainName"] = domain_name
|
||||
self["distributionDomainName"] = domain_name
|
||||
self["domainNameStatus"] = "AVAILABLE"
|
||||
self["domainNameStatusMessage"] = "Domain Name Available"
|
||||
self["regionalHostedZoneId"] = "Z2FDTNDATAQYW2"
|
||||
self["distributionHostedZoneId"] = "Z2FDTNDATAQYW2"
|
||||
self["certificateUploadDate"] = int(time.time())
|
||||
if kwargs.get("certificate_name"):
|
||||
self["certificateName"] = kwargs.get("certificate_name")
|
||||
if kwargs.get("certificate_arn"):
|
||||
self["certificateArn"] = kwargs.get("certificate_arn")
|
||||
if kwargs.get("certificate_body"):
|
||||
self["certificateBody"] = kwargs.get("certificate_body")
|
||||
if kwargs.get("tags"):
|
||||
self["tags"] = kwargs.get("tags")
|
||||
if kwargs.get("security_policy"):
|
||||
self["securityPolicy"] = kwargs.get("security_policy")
|
||||
if kwargs.get("certificate_chain"):
|
||||
self["certificateChain"] = kwargs.get("certificate_chain")
|
||||
if kwargs.get("regional_certificate_name"):
|
||||
self["regionalCertificateName"] = kwargs.get("regional_certificate_name")
|
||||
if kwargs.get("certificate_private_key"):
|
||||
self["certificatePrivateKey"] = kwargs.get("certificate_private_key")
|
||||
if kwargs.get("regional_certificate_arn"):
|
||||
self["regionalCertificateArn"] = kwargs.get("regional_certificate_arn")
|
||||
if kwargs.get("endpoint_configuration"):
|
||||
self["endpointConfiguration"] = kwargs.get("endpoint_configuration")
|
||||
if kwargs.get("generate_cli_skeleton"):
|
||||
self["generateCliSkeleton"] = kwargs.get("generate_cli_skeleton")
|
||||
|
||||
|
||||
class Model(BaseModel, dict):
|
||||
def __init__(self, id, name, **kwargs):
|
||||
super(Model, self).__init__()
|
||||
self["id"] = id
|
||||
self["name"] = name
|
||||
if kwargs.get("description"):
|
||||
self["description"] = kwargs.get("description")
|
||||
if kwargs.get("schema"):
|
||||
self["schema"] = kwargs.get("schema")
|
||||
if kwargs.get("content_type"):
|
||||
self["contentType"] = kwargs.get("content_type")
|
||||
if kwargs.get("cli_input_json"):
|
||||
self["cliInputJson"] = kwargs.get("cli_input_json")
|
||||
if kwargs.get("generate_cli_skeleton"):
|
||||
self["generateCliSkeleton"] = kwargs.get("generate_cli_skeleton")
|
||||
|
||||
|
||||
class APIGatewayBackend(BaseBackend):
|
||||
def __init__(self, region_name):
|
||||
super(APIGatewayBackend, self).__init__()
|
||||
|
|
@ -616,6 +699,8 @@ class APIGatewayBackend(BaseBackend):
|
|||
self.keys = {}
|
||||
self.usage_plans = {}
|
||||
self.usage_plan_keys = {}
|
||||
self.domain_names = {}
|
||||
self.models = {}
|
||||
self.region_name = region_name
|
||||
|
||||
def reset(self):
|
||||
|
|
@ -630,6 +715,7 @@ class APIGatewayBackend(BaseBackend):
|
|||
api_key_source=None,
|
||||
endpoint_configuration=None,
|
||||
tags=None,
|
||||
policy=None,
|
||||
):
|
||||
api_id = create_id()
|
||||
rest_api = RestAPI(
|
||||
|
|
@ -640,12 +726,15 @@ class APIGatewayBackend(BaseBackend):
|
|||
api_key_source=api_key_source,
|
||||
endpoint_configuration=endpoint_configuration,
|
||||
tags=tags,
|
||||
policy=policy,
|
||||
)
|
||||
self.apis[api_id] = rest_api
|
||||
return rest_api
|
||||
|
||||
def get_rest_api(self, function_id):
|
||||
rest_api = self.apis[function_id]
|
||||
rest_api = self.apis.get(function_id)
|
||||
if rest_api is None:
|
||||
raise RestAPINotFound()
|
||||
return rest_api
|
||||
|
||||
def list_apis(self):
|
||||
|
|
@ -1001,6 +1090,98 @@ class APIGatewayBackend(BaseBackend):
|
|||
except Exception:
|
||||
return False
|
||||
|
||||
def create_domain_name(
|
||||
self,
|
||||
domain_name,
|
||||
certificate_name=None,
|
||||
tags=None,
|
||||
certificate_arn=None,
|
||||
certificate_body=None,
|
||||
certificate_private_key=None,
|
||||
certificate_chain=None,
|
||||
regional_certificate_name=None,
|
||||
regional_certificate_arn=None,
|
||||
endpoint_configuration=None,
|
||||
security_policy=None,
|
||||
generate_cli_skeleton=None,
|
||||
):
|
||||
|
||||
if not domain_name:
|
||||
raise InvalidDomainName()
|
||||
|
||||
new_domain_name = DomainName(
|
||||
domain_name=domain_name,
|
||||
certificate_name=certificate_name,
|
||||
certificate_private_key=certificate_private_key,
|
||||
certificate_arn=certificate_arn,
|
||||
certificate_body=certificate_body,
|
||||
certificate_chain=certificate_chain,
|
||||
regional_certificate_name=regional_certificate_name,
|
||||
regional_certificate_arn=regional_certificate_arn,
|
||||
endpoint_configuration=endpoint_configuration,
|
||||
tags=tags,
|
||||
security_policy=security_policy,
|
||||
generate_cli_skeleton=generate_cli_skeleton,
|
||||
)
|
||||
|
||||
self.domain_names[domain_name] = new_domain_name
|
||||
return new_domain_name
|
||||
|
||||
def get_domain_names(self):
|
||||
return list(self.domain_names.values())
|
||||
|
||||
def get_domain_name(self, domain_name):
|
||||
domain_info = self.domain_names.get(domain_name)
|
||||
if domain_info is None:
|
||||
raise DomainNameNotFound
|
||||
else:
|
||||
return self.domain_names[domain_name]
|
||||
|
||||
def create_model(
|
||||
self,
|
||||
rest_api_id,
|
||||
name,
|
||||
content_type,
|
||||
description=None,
|
||||
schema=None,
|
||||
cli_input_json=None,
|
||||
generate_cli_skeleton=None,
|
||||
):
|
||||
|
||||
if not rest_api_id:
|
||||
raise InvalidRestApiId
|
||||
if not name:
|
||||
raise InvalidModelName
|
||||
|
||||
api = self.get_rest_api(rest_api_id)
|
||||
new_model = api.add_model(
|
||||
name=name,
|
||||
description=description,
|
||||
schema=schema,
|
||||
content_type=content_type,
|
||||
cli_input_json=cli_input_json,
|
||||
generate_cli_skeleton=generate_cli_skeleton,
|
||||
)
|
||||
|
||||
return new_model
|
||||
|
||||
def get_models(self, rest_api_id):
|
||||
if not rest_api_id:
|
||||
raise InvalidRestApiId
|
||||
api = self.get_rest_api(rest_api_id)
|
||||
models = api.models.values()
|
||||
return list(models)
|
||||
|
||||
def get_model(self, rest_api_id, model_name):
|
||||
if not rest_api_id:
|
||||
raise InvalidRestApiId
|
||||
api = self.get_rest_api(rest_api_id)
|
||||
model = api.models.get(model_name)
|
||||
if model is None:
|
||||
raise ModelNotFound
|
||||
else:
|
||||
return model
|
||||
|
||||
|
||||
apigateway_backends = {}
|
||||
for region_name in Session().get_available_regions("apigateway"):
|
||||
|
|
|
|||
|
|
@ -11,6 +11,12 @@ from .exceptions import (
|
|||
AuthorizerNotFoundException,
|
||||
StageNotFoundException,
|
||||
ApiKeyAlreadyExists,
|
||||
DomainNameNotFound,
|
||||
InvalidDomainName,
|
||||
InvalidRestApiId,
|
||||
InvalidModelName,
|
||||
RestAPINotFound,
|
||||
ModelNotFound,
|
||||
)
|
||||
|
||||
API_KEY_SOURCES = ["AUTHORIZER", "HEADER"]
|
||||
|
|
@ -53,6 +59,7 @@ class APIGatewayResponse(BaseResponse):
|
|||
api_key_source = self._get_param("apiKeySource")
|
||||
endpoint_configuration = self._get_param("endpointConfiguration")
|
||||
tags = self._get_param("tags")
|
||||
policy = self._get_param("policy")
|
||||
|
||||
# Param validation
|
||||
if api_key_source and api_key_source not in API_KEY_SOURCES:
|
||||
|
|
@ -88,6 +95,7 @@ class APIGatewayResponse(BaseResponse):
|
|||
api_key_source=api_key_source,
|
||||
endpoint_configuration=endpoint_configuration,
|
||||
tags=tags,
|
||||
policy=policy,
|
||||
)
|
||||
return 200, {}, json.dumps(rest_api.to_dict())
|
||||
|
||||
|
|
@ -527,3 +535,130 @@ class APIGatewayResponse(BaseResponse):
|
|||
usage_plan_id, key_id
|
||||
)
|
||||
return 200, {}, json.dumps(usage_plan_response)
|
||||
|
||||
def domain_names(self, request, full_url, headers):
|
||||
self.setup_class(request, full_url, headers)
|
||||
|
||||
try:
|
||||
if self.method == "GET":
|
||||
domain_names = self.backend.get_domain_names()
|
||||
return 200, {}, json.dumps({"item": domain_names})
|
||||
|
||||
elif self.method == "POST":
|
||||
domain_name = self._get_param("domainName")
|
||||
certificate_name = self._get_param("certificateName")
|
||||
tags = self._get_param("tags")
|
||||
certificate_arn = self._get_param("certificateArn")
|
||||
certificate_body = self._get_param("certificateBody")
|
||||
certificate_private_key = self._get_param("certificatePrivateKey")
|
||||
certificate_chain = self._get_param("certificateChain")
|
||||
regional_certificate_name = self._get_param("regionalCertificateName")
|
||||
regional_certificate_arn = self._get_param("regionalCertificateArn")
|
||||
endpoint_configuration = self._get_param("endpointConfiguration")
|
||||
security_policy = self._get_param("securityPolicy")
|
||||
generate_cli_skeleton = self._get_param("generateCliSkeleton")
|
||||
domain_name_resp = self.backend.create_domain_name(
|
||||
domain_name,
|
||||
certificate_name,
|
||||
tags,
|
||||
certificate_arn,
|
||||
certificate_body,
|
||||
certificate_private_key,
|
||||
certificate_chain,
|
||||
regional_certificate_name,
|
||||
regional_certificate_arn,
|
||||
endpoint_configuration,
|
||||
security_policy,
|
||||
generate_cli_skeleton,
|
||||
)
|
||||
return 200, {}, json.dumps(domain_name_resp)
|
||||
|
||||
except InvalidDomainName as error:
|
||||
return (
|
||||
error.code,
|
||||
{},
|
||||
'{{"message":"{0}","code":"{1}"}}'.format(
|
||||
error.message, error.error_type
|
||||
),
|
||||
)
|
||||
|
||||
def domain_name_induvidual(self, request, full_url, headers):
|
||||
self.setup_class(request, full_url, headers)
|
||||
|
||||
url_path_parts = self.path.split("/")
|
||||
domain_name = url_path_parts[2]
|
||||
domain_names = {}
|
||||
try:
|
||||
if self.method == "GET":
|
||||
if domain_name is not None:
|
||||
domain_names = self.backend.get_domain_name(domain_name)
|
||||
return 200, {}, json.dumps(domain_names)
|
||||
except DomainNameNotFound as error:
|
||||
return (
|
||||
error.code,
|
||||
{},
|
||||
'{{"message":"{0}","code":"{1}"}}'.format(
|
||||
error.message, error.error_type
|
||||
),
|
||||
)
|
||||
|
||||
def models(self, request, full_url, headers):
|
||||
self.setup_class(request, full_url, headers)
|
||||
rest_api_id = self.path.replace("/restapis/", "", 1).split("/")[0]
|
||||
|
||||
try:
|
||||
if self.method == "GET":
|
||||
models = self.backend.get_models(rest_api_id)
|
||||
return 200, {}, json.dumps({"item": models})
|
||||
|
||||
elif self.method == "POST":
|
||||
name = self._get_param("name")
|
||||
description = self._get_param("description")
|
||||
schema = self._get_param("schema")
|
||||
content_type = self._get_param("contentType")
|
||||
cli_input_json = self._get_param("cliInputJson")
|
||||
generate_cli_skeleton = self._get_param("generateCliSkeleton")
|
||||
model = self.backend.create_model(
|
||||
rest_api_id,
|
||||
name,
|
||||
content_type,
|
||||
description,
|
||||
schema,
|
||||
cli_input_json,
|
||||
generate_cli_skeleton,
|
||||
)
|
||||
|
||||
return 200, {}, json.dumps(model)
|
||||
|
||||
except (InvalidRestApiId, InvalidModelName, RestAPINotFound) as error:
|
||||
return (
|
||||
error.code,
|
||||
{},
|
||||
'{{"message":"{0}","code":"{1}"}}'.format(
|
||||
error.message, error.error_type
|
||||
),
|
||||
)
|
||||
|
||||
def model_induvidual(self, request, full_url, headers):
|
||||
self.setup_class(request, full_url, headers)
|
||||
url_path_parts = self.path.split("/")
|
||||
rest_api_id = url_path_parts[2]
|
||||
model_name = url_path_parts[4]
|
||||
model_info = {}
|
||||
try:
|
||||
if self.method == "GET":
|
||||
model_info = self.backend.get_model(rest_api_id, model_name)
|
||||
return 200, {}, json.dumps(model_info)
|
||||
except (
|
||||
ModelNotFound,
|
||||
RestAPINotFound,
|
||||
InvalidRestApiId,
|
||||
InvalidModelName,
|
||||
) as error:
|
||||
return (
|
||||
error.code,
|
||||
{},
|
||||
'{{"message":"{0}","code":"{1}"}}'.format(
|
||||
error.message, error.error_type
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -21,6 +21,10 @@ url_paths = {
|
|||
"{0}/apikeys$": APIGatewayResponse().apikeys,
|
||||
"{0}/apikeys/(?P<apikey>[^/]+)": APIGatewayResponse().apikey_individual,
|
||||
"{0}/usageplans$": APIGatewayResponse().usage_plans,
|
||||
"{0}/domainnames$": APIGatewayResponse().domain_names,
|
||||
"{0}/restapis/(?P<function_id>[^/]+)/models$": APIGatewayResponse().models,
|
||||
"{0}/restapis/(?P<function_id>[^/]+)/models/(?P<model_name>[^/]+)/?$": APIGatewayResponse().model_induvidual,
|
||||
"{0}/domainnames/(?P<domain_name>[^/]+)/?$": APIGatewayResponse().domain_name_induvidual,
|
||||
"{0}/usageplans/(?P<usage_plan_id>[^/]+)/?$": APIGatewayResponse().usage_plan_individual,
|
||||
"{0}/usageplans/(?P<usage_plan_id>[^/]+)/keys$": APIGatewayResponse().usage_plan_keys,
|
||||
"{0}/usageplans/(?P<usage_plan_id>[^/]+)/keys/(?P<api_key_id>[^/]+)/?$": APIGatewayResponse().usage_plan_key_individual,
|
||||
|
|
|
|||
|
|
@ -267,6 +267,9 @@ class FakeAutoScalingGroup(BaseModel):
|
|||
self.tags = tags if tags else []
|
||||
self.set_desired_capacity(desired_capacity)
|
||||
|
||||
def active_instances(self):
|
||||
return [x for x in self.instance_states if x.lifecycle_state == "InService"]
|
||||
|
||||
def _set_azs_and_vpcs(self, availability_zones, vpc_zone_identifier, update=False):
|
||||
# for updates, if only AZs are provided, they must not clash with
|
||||
# the AZs of existing VPCs
|
||||
|
|
@ -413,9 +416,11 @@ class FakeAutoScalingGroup(BaseModel):
|
|||
else:
|
||||
self.desired_capacity = new_capacity
|
||||
|
||||
curr_instance_count = len(self.instance_states)
|
||||
curr_instance_count = len(self.active_instances())
|
||||
|
||||
if self.desired_capacity == curr_instance_count:
|
||||
self.autoscaling_backend.update_attached_elbs(self.name)
|
||||
self.autoscaling_backend.update_attached_target_groups(self.name)
|
||||
return
|
||||
|
||||
if self.desired_capacity > curr_instance_count:
|
||||
|
|
@ -442,6 +447,8 @@ class FakeAutoScalingGroup(BaseModel):
|
|||
self.instance_states = list(
|
||||
set(self.instance_states) - set(instances_to_remove)
|
||||
)
|
||||
self.autoscaling_backend.update_attached_elbs(self.name)
|
||||
self.autoscaling_backend.update_attached_target_groups(self.name)
|
||||
|
||||
def get_propagated_tags(self):
|
||||
propagated_tags = {}
|
||||
|
|
@ -655,10 +662,16 @@ class AutoScalingBackend(BaseBackend):
|
|||
self.set_desired_capacity(group_name, 0)
|
||||
self.autoscaling_groups.pop(group_name, None)
|
||||
|
||||
def describe_auto_scaling_instances(self):
|
||||
def describe_auto_scaling_instances(self, instance_ids):
|
||||
instance_states = []
|
||||
for group in self.autoscaling_groups.values():
|
||||
instance_states.extend(group.instance_states)
|
||||
instance_states.extend(
|
||||
[
|
||||
x
|
||||
for x in group.instance_states
|
||||
if not instance_ids or x.instance.id in instance_ids
|
||||
]
|
||||
)
|
||||
return instance_states
|
||||
|
||||
def attach_instances(self, group_name, instance_ids):
|
||||
|
|
@ -697,7 +710,7 @@ class AutoScalingBackend(BaseBackend):
|
|||
|
||||
def detach_instances(self, group_name, instance_ids, should_decrement):
|
||||
group = self.autoscaling_groups[group_name]
|
||||
original_size = len(group.instance_states)
|
||||
original_size = group.desired_capacity
|
||||
|
||||
detached_instances = [
|
||||
x for x in group.instance_states if x.instance.id in instance_ids
|
||||
|
|
@ -714,13 +727,8 @@ class AutoScalingBackend(BaseBackend):
|
|||
|
||||
if should_decrement:
|
||||
group.desired_capacity = original_size - len(instance_ids)
|
||||
else:
|
||||
count_needed = len(instance_ids)
|
||||
group.replace_autoscaling_group_instances(
|
||||
count_needed, group.get_propagated_tags()
|
||||
)
|
||||
|
||||
self.update_attached_elbs(group_name)
|
||||
group.set_desired_capacity(group.desired_capacity)
|
||||
return detached_instances
|
||||
|
||||
def set_desired_capacity(self, group_name, desired_capacity):
|
||||
|
|
@ -785,7 +793,9 @@ class AutoScalingBackend(BaseBackend):
|
|||
|
||||
def update_attached_elbs(self, group_name):
|
||||
group = self.autoscaling_groups[group_name]
|
||||
group_instance_ids = set(state.instance.id for state in group.instance_states)
|
||||
group_instance_ids = set(
|
||||
state.instance.id for state in group.active_instances()
|
||||
)
|
||||
|
||||
# skip this if group.load_balancers is empty
|
||||
# otherwise elb_backend.describe_load_balancers returns all available load balancers
|
||||
|
|
@ -902,15 +912,15 @@ class AutoScalingBackend(BaseBackend):
|
|||
autoscaling_group_name,
|
||||
autoscaling_group,
|
||||
) in self.autoscaling_groups.items():
|
||||
original_instance_count = len(autoscaling_group.instance_states)
|
||||
original_active_instance_count = len(autoscaling_group.active_instances())
|
||||
autoscaling_group.instance_states = list(
|
||||
filter(
|
||||
lambda i_state: i_state.instance.id not in instance_ids,
|
||||
autoscaling_group.instance_states,
|
||||
)
|
||||
)
|
||||
difference = original_instance_count - len(
|
||||
autoscaling_group.instance_states
|
||||
difference = original_active_instance_count - len(
|
||||
autoscaling_group.active_instances()
|
||||
)
|
||||
if difference > 0:
|
||||
autoscaling_group.replace_autoscaling_group_instances(
|
||||
|
|
@ -918,6 +928,45 @@ class AutoScalingBackend(BaseBackend):
|
|||
)
|
||||
self.update_attached_elbs(autoscaling_group_name)
|
||||
|
||||
def enter_standby_instances(self, group_name, instance_ids, should_decrement):
|
||||
group = self.autoscaling_groups[group_name]
|
||||
original_size = group.desired_capacity
|
||||
standby_instances = []
|
||||
for instance_state in group.instance_states:
|
||||
if instance_state.instance.id in instance_ids:
|
||||
instance_state.lifecycle_state = "Standby"
|
||||
standby_instances.append(instance_state)
|
||||
if should_decrement:
|
||||
group.desired_capacity = group.desired_capacity - len(instance_ids)
|
||||
else:
|
||||
group.set_desired_capacity(group.desired_capacity)
|
||||
return standby_instances, original_size, group.desired_capacity
|
||||
|
||||
def exit_standby_instances(self, group_name, instance_ids):
|
||||
group = self.autoscaling_groups[group_name]
|
||||
original_size = group.desired_capacity
|
||||
standby_instances = []
|
||||
for instance_state in group.instance_states:
|
||||
if instance_state.instance.id in instance_ids:
|
||||
instance_state.lifecycle_state = "InService"
|
||||
standby_instances.append(instance_state)
|
||||
group.desired_capacity = group.desired_capacity + len(instance_ids)
|
||||
return standby_instances, original_size, group.desired_capacity
|
||||
|
||||
def terminate_instance(self, instance_id, should_decrement):
|
||||
instance = self.ec2_backend.get_instance(instance_id)
|
||||
instance_state = next(
|
||||
instance_state
|
||||
for group in self.autoscaling_groups.values()
|
||||
for instance_state in group.instance_states
|
||||
if instance_state.instance.id == instance.id
|
||||
)
|
||||
group = instance.autoscaling_group
|
||||
original_size = group.desired_capacity
|
||||
self.detach_instances(group.name, [instance.id], should_decrement)
|
||||
self.ec2_backend.terminate_instances([instance.id])
|
||||
return instance_state, original_size, group.desired_capacity
|
||||
|
||||
|
||||
autoscaling_backends = {}
|
||||
for region, ec2_backend in ec2_backends.items():
|
||||
|
|
|
|||
|
|
@ -1,7 +1,12 @@
|
|||
from __future__ import unicode_literals
|
||||
import datetime
|
||||
|
||||
from moto.core.responses import BaseResponse
|
||||
from moto.core.utils import amz_crc32, amzn_request_id
|
||||
from moto.core.utils import (
|
||||
amz_crc32,
|
||||
amzn_request_id,
|
||||
iso_8601_datetime_with_milliseconds,
|
||||
)
|
||||
from .models import autoscaling_backends
|
||||
|
||||
|
||||
|
|
@ -226,7 +231,9 @@ class AutoScalingResponse(BaseResponse):
|
|||
return template.render()
|
||||
|
||||
def describe_auto_scaling_instances(self):
|
||||
instance_states = self.autoscaling_backend.describe_auto_scaling_instances()
|
||||
instance_states = self.autoscaling_backend.describe_auto_scaling_instances(
|
||||
instance_ids=self._get_multi_param("InstanceIds.member")
|
||||
)
|
||||
template = self.response_template(DESCRIBE_AUTOSCALING_INSTANCES_TEMPLATE)
|
||||
return template.render(instance_states=instance_states)
|
||||
|
||||
|
|
@ -289,6 +296,50 @@ class AutoScalingResponse(BaseResponse):
|
|||
template = self.response_template(DETACH_LOAD_BALANCERS_TEMPLATE)
|
||||
return template.render()
|
||||
|
||||
@amz_crc32
|
||||
@amzn_request_id
|
||||
def enter_standby(self):
|
||||
group_name = self._get_param("AutoScalingGroupName")
|
||||
instance_ids = self._get_multi_param("InstanceIds.member")
|
||||
should_decrement_string = self._get_param("ShouldDecrementDesiredCapacity")
|
||||
if should_decrement_string == "true":
|
||||
should_decrement = True
|
||||
else:
|
||||
should_decrement = False
|
||||
(
|
||||
standby_instances,
|
||||
original_size,
|
||||
desired_capacity,
|
||||
) = self.autoscaling_backend.enter_standby_instances(
|
||||
group_name, instance_ids, should_decrement
|
||||
)
|
||||
template = self.response_template(ENTER_STANDBY_TEMPLATE)
|
||||
return template.render(
|
||||
standby_instances=standby_instances,
|
||||
should_decrement=should_decrement,
|
||||
original_size=original_size,
|
||||
desired_capacity=desired_capacity,
|
||||
timestamp=iso_8601_datetime_with_milliseconds(datetime.datetime.utcnow()),
|
||||
)
|
||||
|
||||
@amz_crc32
|
||||
@amzn_request_id
|
||||
def exit_standby(self):
|
||||
group_name = self._get_param("AutoScalingGroupName")
|
||||
instance_ids = self._get_multi_param("InstanceIds.member")
|
||||
(
|
||||
standby_instances,
|
||||
original_size,
|
||||
desired_capacity,
|
||||
) = self.autoscaling_backend.exit_standby_instances(group_name, instance_ids)
|
||||
template = self.response_template(EXIT_STANDBY_TEMPLATE)
|
||||
return template.render(
|
||||
standby_instances=standby_instances,
|
||||
original_size=original_size,
|
||||
desired_capacity=desired_capacity,
|
||||
timestamp=iso_8601_datetime_with_milliseconds(datetime.datetime.utcnow()),
|
||||
)
|
||||
|
||||
def suspend_processes(self):
|
||||
autoscaling_group_name = self._get_param("AutoScalingGroupName")
|
||||
scaling_processes = self._get_multi_param("ScalingProcesses.member")
|
||||
|
|
@ -308,6 +359,29 @@ class AutoScalingResponse(BaseResponse):
|
|||
template = self.response_template(SET_INSTANCE_PROTECTION_TEMPLATE)
|
||||
return template.render()
|
||||
|
||||
@amz_crc32
|
||||
@amzn_request_id
|
||||
def terminate_instance_in_auto_scaling_group(self):
|
||||
instance_id = self._get_param("InstanceId")
|
||||
should_decrement_string = self._get_param("ShouldDecrementDesiredCapacity")
|
||||
if should_decrement_string == "true":
|
||||
should_decrement = True
|
||||
else:
|
||||
should_decrement = False
|
||||
(
|
||||
instance,
|
||||
original_size,
|
||||
desired_capacity,
|
||||
) = self.autoscaling_backend.terminate_instance(instance_id, should_decrement)
|
||||
template = self.response_template(TERMINATE_INSTANCES_TEMPLATE)
|
||||
return template.render(
|
||||
instance=instance,
|
||||
should_decrement=should_decrement,
|
||||
original_size=original_size,
|
||||
desired_capacity=desired_capacity,
|
||||
timestamp=iso_8601_datetime_with_milliseconds(datetime.datetime.utcnow()),
|
||||
)
|
||||
|
||||
|
||||
CREATE_LAUNCH_CONFIGURATION_TEMPLATE = """<CreateLaunchConfigurationResponse xmlns="http://autoscaling.amazonaws.com/doc/2011-01-01/">
|
||||
<ResponseMetadata>
|
||||
|
|
@ -705,3 +779,73 @@ SET_INSTANCE_PROTECTION_TEMPLATE = """<SetInstanceProtectionResponse xmlns="http
|
|||
<RequestId></RequestId>
|
||||
</ResponseMetadata>
|
||||
</SetInstanceProtectionResponse>"""
|
||||
|
||||
ENTER_STANDBY_TEMPLATE = """<EnterStandbyResponse xmlns="http://autoscaling.amazonaws.com/doc/2011-01-01/">
|
||||
<EnterStandbyResult>
|
||||
<Activities>
|
||||
{% for instance in standby_instances %}
|
||||
<member>
|
||||
<ActivityId>12345678-1234-1234-1234-123456789012</ActivityId>
|
||||
<AutoScalingGroupName>{{ group_name }}</AutoScalingGroupName>
|
||||
{% if should_decrement %}
|
||||
<Cause>At {{ timestamp }} instance {{ instance.instance.id }} was moved to standby in response to a user request, shrinking the capacity from {{ original_size }} to {{ desired_capacity }}.</Cause>
|
||||
{% else %}
|
||||
<Cause>At {{ timestamp }} instance {{ instance.instance.id }} was moved to standby in response to a user request.</Cause>
|
||||
{% endif %}
|
||||
<Description>Moving EC2 instance to Standby: {{ instance.instance.id }}</Description>
|
||||
<Progress>50</Progress>
|
||||
<StartTime>{{ timestamp }}</StartTime>
|
||||
<Details>{"Subnet ID":"??","Availability Zone":"{{ instance.instance.placement }}"}</Details>
|
||||
<StatusCode>InProgress</StatusCode>
|
||||
</member>
|
||||
{% endfor %}
|
||||
</Activities>
|
||||
</EnterStandbyResult>
|
||||
<ResponseMetadata>
|
||||
<RequestId>7c6e177f-f082-11e1-ac58-3714bEXAMPLE</RequestId>
|
||||
</ResponseMetadata>
|
||||
</EnterStandbyResponse>"""
|
||||
|
||||
EXIT_STANDBY_TEMPLATE = """<ExitStandbyResponse xmlns="http://autoscaling.amazonaws.com/doc/2011-01-01/">
|
||||
<ExitStandbyResult>
|
||||
<Activities>
|
||||
{% for instance in standby_instances %}
|
||||
<member>
|
||||
<ActivityId>12345678-1234-1234-1234-123456789012</ActivityId>
|
||||
<AutoScalingGroupName>{{ group_name }}</AutoScalingGroupName>
|
||||
<Description>Moving EC2 instance out of Standby: {{ instance.instance.id }}</Description>
|
||||
<Progress>30</Progress>
|
||||
<Cause>At {{ timestamp }} instance {{ instance.instance.id }} was moved out of standby in response to a user request, increasing the capacity from {{ original_size }} to {{ desired_capacity }}.</Cause>
|
||||
<StartTime>{{ timestamp }}</StartTime>
|
||||
<Details>{"Subnet ID":"??","Availability Zone":"{{ instance.instance.placement }}"}</Details>
|
||||
<StatusCode>PreInService</StatusCode>
|
||||
</member>
|
||||
{% endfor %}
|
||||
</Activities>
|
||||
</ExitStandbyResult>
|
||||
<ResponseMetadata>
|
||||
<RequestId>7c6e177f-f082-11e1-ac58-3714bEXAMPLE</RequestId>
|
||||
</ResponseMetadata>
|
||||
</ExitStandbyResponse>"""
|
||||
|
||||
TERMINATE_INSTANCES_TEMPLATE = """<TerminateInstanceInAutoScalingGroupResponse xmlns="http://autoscaling.amazonaws.com/doc/2011-01-01/">
|
||||
<TerminateInstanceInAutoScalingGroupResult>
|
||||
<Activity>
|
||||
<ActivityId>35b5c464-0b63-2fc7-1611-467d4a7f2497EXAMPLE</ActivityId>
|
||||
<AutoScalingGroupName>{{ group_name }}</AutoScalingGroupName>
|
||||
{% if should_decrement %}
|
||||
<Cause>At {{ timestamp }} instance {{ instance.instance.id }} was taken out of service in response to a user request, shrinking the capacity from {{ original_size }} to {{ desired_capacity }}.</Cause>
|
||||
{% else %}
|
||||
<Cause>At {{ timestamp }} instance {{ instance.instance.id }} was taken out of service in response to a user request.</Cause>
|
||||
{% endif %}
|
||||
<Description>Terminating EC2 instance: {{ instance.instance.id }}</Description>
|
||||
<Progress>0</Progress>
|
||||
<StartTime>{{ timestamp }}</StartTime>
|
||||
<Details>{"Subnet ID":"??","Availability Zone":"{{ instance.instance.placement }}"}</Details>
|
||||
<StatusCode>InProgress</StatusCode>
|
||||
</Activity>
|
||||
</TerminateInstanceInAutoScalingGroupResult>
|
||||
<ResponseMetadata>
|
||||
<RequestId>a1ba8fb9-31d6-4d9a-ace1-a7f76749df11EXAMPLE</RequestId>
|
||||
</ResponseMetadata>
|
||||
</TerminateInstanceInAutoScalingGroupResponse>"""
|
||||
|
|
|
|||
|
|
@ -1006,11 +1006,11 @@ class LambdaBackend(BaseBackend):
|
|||
return True
|
||||
return False
|
||||
|
||||
def add_policy_statement(self, function_name, raw):
|
||||
def add_permission(self, function_name, raw):
|
||||
fn = self.get_function(function_name)
|
||||
fn.policy.add_statement(raw)
|
||||
|
||||
def del_policy_statement(self, function_name, sid, revision=""):
|
||||
def remove_permission(self, function_name, sid, revision=""):
|
||||
fn = self.get_function(function_name)
|
||||
fn.policy.del_statement(sid, revision)
|
||||
|
||||
|
|
|
|||
|
|
@ -146,7 +146,7 @@ class LambdaResponse(BaseResponse):
|
|||
function_name = path.split("/")[-2]
|
||||
if self.lambda_backend.get_function(function_name):
|
||||
statement = self.body
|
||||
self.lambda_backend.add_policy_statement(function_name, statement)
|
||||
self.lambda_backend.add_permission(function_name, statement)
|
||||
return 200, {}, json.dumps({"Statement": statement})
|
||||
else:
|
||||
return 404, {}, "{}"
|
||||
|
|
@ -166,9 +166,7 @@ class LambdaResponse(BaseResponse):
|
|||
statement_id = path.split("/")[-1].split("?")[0]
|
||||
revision = querystring.get("RevisionId", "")
|
||||
if self.lambda_backend.get_function(function_name):
|
||||
self.lambda_backend.del_policy_statement(
|
||||
function_name, statement_id, revision
|
||||
)
|
||||
self.lambda_backend.remove_permission(function_name, statement_id, revision)
|
||||
return 204, {}, "{}"
|
||||
else:
|
||||
return 404, {}, "{}"
|
||||
|
|
@ -184,9 +182,9 @@ class LambdaResponse(BaseResponse):
|
|||
function_name, qualifier, self.body, self.headers, response_headers
|
||||
)
|
||||
if payload:
|
||||
if request.headers["X-Amz-Invocation-Type"] == "Event":
|
||||
if request.headers.get("X-Amz-Invocation-Type") == "Event":
|
||||
status_code = 202
|
||||
elif request.headers["X-Amz-Invocation-Type"] == "DryRun":
|
||||
elif request.headers.get("X-Amz-Invocation-Type") == "DryRun":
|
||||
status_code = 204
|
||||
else:
|
||||
status_code = 200
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ from moto.ec2 import ec2_backends
|
|||
from moto.ec2_instance_connect import ec2_instance_connect_backends
|
||||
from moto.ecr import ecr_backends
|
||||
from moto.ecs import ecs_backends
|
||||
from moto.elasticbeanstalk import eb_backends
|
||||
from moto.elb import elb_backends
|
||||
from moto.elbv2 import elbv2_backends
|
||||
from moto.emr import emr_backends
|
||||
|
|
@ -77,6 +78,7 @@ BACKENDS = {
|
|||
"ec2_instance_connect": ec2_instance_connect_backends,
|
||||
"ecr": ecr_backends,
|
||||
"ecs": ecs_backends,
|
||||
"elasticbeanstalk": eb_backends,
|
||||
"elb": elb_backends,
|
||||
"elbv2": elbv2_backends,
|
||||
"events": events_backends,
|
||||
|
|
|
|||
|
|
@ -239,8 +239,11 @@ class FakeStack(BaseModel):
|
|||
self.cross_stack_resources = cross_stack_resources or {}
|
||||
self.resource_map = self._create_resource_map()
|
||||
self.output_map = self._create_output_map()
|
||||
self._add_stack_event("CREATE_COMPLETE")
|
||||
self.status = "CREATE_COMPLETE"
|
||||
if create_change_set:
|
||||
self.status = "REVIEW_IN_PROGRESS"
|
||||
else:
|
||||
self.create_resources()
|
||||
self._add_stack_event("CREATE_COMPLETE")
|
||||
self.creation_time = datetime.utcnow()
|
||||
|
||||
def _create_resource_map(self):
|
||||
|
|
@ -253,7 +256,7 @@ class FakeStack(BaseModel):
|
|||
self.template_dict,
|
||||
self.cross_stack_resources,
|
||||
)
|
||||
resource_map.create()
|
||||
resource_map.load()
|
||||
return resource_map
|
||||
|
||||
def _create_output_map(self):
|
||||
|
|
@ -326,6 +329,10 @@ class FakeStack(BaseModel):
|
|||
def exports(self):
|
||||
return self.output_map.exports
|
||||
|
||||
def create_resources(self):
|
||||
self.resource_map.create()
|
||||
self.status = "CREATE_COMPLETE"
|
||||
|
||||
def update(self, template, role_arn=None, parameters=None, tags=None):
|
||||
self._add_stack_event(
|
||||
"UPDATE_IN_PROGRESS", resource_status_reason="User Initiated"
|
||||
|
|
@ -640,6 +647,7 @@ class CloudFormationBackend(BaseBackend):
|
|||
else:
|
||||
stack._add_stack_event("UPDATE_IN_PROGRESS")
|
||||
stack._add_stack_event("UPDATE_COMPLETE")
|
||||
stack.create_resources()
|
||||
return True
|
||||
|
||||
def describe_stacks(self, name_or_stack_id):
|
||||
|
|
|
|||
|
|
@ -531,14 +531,16 @@ class ResourceMap(collections_abc.Mapping):
|
|||
for condition_name in self.lazy_condition_map:
|
||||
self.lazy_condition_map[condition_name]
|
||||
|
||||
def create(self):
|
||||
def load(self):
|
||||
self.load_mapping()
|
||||
self.transform_mapping()
|
||||
self.load_parameters()
|
||||
self.load_conditions()
|
||||
|
||||
def create(self):
|
||||
# Since this is a lazy map, to create every object we just need to
|
||||
# iterate through self.
|
||||
# Assumes that self.load() has been called before
|
||||
self.tags.update(
|
||||
{
|
||||
"aws:cloudformation:stack-name": self.get("AWS::StackName"),
|
||||
|
|
|
|||
|
|
@ -22,6 +22,14 @@ class Dimension(object):
|
|||
self.name = name
|
||||
self.value = value
|
||||
|
||||
def __eq__(self, item):
|
||||
if isinstance(item, Dimension):
|
||||
return self.name == item.name and self.value == item.value
|
||||
return False
|
||||
|
||||
def __ne__(self, item): # Only needed on Py2; Py3 defines it implicitly
|
||||
return self != item
|
||||
|
||||
|
||||
def daterange(start, stop, step=timedelta(days=1), inclusive=False):
|
||||
"""
|
||||
|
|
@ -124,6 +132,17 @@ class MetricDatum(BaseModel):
|
|||
Dimension(dimension["Name"], dimension["Value"]) for dimension in dimensions
|
||||
]
|
||||
|
||||
def filter(self, namespace, name, dimensions):
|
||||
if namespace and namespace != self.namespace:
|
||||
return False
|
||||
if name and name != self.name:
|
||||
return False
|
||||
if dimensions and any(
|
||||
Dimension(d["Name"], d["Value"]) not in self.dimensions for d in dimensions
|
||||
):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class Dashboard(BaseModel):
|
||||
def __init__(self, name, body):
|
||||
|
|
@ -202,6 +221,15 @@ class CloudWatchBackend(BaseBackend):
|
|||
self.metric_data = []
|
||||
self.paged_metric_data = {}
|
||||
|
||||
@property
|
||||
# Retrieve a list of all OOTB metrics that are provided by metrics providers
|
||||
# Computed on the fly
|
||||
def aws_metric_data(self):
|
||||
md = []
|
||||
for name, service in metric_providers.items():
|
||||
md.extend(service.get_cloudwatch_metrics())
|
||||
return md
|
||||
|
||||
def put_metric_alarm(
|
||||
self,
|
||||
name,
|
||||
|
|
@ -295,6 +323,43 @@ class CloudWatchBackend(BaseBackend):
|
|||
)
|
||||
)
|
||||
|
||||
def get_metric_data(self, queries, start_time, end_time):
|
||||
period_data = [
|
||||
md for md in self.metric_data if start_time <= md.timestamp <= end_time
|
||||
]
|
||||
results = []
|
||||
for query in queries:
|
||||
query_ns = query["metric_stat._metric._namespace"]
|
||||
query_name = query["metric_stat._metric._metric_name"]
|
||||
query_data = [
|
||||
md
|
||||
for md in period_data
|
||||
if md.namespace == query_ns and md.name == query_name
|
||||
]
|
||||
metric_values = [m.value for m in query_data]
|
||||
result_vals = []
|
||||
stat = query["metric_stat._stat"]
|
||||
if len(metric_values) > 0:
|
||||
if stat == "Average":
|
||||
result_vals.append(sum(metric_values) / len(metric_values))
|
||||
elif stat == "Minimum":
|
||||
result_vals.append(min(metric_values))
|
||||
elif stat == "Maximum":
|
||||
result_vals.append(max(metric_values))
|
||||
elif stat == "Sum":
|
||||
result_vals.append(sum(metric_values))
|
||||
|
||||
label = query["metric_stat._metric._metric_name"] + " " + stat
|
||||
results.append(
|
||||
{
|
||||
"id": query["id"],
|
||||
"label": label,
|
||||
"vals": result_vals,
|
||||
"timestamps": [datetime.now() for _ in result_vals],
|
||||
}
|
||||
)
|
||||
return results
|
||||
|
||||
def get_metric_statistics(
|
||||
self, namespace, metric_name, start_time, end_time, period, stats
|
||||
):
|
||||
|
|
@ -334,7 +399,7 @@ class CloudWatchBackend(BaseBackend):
|
|||
return data
|
||||
|
||||
def get_all_metrics(self):
|
||||
return self.metric_data
|
||||
return self.metric_data + self.aws_metric_data
|
||||
|
||||
def put_dashboard(self, name, body):
|
||||
self.dashboards[name] = Dashboard(name, body)
|
||||
|
|
@ -386,7 +451,7 @@ class CloudWatchBackend(BaseBackend):
|
|||
|
||||
self.alarms[alarm_name].update_state(reason, reason_data, state_value)
|
||||
|
||||
def list_metrics(self, next_token, namespace, metric_name):
|
||||
def list_metrics(self, next_token, namespace, metric_name, dimensions):
|
||||
if next_token:
|
||||
if next_token not in self.paged_metric_data:
|
||||
raise RESTError(
|
||||
|
|
@ -397,15 +462,16 @@ class CloudWatchBackend(BaseBackend):
|
|||
del self.paged_metric_data[next_token] # Cant reuse same token twice
|
||||
return self._get_paginated(metrics)
|
||||
else:
|
||||
metrics = self.get_filtered_metrics(metric_name, namespace)
|
||||
metrics = self.get_filtered_metrics(metric_name, namespace, dimensions)
|
||||
return self._get_paginated(metrics)
|
||||
|
||||
def get_filtered_metrics(self, metric_name, namespace):
|
||||
def get_filtered_metrics(self, metric_name, namespace, dimensions):
|
||||
metrics = self.get_all_metrics()
|
||||
if namespace:
|
||||
metrics = [md for md in metrics if md.namespace == namespace]
|
||||
if metric_name:
|
||||
metrics = [md for md in metrics if md.name == metric_name]
|
||||
metrics = [
|
||||
md
|
||||
for md in metrics
|
||||
if md.filter(namespace=namespace, name=metric_name, dimensions=dimensions)
|
||||
]
|
||||
return metrics
|
||||
|
||||
def _get_paginated(self, metrics):
|
||||
|
|
@ -431,7 +497,9 @@ class LogGroup(BaseModel):
|
|||
properties = cloudformation_json["Properties"]
|
||||
log_group_name = properties["LogGroupName"]
|
||||
tags = properties.get("Tags", {})
|
||||
return logs_backends[region_name].create_log_group(log_group_name, tags)
|
||||
return logs_backends[region_name].create_log_group(
|
||||
log_group_name, tags, **properties
|
||||
)
|
||||
|
||||
|
||||
cloudwatch_backends = {}
|
||||
|
|
@ -443,3 +511,8 @@ for region in Session().get_available_regions(
|
|||
cloudwatch_backends[region] = CloudWatchBackend()
|
||||
for region in Session().get_available_regions("cloudwatch", partition_name="aws-cn"):
|
||||
cloudwatch_backends[region] = CloudWatchBackend()
|
||||
|
||||
# List of services that provide OOTB CW metrics
|
||||
# See the S3Backend constructor for an example
|
||||
# TODO: We might have to separate this out per region for non-global services
|
||||
metric_providers = {}
|
||||
|
|
|
|||
|
|
@ -92,6 +92,18 @@ class CloudWatchResponse(BaseResponse):
|
|||
template = self.response_template(PUT_METRIC_DATA_TEMPLATE)
|
||||
return template.render()
|
||||
|
||||
@amzn_request_id
|
||||
def get_metric_data(self):
|
||||
start = dtparse(self._get_param("StartTime"))
|
||||
end = dtparse(self._get_param("EndTime"))
|
||||
queries = self._get_list_prefix("MetricDataQueries.member")
|
||||
results = self.cloudwatch_backend.get_metric_data(
|
||||
start_time=start, end_time=end, queries=queries
|
||||
)
|
||||
|
||||
template = self.response_template(GET_METRIC_DATA_TEMPLATE)
|
||||
return template.render(results=results)
|
||||
|
||||
@amzn_request_id
|
||||
def get_metric_statistics(self):
|
||||
namespace = self._get_param("Namespace")
|
||||
|
|
@ -124,9 +136,10 @@ class CloudWatchResponse(BaseResponse):
|
|||
def list_metrics(self):
|
||||
namespace = self._get_param("Namespace")
|
||||
metric_name = self._get_param("MetricName")
|
||||
dimensions = self._get_multi_param("Dimensions.member")
|
||||
next_token = self._get_param("NextToken")
|
||||
next_token, metrics = self.cloudwatch_backend.list_metrics(
|
||||
next_token, namespace, metric_name
|
||||
next_token, namespace, metric_name, dimensions
|
||||
)
|
||||
template = self.response_template(LIST_METRICS_TEMPLATE)
|
||||
return template.render(metrics=metrics, next_token=next_token)
|
||||
|
|
@ -285,6 +298,35 @@ PUT_METRIC_DATA_TEMPLATE = """<PutMetricDataResponse xmlns="http://monitoring.am
|
|||
</ResponseMetadata>
|
||||
</PutMetricDataResponse>"""
|
||||
|
||||
GET_METRIC_DATA_TEMPLATE = """<GetMetricDataResponse xmlns="http://monitoring.amazonaws.com/doc/2010-08-01/">
|
||||
<ResponseMetadata>
|
||||
<RequestId>
|
||||
{{ request_id }}
|
||||
</RequestId>
|
||||
</ResponseMetadata>
|
||||
<GetMetricDataResult>
|
||||
<MetricDataResults>
|
||||
{% for result in results %}
|
||||
<member>
|
||||
<Id>{{ result.id }}</Id>
|
||||
<Label>{{ result.label }}</Label>
|
||||
<StatusCode>Complete</StatusCode>
|
||||
<Timestamps>
|
||||
{% for val in result.timestamps %}
|
||||
<member>{{ val }}</member>
|
||||
{% endfor %}
|
||||
</Timestamps>
|
||||
<Values>
|
||||
{% for val in result.vals %}
|
||||
<member>{{ val }}</member>
|
||||
{% endfor %}
|
||||
</Values>
|
||||
</member>
|
||||
{% endfor %}
|
||||
</MetricDataResults>
|
||||
</GetMetricDataResult>
|
||||
</GetMetricDataResponse>"""
|
||||
|
||||
GET_METRIC_STATISTICS_TEMPLATE = """<GetMetricStatisticsResponse xmlns="http://monitoring.amazonaws.com/doc/2010-08-01/">
|
||||
<ResponseMetadata>
|
||||
<RequestId>
|
||||
|
|
@ -342,7 +384,7 @@ LIST_METRICS_TEMPLATE = """<ListMetricsResponse xmlns="http://monitoring.amazona
|
|||
</member>
|
||||
{% endfor %}
|
||||
</Dimensions>
|
||||
<MetricName>{{ metric.name }}</MetricName>
|
||||
<MetricName>Metric:{{ metric.name }}</MetricName>
|
||||
<Namespace>{{ metric.namespace }}</Namespace>
|
||||
</member>
|
||||
{% endfor %}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from moto.core.utils import get_random_hex
|
||||
from uuid import uuid4
|
||||
|
||||
|
||||
def get_random_identity_id(region):
|
||||
return "{0}:{1}".format(region, get_random_hex(length=19))
|
||||
return "{0}:{1}".format(region, uuid4())
|
||||
|
|
|
|||
|
|
@ -12,6 +12,8 @@ from io import BytesIO
|
|||
from collections import defaultdict
|
||||
from botocore.handlers import BUILTIN_HANDLERS
|
||||
from botocore.awsrequest import AWSResponse
|
||||
from six.moves.urllib.parse import urlparse
|
||||
from werkzeug.wrappers import Request
|
||||
|
||||
import mock
|
||||
from moto import settings
|
||||
|
|
@ -175,6 +177,26 @@ class CallbackResponse(responses.CallbackResponse):
|
|||
"""
|
||||
Need to override this so we can pass decode_content=False
|
||||
"""
|
||||
if not isinstance(request, Request):
|
||||
url = urlparse(request.url)
|
||||
if request.body is None:
|
||||
body = None
|
||||
elif isinstance(request.body, six.text_type):
|
||||
body = six.BytesIO(six.b(request.body))
|
||||
else:
|
||||
body = six.BytesIO(request.body)
|
||||
req = Request.from_values(
|
||||
path="?".join([url.path, url.query]),
|
||||
input_stream=body,
|
||||
content_length=request.headers.get("Content-Length"),
|
||||
content_type=request.headers.get("Content-Type"),
|
||||
method=request.method,
|
||||
base_url="{scheme}://{netloc}".format(
|
||||
scheme=url.scheme, netloc=url.netloc
|
||||
),
|
||||
headers=[(k, v) for k, v in six.iteritems(request.headers)],
|
||||
)
|
||||
request = req
|
||||
headers = self.get_headers()
|
||||
|
||||
result = self.callback(request)
|
||||
|
|
|
|||
|
|
@ -328,3 +328,25 @@ def py2_strip_unicode_keys(blob):
|
|||
blob = new_set
|
||||
|
||||
return blob
|
||||
|
||||
|
||||
def tags_from_query_string(
|
||||
querystring_dict, prefix="Tag", key_suffix="Key", value_suffix="Value"
|
||||
):
|
||||
response_values = {}
|
||||
for key, value in querystring_dict.items():
|
||||
if key.startswith(prefix) and key.endswith(key_suffix):
|
||||
tag_index = key.replace(prefix + ".", "").replace("." + key_suffix, "")
|
||||
tag_key = querystring_dict.get(
|
||||
"{prefix}.{index}.{key_suffix}".format(
|
||||
prefix=prefix, index=tag_index, key_suffix=key_suffix,
|
||||
)
|
||||
)[0]
|
||||
tag_value_key = "{prefix}.{index}.{value_suffix}".format(
|
||||
prefix=prefix, index=tag_index, value_suffix=value_suffix,
|
||||
)
|
||||
if tag_value_key in querystring_dict:
|
||||
response_values[tag_key] = querystring_dict.get(tag_value_key)[0]
|
||||
else:
|
||||
response_values[tag_key] = None
|
||||
return response_values
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from __future__ import unicode_literals
|
||||
from .models import dynamodb_backends as dynamodb_backends2
|
||||
from moto.dynamodb2.models import dynamodb_backends as dynamodb_backends2
|
||||
from ..core.models import base_decorator, deprecated_base_decorator
|
||||
|
||||
dynamodb_backend2 = dynamodb_backends2["us-east-1"]
|
||||
|
|
|
|||
|
|
@ -2,9 +2,132 @@ class InvalidIndexNameError(ValueError):
|
|||
pass
|
||||
|
||||
|
||||
class InvalidUpdateExpression(ValueError):
|
||||
pass
|
||||
class MockValidationException(ValueError):
|
||||
def __init__(self, message):
|
||||
self.exception_msg = message
|
||||
|
||||
|
||||
class ItemSizeTooLarge(Exception):
|
||||
message = "Item size has exceeded the maximum allowed size"
|
||||
class InvalidUpdateExpressionInvalidDocumentPath(MockValidationException):
|
||||
invalid_update_expression_msg = (
|
||||
"The document path provided in the update expression is invalid for update"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super(InvalidUpdateExpressionInvalidDocumentPath, self).__init__(
|
||||
self.invalid_update_expression_msg
|
||||
)
|
||||
|
||||
|
||||
class InvalidUpdateExpression(MockValidationException):
|
||||
invalid_update_expr_msg = "Invalid UpdateExpression: {update_expression_error}"
|
||||
|
||||
def __init__(self, update_expression_error):
|
||||
self.update_expression_error = update_expression_error
|
||||
super(InvalidUpdateExpression, self).__init__(
|
||||
self.invalid_update_expr_msg.format(
|
||||
update_expression_error=update_expression_error
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class AttributeDoesNotExist(MockValidationException):
|
||||
attr_does_not_exist_msg = (
|
||||
"The provided expression refers to an attribute that does not exist in the item"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super(AttributeDoesNotExist, self).__init__(self.attr_does_not_exist_msg)
|
||||
|
||||
|
||||
class ExpressionAttributeNameNotDefined(InvalidUpdateExpression):
|
||||
name_not_defined_msg = "An expression attribute name used in the document path is not defined; attribute name: {n}"
|
||||
|
||||
def __init__(self, attribute_name):
|
||||
self.not_defined_attribute_name = attribute_name
|
||||
super(ExpressionAttributeNameNotDefined, self).__init__(
|
||||
self.name_not_defined_msg.format(n=attribute_name)
|
||||
)
|
||||
|
||||
|
||||
class AttributeIsReservedKeyword(InvalidUpdateExpression):
|
||||
attribute_is_keyword_msg = (
|
||||
"Attribute name is a reserved keyword; reserved keyword: {keyword}"
|
||||
)
|
||||
|
||||
def __init__(self, keyword):
|
||||
self.keyword = keyword
|
||||
super(AttributeIsReservedKeyword, self).__init__(
|
||||
self.attribute_is_keyword_msg.format(keyword=keyword)
|
||||
)
|
||||
|
||||
|
||||
class ExpressionAttributeValueNotDefined(InvalidUpdateExpression):
|
||||
attr_value_not_defined_msg = "An expression attribute value used in expression is not defined; attribute value: {attribute_value}"
|
||||
|
||||
def __init__(self, attribute_value):
|
||||
self.attribute_value = attribute_value
|
||||
super(ExpressionAttributeValueNotDefined, self).__init__(
|
||||
self.attr_value_not_defined_msg.format(attribute_value=attribute_value)
|
||||
)
|
||||
|
||||
|
||||
class UpdateExprSyntaxError(InvalidUpdateExpression):
|
||||
update_expr_syntax_error_msg = "Syntax error; {error_detail}"
|
||||
|
||||
def __init__(self, error_detail):
|
||||
self.error_detail = error_detail
|
||||
super(UpdateExprSyntaxError, self).__init__(
|
||||
self.update_expr_syntax_error_msg.format(error_detail=error_detail)
|
||||
)
|
||||
|
||||
|
||||
class InvalidTokenException(UpdateExprSyntaxError):
|
||||
token_detail_msg = 'token: "{token}", near: "{near}"'
|
||||
|
||||
def __init__(self, token, near):
|
||||
self.token = token
|
||||
self.near = near
|
||||
super(InvalidTokenException, self).__init__(
|
||||
self.token_detail_msg.format(token=token, near=near)
|
||||
)
|
||||
|
||||
|
||||
class InvalidExpressionAttributeNameKey(MockValidationException):
|
||||
invalid_expr_attr_name_msg = (
|
||||
'ExpressionAttributeNames contains invalid key: Syntax error; key: "{key}"'
|
||||
)
|
||||
|
||||
def __init__(self, key):
|
||||
self.key = key
|
||||
super(InvalidExpressionAttributeNameKey, self).__init__(
|
||||
self.invalid_expr_attr_name_msg.format(key=key)
|
||||
)
|
||||
|
||||
|
||||
class ItemSizeTooLarge(MockValidationException):
|
||||
item_size_too_large_msg = "Item size has exceeded the maximum allowed size"
|
||||
|
||||
def __init__(self):
|
||||
super(ItemSizeTooLarge, self).__init__(self.item_size_too_large_msg)
|
||||
|
||||
|
||||
class ItemSizeToUpdateTooLarge(MockValidationException):
|
||||
item_size_to_update_too_large_msg = (
|
||||
"Item size to update has exceeded the maximum allowed size"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super(ItemSizeToUpdateTooLarge, self).__init__(
|
||||
self.item_size_to_update_too_large_msg
|
||||
)
|
||||
|
||||
|
||||
class IncorrectOperandType(InvalidUpdateExpression):
|
||||
inv_operand_msg = "Incorrect operand type for operator or function; operator or function: {f}, operand type: {t}"
|
||||
|
||||
def __init__(self, operator_or_function, operand_type):
|
||||
self.operator_or_function = operator_or_function
|
||||
self.operand_type = operand_type
|
||||
super(IncorrectOperandType, self).__init__(
|
||||
self.inv_operand_msg.format(f=operator_or_function, t=operand_type)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ import decimal
|
|||
import json
|
||||
import re
|
||||
import uuid
|
||||
import six
|
||||
|
||||
from boto3 import Session
|
||||
from botocore.exceptions import ParamValidationError
|
||||
|
|
@ -14,10 +13,17 @@ from moto.compat import OrderedDict
|
|||
from moto.core import BaseBackend, BaseModel
|
||||
from moto.core.utils import unix_time
|
||||
from moto.core.exceptions import JsonRESTError
|
||||
from .comparisons import get_comparison_func
|
||||
from .comparisons import get_filter_expression
|
||||
from .comparisons import get_expected
|
||||
from .exceptions import InvalidIndexNameError, InvalidUpdateExpression, ItemSizeTooLarge
|
||||
from moto.dynamodb2.comparisons import get_filter_expression
|
||||
from moto.dynamodb2.comparisons import get_expected
|
||||
from moto.dynamodb2.exceptions import (
|
||||
InvalidIndexNameError,
|
||||
ItemSizeTooLarge,
|
||||
ItemSizeToUpdateTooLarge,
|
||||
)
|
||||
from moto.dynamodb2.models.utilities import bytesize, attribute_is_list
|
||||
from moto.dynamodb2.models.dynamo_type import DynamoType
|
||||
from moto.dynamodb2.parsing.expressions import UpdateExpressionParser
|
||||
from moto.dynamodb2.parsing.validators import UpdateExpressionValidator
|
||||
|
||||
|
||||
class DynamoJsonEncoder(json.JSONEncoder):
|
||||
|
|
@ -30,223 +36,6 @@ def dynamo_json_dump(dynamo_object):
|
|||
return json.dumps(dynamo_object, cls=DynamoJsonEncoder)
|
||||
|
||||
|
||||
def bytesize(val):
|
||||
return len(str(val).encode("utf-8"))
|
||||
|
||||
|
||||
def attribute_is_list(attr):
|
||||
"""
|
||||
Checks if attribute denotes a list, and returns the name of the list and the given list index if so
|
||||
:param attr: attr or attr[index]
|
||||
:return: attr, index or None
|
||||
"""
|
||||
list_index_update = re.match("(.+)\\[([0-9]+)\\]", attr)
|
||||
if list_index_update:
|
||||
attr = list_index_update.group(1)
|
||||
return attr, list_index_update.group(2) if list_index_update else None
|
||||
|
||||
|
||||
class DynamoType(object):
|
||||
"""
|
||||
http://docs.aws.amazon.com/amazondynamodb/latest/developerguide/DataModel.html#DataModelDataTypes
|
||||
"""
|
||||
|
||||
def __init__(self, type_as_dict):
|
||||
if type(type_as_dict) == DynamoType:
|
||||
self.type = type_as_dict.type
|
||||
self.value = type_as_dict.value
|
||||
else:
|
||||
self.type = list(type_as_dict)[0]
|
||||
self.value = list(type_as_dict.values())[0]
|
||||
if self.is_list():
|
||||
self.value = [DynamoType(val) for val in self.value]
|
||||
elif self.is_map():
|
||||
self.value = dict((k, DynamoType(v)) for k, v in self.value.items())
|
||||
|
||||
def get(self, key):
|
||||
if not key:
|
||||
return self
|
||||
else:
|
||||
key_head = key.split(".")[0]
|
||||
key_tail = ".".join(key.split(".")[1:])
|
||||
if key_head not in self.value:
|
||||
self.value[key_head] = DynamoType({"NONE": None})
|
||||
return self.value[key_head].get(key_tail)
|
||||
|
||||
def set(self, key, new_value, index=None):
|
||||
if index:
|
||||
index = int(index)
|
||||
if type(self.value) is not list:
|
||||
raise InvalidUpdateExpression
|
||||
if index >= len(self.value):
|
||||
self.value.append(new_value)
|
||||
# {'L': [DynamoType, ..]} ==> DynamoType.set()
|
||||
self.value[min(index, len(self.value) - 1)].set(key, new_value)
|
||||
else:
|
||||
attr = (key or "").split(".").pop(0)
|
||||
attr, list_index = attribute_is_list(attr)
|
||||
if not key:
|
||||
# {'S': value} ==> {'S': new_value}
|
||||
self.type = new_value.type
|
||||
self.value = new_value.value
|
||||
else:
|
||||
if attr not in self.value: # nonexistingattribute
|
||||
type_of_new_attr = "M" if "." in key else new_value.type
|
||||
self.value[attr] = DynamoType({type_of_new_attr: {}})
|
||||
# {'M': {'foo': DynamoType}} ==> DynamoType.set(new_value)
|
||||
self.value[attr].set(
|
||||
".".join(key.split(".")[1:]), new_value, list_index
|
||||
)
|
||||
|
||||
def delete(self, key, index=None):
|
||||
if index:
|
||||
if not key:
|
||||
if int(index) < len(self.value):
|
||||
del self.value[int(index)]
|
||||
elif "." in key:
|
||||
self.value[int(index)].delete(".".join(key.split(".")[1:]))
|
||||
else:
|
||||
self.value[int(index)].delete(key)
|
||||
else:
|
||||
attr = key.split(".")[0]
|
||||
attr, list_index = attribute_is_list(attr)
|
||||
|
||||
if list_index:
|
||||
self.value[attr].delete(".".join(key.split(".")[1:]), list_index)
|
||||
elif "." in key:
|
||||
self.value[attr].delete(".".join(key.split(".")[1:]))
|
||||
else:
|
||||
self.value.pop(key)
|
||||
|
||||
def filter(self, projection_expressions):
|
||||
nested_projections = [
|
||||
expr[0 : expr.index(".")] for expr in projection_expressions if "." in expr
|
||||
]
|
||||
if self.is_map():
|
||||
expressions_to_delete = []
|
||||
for attr in self.value:
|
||||
if (
|
||||
attr not in projection_expressions
|
||||
and attr not in nested_projections
|
||||
):
|
||||
expressions_to_delete.append(attr)
|
||||
elif attr in nested_projections:
|
||||
relevant_expressions = [
|
||||
expr[len(attr + ".") :]
|
||||
for expr in projection_expressions
|
||||
if expr.startswith(attr + ".")
|
||||
]
|
||||
self.value[attr].filter(relevant_expressions)
|
||||
for expr in expressions_to_delete:
|
||||
self.value.pop(expr)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.type, self.value))
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.type == other.type and self.value == other.value
|
||||
|
||||
def __ne__(self, other):
|
||||
return self.type != other.type or self.value != other.value
|
||||
|
||||
def __lt__(self, other):
|
||||
return self.cast_value < other.cast_value
|
||||
|
||||
def __le__(self, other):
|
||||
return self.cast_value <= other.cast_value
|
||||
|
||||
def __gt__(self, other):
|
||||
return self.cast_value > other.cast_value
|
||||
|
||||
def __ge__(self, other):
|
||||
return self.cast_value >= other.cast_value
|
||||
|
||||
def __repr__(self):
|
||||
return "DynamoType: {0}".format(self.to_json())
|
||||
|
||||
@property
|
||||
def cast_value(self):
|
||||
if self.is_number():
|
||||
try:
|
||||
return int(self.value)
|
||||
except ValueError:
|
||||
return float(self.value)
|
||||
elif self.is_set():
|
||||
sub_type = self.type[0]
|
||||
return set([DynamoType({sub_type: v}).cast_value for v in self.value])
|
||||
elif self.is_list():
|
||||
return [DynamoType(v).cast_value for v in self.value]
|
||||
elif self.is_map():
|
||||
return dict([(k, DynamoType(v).cast_value) for k, v in self.value.items()])
|
||||
else:
|
||||
return self.value
|
||||
|
||||
def child_attr(self, key):
|
||||
"""
|
||||
Get Map or List children by key. str for Map, int for List.
|
||||
|
||||
Returns DynamoType or None.
|
||||
"""
|
||||
if isinstance(key, six.string_types) and self.is_map():
|
||||
if "." in key and key.split(".")[0] in self.value:
|
||||
return self.value[key.split(".")[0]].child_attr(
|
||||
".".join(key.split(".")[1:])
|
||||
)
|
||||
elif "." not in key and key in self.value:
|
||||
return DynamoType(self.value[key])
|
||||
|
||||
if isinstance(key, int) and self.is_list():
|
||||
idx = key
|
||||
if 0 <= idx < len(self.value):
|
||||
return DynamoType(self.value[idx])
|
||||
|
||||
return None
|
||||
|
||||
def size(self):
|
||||
if self.is_number():
|
||||
value_size = len(str(self.value))
|
||||
elif self.is_set():
|
||||
sub_type = self.type[0]
|
||||
value_size = sum([DynamoType({sub_type: v}).size() for v in self.value])
|
||||
elif self.is_list():
|
||||
value_size = sum([v.size() for v in self.value])
|
||||
elif self.is_map():
|
||||
value_size = sum(
|
||||
[bytesize(k) + DynamoType(v).size() for k, v in self.value.items()]
|
||||
)
|
||||
elif type(self.value) == bool:
|
||||
value_size = 1
|
||||
else:
|
||||
value_size = bytesize(self.value)
|
||||
return value_size
|
||||
|
||||
def to_json(self):
|
||||
return {self.type: self.value}
|
||||
|
||||
def compare(self, range_comparison, range_objs):
|
||||
"""
|
||||
Compares this type against comparison filters
|
||||
"""
|
||||
range_values = [obj.cast_value for obj in range_objs]
|
||||
comparison_func = get_comparison_func(range_comparison)
|
||||
return comparison_func(self.cast_value, *range_values)
|
||||
|
||||
def is_number(self):
|
||||
return self.type == "N"
|
||||
|
||||
def is_set(self):
|
||||
return self.type == "SS" or self.type == "NS" or self.type == "BS"
|
||||
|
||||
def is_list(self):
|
||||
return self.type == "L"
|
||||
|
||||
def is_map(self):
|
||||
return self.type == "M"
|
||||
|
||||
def same_type(self, other):
|
||||
return self.type == other.type
|
||||
|
||||
|
||||
# https://github.com/spulec/moto/issues/1874
|
||||
# Ensure that the total size of an item does not exceed 400kb
|
||||
class LimitedSizeDict(dict):
|
||||
|
|
@ -285,6 +74,9 @@ class Item(BaseModel):
|
|||
def __repr__(self):
|
||||
return "Item: {0}".format(self.to_json())
|
||||
|
||||
def size(self):
|
||||
return sum(bytesize(key) + value.size() for key, value in self.attrs.items())
|
||||
|
||||
def to_json(self):
|
||||
attributes = {}
|
||||
for attribute_key, attribute in self.attrs.items():
|
||||
|
|
@ -367,7 +159,10 @@ class Item(BaseModel):
|
|||
if "." in key and attr not in self.attrs:
|
||||
raise ValueError # Setting nested attr not allowed if first attr does not exist yet
|
||||
elif attr not in self.attrs:
|
||||
self.attrs[attr] = dyn_value # set new top-level attribute
|
||||
try:
|
||||
self.attrs[attr] = dyn_value # set new top-level attribute
|
||||
except ItemSizeTooLarge:
|
||||
raise ItemSizeToUpdateTooLarge()
|
||||
else:
|
||||
self.attrs[attr].set(
|
||||
".".join(key.split(".")[1:]), dyn_value, list_index
|
||||
|
|
@ -1129,6 +924,14 @@ class Table(BaseModel):
|
|||
break
|
||||
|
||||
last_evaluated_key = None
|
||||
size_limit = 1000000 # DynamoDB has a 1MB size limit
|
||||
item_size = sum(res.size() for res in results)
|
||||
if item_size > size_limit:
|
||||
item_size = idx = 0
|
||||
while item_size + results[idx].size() < size_limit:
|
||||
item_size += results[idx].size()
|
||||
idx += 1
|
||||
limit = min(limit, idx) if limit else idx
|
||||
if limit and len(results) > limit:
|
||||
results = results[:limit]
|
||||
last_evaluated_key = {self.hash_key_attr: results[-1].hash_key}
|
||||
|
|
@ -1414,6 +1217,13 @@ class DynamoDBBackend(BaseBackend):
|
|||
):
|
||||
table = self.get_table(table_name)
|
||||
|
||||
# Support spaces between operators in an update expression
|
||||
# E.g. `a = b + c` -> `a=b+c`
|
||||
if update_expression:
|
||||
# Parse expression to get validation errors
|
||||
update_expression_ast = UpdateExpressionParser.make(update_expression)
|
||||
update_expression = re.sub(r"\s*([=\+-])\s*", "\\1", update_expression)
|
||||
|
||||
if all([table.hash_key_attr in key, table.range_key_attr in key]):
|
||||
# Covers cases where table has hash and range keys, ``key`` param
|
||||
# will be a dict
|
||||
|
|
@ -1456,6 +1266,12 @@ class DynamoDBBackend(BaseBackend):
|
|||
item = table.get_item(hash_value, range_value)
|
||||
|
||||
if update_expression:
|
||||
UpdateExpressionValidator(
|
||||
update_expression_ast,
|
||||
expression_attribute_names=expression_attribute_names,
|
||||
expression_attribute_values=expression_attribute_values,
|
||||
item=item,
|
||||
).validate()
|
||||
item.update(
|
||||
update_expression,
|
||||
expression_attribute_names,
|
||||
237
moto/dynamodb2/models/dynamo_type.py
Normal file
237
moto/dynamodb2/models/dynamo_type.py
Normal file
|
|
@ -0,0 +1,237 @@
|
|||
import six
|
||||
|
||||
from moto.dynamodb2.comparisons import get_comparison_func
|
||||
from moto.dynamodb2.exceptions import InvalidUpdateExpression
|
||||
from moto.dynamodb2.models.utilities import attribute_is_list, bytesize
|
||||
|
||||
|
||||
class DynamoType(object):
|
||||
"""
|
||||
http://docs.aws.amazon.com/amazondynamodb/latest/developerguide/DataModel.html#DataModelDataTypes
|
||||
"""
|
||||
|
||||
def __init__(self, type_as_dict):
|
||||
if type(type_as_dict) == DynamoType:
|
||||
self.type = type_as_dict.type
|
||||
self.value = type_as_dict.value
|
||||
else:
|
||||
self.type = list(type_as_dict)[0]
|
||||
self.value = list(type_as_dict.values())[0]
|
||||
if self.is_list():
|
||||
self.value = [DynamoType(val) for val in self.value]
|
||||
elif self.is_map():
|
||||
self.value = dict((k, DynamoType(v)) for k, v in self.value.items())
|
||||
|
||||
def get(self, key):
|
||||
if not key:
|
||||
return self
|
||||
else:
|
||||
key_head = key.split(".")[0]
|
||||
key_tail = ".".join(key.split(".")[1:])
|
||||
if key_head not in self.value:
|
||||
self.value[key_head] = DynamoType({"NONE": None})
|
||||
return self.value[key_head].get(key_tail)
|
||||
|
||||
def set(self, key, new_value, index=None):
|
||||
if index:
|
||||
index = int(index)
|
||||
if type(self.value) is not list:
|
||||
raise InvalidUpdateExpression
|
||||
if index >= len(self.value):
|
||||
self.value.append(new_value)
|
||||
# {'L': [DynamoType, ..]} ==> DynamoType.set()
|
||||
self.value[min(index, len(self.value) - 1)].set(key, new_value)
|
||||
else:
|
||||
attr = (key or "").split(".").pop(0)
|
||||
attr, list_index = attribute_is_list(attr)
|
||||
if not key:
|
||||
# {'S': value} ==> {'S': new_value}
|
||||
self.type = new_value.type
|
||||
self.value = new_value.value
|
||||
else:
|
||||
if attr not in self.value: # nonexistingattribute
|
||||
type_of_new_attr = "M" if "." in key else new_value.type
|
||||
self.value[attr] = DynamoType({type_of_new_attr: {}})
|
||||
# {'M': {'foo': DynamoType}} ==> DynamoType.set(new_value)
|
||||
self.value[attr].set(
|
||||
".".join(key.split(".")[1:]), new_value, list_index
|
||||
)
|
||||
|
||||
def delete(self, key, index=None):
|
||||
if index:
|
||||
if not key:
|
||||
if int(index) < len(self.value):
|
||||
del self.value[int(index)]
|
||||
elif "." in key:
|
||||
self.value[int(index)].delete(".".join(key.split(".")[1:]))
|
||||
else:
|
||||
self.value[int(index)].delete(key)
|
||||
else:
|
||||
attr = key.split(".")[0]
|
||||
attr, list_index = attribute_is_list(attr)
|
||||
|
||||
if list_index:
|
||||
self.value[attr].delete(".".join(key.split(".")[1:]), list_index)
|
||||
elif "." in key:
|
||||
self.value[attr].delete(".".join(key.split(".")[1:]))
|
||||
else:
|
||||
self.value.pop(key)
|
||||
|
||||
def filter(self, projection_expressions):
|
||||
nested_projections = [
|
||||
expr[0 : expr.index(".")] for expr in projection_expressions if "." in expr
|
||||
]
|
||||
if self.is_map():
|
||||
expressions_to_delete = []
|
||||
for attr in self.value:
|
||||
if (
|
||||
attr not in projection_expressions
|
||||
and attr not in nested_projections
|
||||
):
|
||||
expressions_to_delete.append(attr)
|
||||
elif attr in nested_projections:
|
||||
relevant_expressions = [
|
||||
expr[len(attr + ".") :]
|
||||
for expr in projection_expressions
|
||||
if expr.startswith(attr + ".")
|
||||
]
|
||||
self.value[attr].filter(relevant_expressions)
|
||||
for expr in expressions_to_delete:
|
||||
self.value.pop(expr)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.type, self.value))
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.type == other.type and self.value == other.value
|
||||
|
||||
def __ne__(self, other):
|
||||
return self.type != other.type or self.value != other.value
|
||||
|
||||
def __lt__(self, other):
|
||||
return self.cast_value < other.cast_value
|
||||
|
||||
def __le__(self, other):
|
||||
return self.cast_value <= other.cast_value
|
||||
|
||||
def __gt__(self, other):
|
||||
return self.cast_value > other.cast_value
|
||||
|
||||
def __ge__(self, other):
|
||||
return self.cast_value >= other.cast_value
|
||||
|
||||
def __repr__(self):
|
||||
return "DynamoType: {0}".format(self.to_json())
|
||||
|
||||
def __add__(self, other):
|
||||
if self.type != other.type:
|
||||
raise TypeError("Different types of operandi is not allowed.")
|
||||
if self.type == "N":
|
||||
return DynamoType({"N": "{v}".format(v=int(self.value) + int(other.value))})
|
||||
else:
|
||||
raise TypeError("Sum only supported for Numbers.")
|
||||
|
||||
def __sub__(self, other):
|
||||
if self.type != other.type:
|
||||
raise TypeError("Different types of operandi is not allowed.")
|
||||
if self.type == "N":
|
||||
return DynamoType({"N": "{v}".format(v=int(self.value) - int(other.value))})
|
||||
else:
|
||||
raise TypeError("Sum only supported for Numbers.")
|
||||
|
||||
def __getitem__(self, item):
|
||||
if isinstance(item, six.string_types):
|
||||
# If our DynamoType is a map it should be subscriptable with a key
|
||||
if self.type == "M":
|
||||
return self.value[item]
|
||||
elif isinstance(item, int):
|
||||
# If our DynamoType is a list is should be subscriptable with an index
|
||||
if self.type == "L":
|
||||
return self.value[item]
|
||||
raise TypeError(
|
||||
"This DynamoType {dt} is not subscriptable by a {it}".format(
|
||||
dt=self.type, it=type(item)
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def cast_value(self):
|
||||
if self.is_number():
|
||||
try:
|
||||
return int(self.value)
|
||||
except ValueError:
|
||||
return float(self.value)
|
||||
elif self.is_set():
|
||||
sub_type = self.type[0]
|
||||
return set([DynamoType({sub_type: v}).cast_value for v in self.value])
|
||||
elif self.is_list():
|
||||
return [DynamoType(v).cast_value for v in self.value]
|
||||
elif self.is_map():
|
||||
return dict([(k, DynamoType(v).cast_value) for k, v in self.value.items()])
|
||||
else:
|
||||
return self.value
|
||||
|
||||
def child_attr(self, key):
|
||||
"""
|
||||
Get Map or List children by key. str for Map, int for List.
|
||||
|
||||
Returns DynamoType or None.
|
||||
"""
|
||||
if isinstance(key, six.string_types) and self.is_map():
|
||||
if "." in key and key.split(".")[0] in self.value:
|
||||
return self.value[key.split(".")[0]].child_attr(
|
||||
".".join(key.split(".")[1:])
|
||||
)
|
||||
elif "." not in key and key in self.value:
|
||||
return DynamoType(self.value[key])
|
||||
|
||||
if isinstance(key, int) and self.is_list():
|
||||
idx = key
|
||||
if 0 <= idx < len(self.value):
|
||||
return DynamoType(self.value[idx])
|
||||
|
||||
return None
|
||||
|
||||
def size(self):
|
||||
if self.is_number():
|
||||
value_size = len(str(self.value))
|
||||
elif self.is_set():
|
||||
sub_type = self.type[0]
|
||||
value_size = sum([DynamoType({sub_type: v}).size() for v in self.value])
|
||||
elif self.is_list():
|
||||
value_size = sum([v.size() for v in self.value])
|
||||
elif self.is_map():
|
||||
value_size = sum(
|
||||
[bytesize(k) + DynamoType(v).size() for k, v in self.value.items()]
|
||||
)
|
||||
elif type(self.value) == bool:
|
||||
value_size = 1
|
||||
else:
|
||||
value_size = bytesize(self.value)
|
||||
return value_size
|
||||
|
||||
def to_json(self):
|
||||
return {self.type: self.value}
|
||||
|
||||
def compare(self, range_comparison, range_objs):
|
||||
"""
|
||||
Compares this type against comparison filters
|
||||
"""
|
||||
range_values = [obj.cast_value for obj in range_objs]
|
||||
comparison_func = get_comparison_func(range_comparison)
|
||||
return comparison_func(self.cast_value, *range_values)
|
||||
|
||||
def is_number(self):
|
||||
return self.type == "N"
|
||||
|
||||
def is_set(self):
|
||||
return self.type == "SS" or self.type == "NS" or self.type == "BS"
|
||||
|
||||
def is_list(self):
|
||||
return self.type == "L"
|
||||
|
||||
def is_map(self):
|
||||
return self.type == "M"
|
||||
|
||||
def same_type(self, other):
|
||||
return self.type == other.type
|
||||
17
moto/dynamodb2/models/utilities.py
Normal file
17
moto/dynamodb2/models/utilities.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
import re
|
||||
|
||||
|
||||
def bytesize(val):
|
||||
return len(str(val).encode("utf-8"))
|
||||
|
||||
|
||||
def attribute_is_list(attr):
|
||||
"""
|
||||
Checks if attribute denotes a list, and returns the name of the list and the given list index if so
|
||||
:param attr: attr or attr[index]
|
||||
:return: attr, index or None
|
||||
"""
|
||||
list_index_update = re.match("(.+)\\[([0-9]+)\\]", attr)
|
||||
if list_index_update:
|
||||
attr = list_index_update.group(1)
|
||||
return attr, list_index_update.group(2) if list_index_update else None
|
||||
23
moto/dynamodb2/parsing/README.md
Normal file
23
moto/dynamodb2/parsing/README.md
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
# Parsing dev documentation
|
||||
|
||||
Parsing happens in a structured manner and happens in different phases.
|
||||
This document explains these phases.
|
||||
|
||||
|
||||
## 1) Expression gets parsed into a tokenlist (tokenized)
|
||||
A string gets parsed from left to right and gets converted into a list of tokens.
|
||||
The tokens are available in `tokens.py`.
|
||||
|
||||
## 2) Tokenlist get transformed to expression tree (AST)
|
||||
This is the parsing of the token list. This parsing will result in an Abstract Syntax Tree (AST).
|
||||
The different node types are available in `ast_nodes.py`. The AST is a representation that has all
|
||||
the information that is in the expression but its tree form allows processing it in a structured manner.
|
||||
|
||||
## 3) The AST gets validated (full semantic correctness)
|
||||
The AST is used for validation. The paths and attributes are validated to be correct. At the end of the
|
||||
validation all the values will be resolved.
|
||||
|
||||
## 4) Update Expression gets executed using the validated AST
|
||||
Finally the AST is used to execute the update expression. There should be no reason for this step to fail
|
||||
since validation has completed. Due to this we have the update expressions behaving atomically (i.e. all the
|
||||
actions of the update expresion are performed or none of them are performed).
|
||||
0
moto/dynamodb2/parsing/__init__.py
Normal file
0
moto/dynamodb2/parsing/__init__.py
Normal file
360
moto/dynamodb2/parsing/ast_nodes.py
Normal file
360
moto/dynamodb2/parsing/ast_nodes.py
Normal file
|
|
@ -0,0 +1,360 @@
|
|||
import abc
|
||||
from abc import abstractmethod
|
||||
from collections import deque
|
||||
|
||||
import six
|
||||
|
||||
from moto.dynamodb2.models import DynamoType
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class Node:
|
||||
def __init__(self, children=None):
|
||||
self.type = self.__class__.__name__
|
||||
assert children is None or isinstance(children, list)
|
||||
self.children = children
|
||||
self.parent = None
|
||||
|
||||
if isinstance(children, list):
|
||||
for child in children:
|
||||
if isinstance(child, Node):
|
||||
child.set_parent(self)
|
||||
|
||||
def set_parent(self, parent_node):
|
||||
self.parent = parent_node
|
||||
|
||||
|
||||
class LeafNode(Node):
|
||||
"""A LeafNode is a Node where none of the children are Nodes themselves."""
|
||||
|
||||
def __init__(self, children=None):
|
||||
super(LeafNode, self).__init__(children)
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class Expression(Node):
|
||||
"""
|
||||
Abstract Syntax Tree representing the expression
|
||||
|
||||
For the Grammar start here and jump down into the classes at the righ-hand side to look further. Nodes marked with
|
||||
a star are abstract and won't appear in the final AST.
|
||||
|
||||
Expression* => UpdateExpression
|
||||
Expression* => ConditionExpression
|
||||
"""
|
||||
|
||||
|
||||
class UpdateExpression(Expression):
|
||||
"""
|
||||
UpdateExpression => UpdateExpressionClause*
|
||||
UpdateExpression => UpdateExpressionClause* UpdateExpression
|
||||
"""
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class UpdateExpressionClause(UpdateExpression):
|
||||
"""
|
||||
UpdateExpressionClause* => UpdateExpressionSetClause
|
||||
UpdateExpressionClause* => UpdateExpressionRemoveClause
|
||||
UpdateExpressionClause* => UpdateExpressionAddClause
|
||||
UpdateExpressionClause* => UpdateExpressionDeleteClause
|
||||
"""
|
||||
|
||||
|
||||
class UpdateExpressionSetClause(UpdateExpressionClause):
|
||||
"""
|
||||
UpdateExpressionSetClause => SET SetActions
|
||||
"""
|
||||
|
||||
|
||||
class UpdateExpressionSetActions(UpdateExpressionClause):
|
||||
"""
|
||||
UpdateExpressionSetClause => SET SetActions
|
||||
|
||||
SetActions => SetAction
|
||||
SetActions => SetAction , SetActions
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class UpdateExpressionSetAction(UpdateExpressionClause):
|
||||
"""
|
||||
SetAction => Path = Value
|
||||
"""
|
||||
|
||||
|
||||
class UpdateExpressionRemoveActions(UpdateExpressionClause):
|
||||
"""
|
||||
UpdateExpressionSetClause => REMOVE RemoveActions
|
||||
|
||||
RemoveActions => RemoveAction
|
||||
RemoveActions => RemoveAction , RemoveActions
|
||||
"""
|
||||
|
||||
|
||||
class UpdateExpressionRemoveAction(UpdateExpressionClause):
|
||||
"""
|
||||
RemoveAction => Path
|
||||
"""
|
||||
|
||||
|
||||
class UpdateExpressionAddActions(UpdateExpressionClause):
|
||||
"""
|
||||
UpdateExpressionAddClause => ADD RemoveActions
|
||||
|
||||
AddActions => AddAction
|
||||
AddActions => AddAction , AddActions
|
||||
"""
|
||||
|
||||
|
||||
class UpdateExpressionAddAction(UpdateExpressionClause):
|
||||
"""
|
||||
AddAction => Path Value
|
||||
"""
|
||||
|
||||
|
||||
class UpdateExpressionDeleteActions(UpdateExpressionClause):
|
||||
"""
|
||||
UpdateExpressionDeleteClause => DELETE RemoveActions
|
||||
|
||||
DeleteActions => DeleteAction
|
||||
DeleteActions => DeleteAction , DeleteActions
|
||||
"""
|
||||
|
||||
|
||||
class UpdateExpressionDeleteAction(UpdateExpressionClause):
|
||||
"""
|
||||
DeleteAction => Path Value
|
||||
"""
|
||||
|
||||
|
||||
class UpdateExpressionPath(UpdateExpressionClause):
|
||||
pass
|
||||
|
||||
|
||||
class UpdateExpressionValue(UpdateExpressionClause):
|
||||
"""
|
||||
Value => Operand
|
||||
Value => Operand + Value
|
||||
Value => Operand - Value
|
||||
"""
|
||||
|
||||
|
||||
class UpdateExpressionGroupedValue(UpdateExpressionClause):
|
||||
"""
|
||||
GroupedValue => ( Value )
|
||||
"""
|
||||
|
||||
|
||||
class UpdateExpressionRemoveClause(UpdateExpressionClause):
|
||||
"""
|
||||
UpdateExpressionRemoveClause => REMOVE RemoveActions
|
||||
"""
|
||||
|
||||
|
||||
class UpdateExpressionAddClause(UpdateExpressionClause):
|
||||
"""
|
||||
UpdateExpressionAddClause => ADD AddActions
|
||||
"""
|
||||
|
||||
|
||||
class UpdateExpressionDeleteClause(UpdateExpressionClause):
|
||||
"""
|
||||
UpdateExpressionDeleteClause => DELETE DeleteActions
|
||||
"""
|
||||
|
||||
|
||||
class ExpressionPathDescender(Node):
|
||||
"""Node identifying descender into nested structure (.) in expression"""
|
||||
|
||||
|
||||
class ExpressionSelector(LeafNode):
|
||||
"""Node identifying selector [selection_index] in expresion"""
|
||||
|
||||
def __init__(self, selection_index):
|
||||
try:
|
||||
super(ExpressionSelector, self).__init__(children=[int(selection_index)])
|
||||
except ValueError:
|
||||
assert (
|
||||
False
|
||||
), "Expression selector must be an int, this is a bug in the moto library."
|
||||
|
||||
def get_index(self):
|
||||
return self.children[0]
|
||||
|
||||
|
||||
class ExpressionAttribute(LeafNode):
|
||||
"""An attribute identifier as used in the DDB item"""
|
||||
|
||||
def __init__(self, attribute):
|
||||
super(ExpressionAttribute, self).__init__(children=[attribute])
|
||||
|
||||
def get_attribute_name(self):
|
||||
return self.children[0]
|
||||
|
||||
|
||||
class ExpressionAttributeName(LeafNode):
|
||||
"""An ExpressionAttributeName is an alias for an attribute identifier"""
|
||||
|
||||
def __init__(self, attribute_name):
|
||||
super(ExpressionAttributeName, self).__init__(children=[attribute_name])
|
||||
|
||||
def get_attribute_name_placeholder(self):
|
||||
return self.children[0]
|
||||
|
||||
|
||||
class ExpressionAttributeValue(LeafNode):
|
||||
"""An ExpressionAttributeValue is an alias for an value"""
|
||||
|
||||
def __init__(self, value):
|
||||
super(ExpressionAttributeValue, self).__init__(children=[value])
|
||||
|
||||
def get_value_name(self):
|
||||
return self.children[0]
|
||||
|
||||
|
||||
class ExpressionValueOperator(LeafNode):
|
||||
"""An ExpressionValueOperator is an operation that works on 2 values"""
|
||||
|
||||
def __init__(self, value):
|
||||
super(ExpressionValueOperator, self).__init__(children=[value])
|
||||
|
||||
def get_operator(self):
|
||||
return self.children[0]
|
||||
|
||||
|
||||
class UpdateExpressionFunction(Node):
|
||||
"""
|
||||
A Node representing a function of an Update Expression. The first child is the function name the others are the
|
||||
arguments.
|
||||
"""
|
||||
|
||||
def get_function_name(self):
|
||||
return self.children[0]
|
||||
|
||||
def get_nth_argument(self, n=1):
|
||||
"""Return nth element where n is a 1-based index."""
|
||||
assert n >= 1
|
||||
return self.children[n]
|
||||
|
||||
|
||||
class DDBTypedValue(Node):
|
||||
"""
|
||||
A node representing a DDBTyped value. This can be any structure as supported by DyanmoDB. The node only has 1 child
|
||||
which is the value of type `DynamoType`.
|
||||
"""
|
||||
|
||||
def __init__(self, value):
|
||||
assert isinstance(value, DynamoType), "DDBTypedValue must be of DynamoType"
|
||||
super(DDBTypedValue, self).__init__(children=[value])
|
||||
|
||||
def get_value(self):
|
||||
return self.children[0]
|
||||
|
||||
|
||||
class NoneExistingPath(LeafNode):
|
||||
"""A placeholder for Paths that did not exist in the Item."""
|
||||
|
||||
def __init__(self, creatable=False):
|
||||
super(NoneExistingPath, self).__init__(children=[creatable])
|
||||
|
||||
def is_creatable(self):
|
||||
"""Can this path be created if need be. For example path creating element in a dictionary or creating a new
|
||||
attribute under root level of an item."""
|
||||
return self.children[0]
|
||||
|
||||
|
||||
class DepthFirstTraverser(object):
|
||||
"""
|
||||
Helper class that allows depth first traversal and to implement custom processing for certain AST nodes. The
|
||||
processor of a node must return the new resulting node. This node will be placed in the tree. Processing of a
|
||||
node using this traverser should therefore only transform child nodes. The returned node will get the same parent
|
||||
as the node before processing had.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def _processing_map(self):
|
||||
"""
|
||||
A map providing a processing function per node class type to a function that takes in a Node object and
|
||||
processes it. A Node can only be processed by a single function and they are considered in order. Therefore if
|
||||
multiple classes from a single class hierarchy strain are used the more specific classes have to be put before
|
||||
the less specific ones. That requires overriding `nodes_to_be_processed`. If no multiple classes form a single
|
||||
class hierarchy strain are used the default implementation of `nodes_to_be_processed` should be OK.
|
||||
Returns:
|
||||
dict: Mapping a Node Class to a processing function.
|
||||
"""
|
||||
pass
|
||||
|
||||
def nodes_to_be_processed(self):
|
||||
"""Cached accessor for getting Node types that need to be processed."""
|
||||
return tuple(k for k in self._processing_map().keys())
|
||||
|
||||
def process(self, node):
|
||||
"""Process a Node"""
|
||||
for class_key, processor in self._processing_map().items():
|
||||
if isinstance(node, class_key):
|
||||
return processor(node)
|
||||
|
||||
def pre_processing_of_child(self, parent_node, child_id):
|
||||
"""Hook that is called pre-processing of the child at position `child_id`"""
|
||||
pass
|
||||
|
||||
def traverse_node_recursively(self, node, child_id=-1):
|
||||
"""
|
||||
Traverse nodes depth first processing nodes bottom up (if root node is considered the top).
|
||||
|
||||
Args:
|
||||
node(Node): The node which is the last node to be processed but which allows to identify all the
|
||||
work (which is in the children)
|
||||
child_id(int): The index in the list of children from the parent that this node corresponds to
|
||||
|
||||
Returns:
|
||||
Node: The node of the new processed AST
|
||||
"""
|
||||
if isinstance(node, Node):
|
||||
parent_node = node.parent
|
||||
if node.children is not None:
|
||||
for i, child_node in enumerate(node.children):
|
||||
self.pre_processing_of_child(node, i)
|
||||
self.traverse_node_recursively(child_node, i)
|
||||
# noinspection PyTypeChecker
|
||||
if isinstance(node, self.nodes_to_be_processed()):
|
||||
node = self.process(node)
|
||||
node.parent = parent_node
|
||||
parent_node.children[child_id] = node
|
||||
return node
|
||||
|
||||
def traverse(self, node):
|
||||
return self.traverse_node_recursively(node)
|
||||
|
||||
|
||||
class NodeDepthLeftTypeFetcher(object):
|
||||
"""Helper class to fetch a node of a specific type. Depth left-first traversal"""
|
||||
|
||||
def __init__(self, node_type, root_node):
|
||||
assert issubclass(node_type, Node)
|
||||
self.node_type = node_type
|
||||
self.root_node = root_node
|
||||
self.queue = deque()
|
||||
self.add_nodes_left_to_right_depth_first(self.root_node)
|
||||
|
||||
def add_nodes_left_to_right_depth_first(self, node):
|
||||
if isinstance(node, Node) and node.children is not None:
|
||||
for child_node in node.children:
|
||||
self.add_nodes_left_to_right_depth_first(child_node)
|
||||
self.queue.append(child_node)
|
||||
self.queue.append(node)
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def next(self):
|
||||
return self.__next__()
|
||||
|
||||
def __next__(self):
|
||||
while len(self.queue) > 0:
|
||||
candidate = self.queue.popleft()
|
||||
if isinstance(candidate, self.node_type):
|
||||
return candidate
|
||||
else:
|
||||
raise StopIteration
|
||||
1040
moto/dynamodb2/parsing/expressions.py
Normal file
1040
moto/dynamodb2/parsing/expressions.py
Normal file
File diff suppressed because it is too large
Load diff
29
moto/dynamodb2/parsing/reserved_keywords.py
Normal file
29
moto/dynamodb2/parsing/reserved_keywords.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
class ReservedKeywords(list):
|
||||
"""
|
||||
DynamoDB has an extensive list of keywords. Keywords are considered when validating the expression Tree.
|
||||
Not earlier since an update expression like "SET path = VALUE 1" fails with:
|
||||
'Invalid UpdateExpression: Syntax error; token: "1", near: "VALUE 1"'
|
||||
"""
|
||||
|
||||
KEYWORDS = None
|
||||
|
||||
@classmethod
|
||||
def get_reserved_keywords(cls):
|
||||
if cls.KEYWORDS is None:
|
||||
cls.KEYWORDS = cls._get_reserved_keywords()
|
||||
return cls.KEYWORDS
|
||||
|
||||
@classmethod
|
||||
def _get_reserved_keywords(cls):
|
||||
"""
|
||||
Get a list of reserved keywords of DynamoDB
|
||||
"""
|
||||
try:
|
||||
import importlib.resources as pkg_resources
|
||||
except ImportError:
|
||||
import importlib_resources as pkg_resources
|
||||
|
||||
reserved_keywords = pkg_resources.read_text(
|
||||
"moto.dynamodb2.parsing", "reserved_keywords.txt"
|
||||
)
|
||||
return reserved_keywords.split()
|
||||
573
moto/dynamodb2/parsing/reserved_keywords.txt
Normal file
573
moto/dynamodb2/parsing/reserved_keywords.txt
Normal file
|
|
@ -0,0 +1,573 @@
|
|||
ABORT
|
||||
ABSOLUTE
|
||||
ACTION
|
||||
ADD
|
||||
AFTER
|
||||
AGENT
|
||||
AGGREGATE
|
||||
ALL
|
||||
ALLOCATE
|
||||
ALTER
|
||||
ANALYZE
|
||||
AND
|
||||
ANY
|
||||
ARCHIVE
|
||||
ARE
|
||||
ARRAY
|
||||
AS
|
||||
ASC
|
||||
ASCII
|
||||
ASENSITIVE
|
||||
ASSERTION
|
||||
ASYMMETRIC
|
||||
AT
|
||||
ATOMIC
|
||||
ATTACH
|
||||
ATTRIBUTE
|
||||
AUTH
|
||||
AUTHORIZATION
|
||||
AUTHORIZE
|
||||
AUTO
|
||||
AVG
|
||||
BACK
|
||||
BACKUP
|
||||
BASE
|
||||
BATCH
|
||||
BEFORE
|
||||
BEGIN
|
||||
BETWEEN
|
||||
BIGINT
|
||||
BINARY
|
||||
BIT
|
||||
BLOB
|
||||
BLOCK
|
||||
BOOLEAN
|
||||
BOTH
|
||||
BREADTH
|
||||
BUCKET
|
||||
BULK
|
||||
BY
|
||||
BYTE
|
||||
CALL
|
||||
CALLED
|
||||
CALLING
|
||||
CAPACITY
|
||||
CASCADE
|
||||
CASCADED
|
||||
CASE
|
||||
CAST
|
||||
CATALOG
|
||||
CHAR
|
||||
CHARACTER
|
||||
CHECK
|
||||
CLASS
|
||||
CLOB
|
||||
CLOSE
|
||||
CLUSTER
|
||||
CLUSTERED
|
||||
CLUSTERING
|
||||
CLUSTERS
|
||||
COALESCE
|
||||
COLLATE
|
||||
COLLATION
|
||||
COLLECTION
|
||||
COLUMN
|
||||
COLUMNS
|
||||
COMBINE
|
||||
COMMENT
|
||||
COMMIT
|
||||
COMPACT
|
||||
COMPILE
|
||||
COMPRESS
|
||||
CONDITION
|
||||
CONFLICT
|
||||
CONNECT
|
||||
CONNECTION
|
||||
CONSISTENCY
|
||||
CONSISTENT
|
||||
CONSTRAINT
|
||||
CONSTRAINTS
|
||||
CONSTRUCTOR
|
||||
CONSUMED
|
||||
CONTINUE
|
||||
CONVERT
|
||||
COPY
|
||||
CORRESPONDING
|
||||
COUNT
|
||||
COUNTER
|
||||
CREATE
|
||||
CROSS
|
||||
CUBE
|
||||
CURRENT
|
||||
CURSOR
|
||||
CYCLE
|
||||
DATA
|
||||
DATABASE
|
||||
DATE
|
||||
DATETIME
|
||||
DAY
|
||||
DEALLOCATE
|
||||
DEC
|
||||
DECIMAL
|
||||
DECLARE
|
||||
DEFAULT
|
||||
DEFERRABLE
|
||||
DEFERRED
|
||||
DEFINE
|
||||
DEFINED
|
||||
DEFINITION
|
||||
DELETE
|
||||
DELIMITED
|
||||
DEPTH
|
||||
DEREF
|
||||
DESC
|
||||
DESCRIBE
|
||||
DESCRIPTOR
|
||||
DETACH
|
||||
DETERMINISTIC
|
||||
DIAGNOSTICS
|
||||
DIRECTORIES
|
||||
DISABLE
|
||||
DISCONNECT
|
||||
DISTINCT
|
||||
DISTRIBUTE
|
||||
DO
|
||||
DOMAIN
|
||||
DOUBLE
|
||||
DROP
|
||||
DUMP
|
||||
DURATION
|
||||
DYNAMIC
|
||||
EACH
|
||||
ELEMENT
|
||||
ELSE
|
||||
ELSEIF
|
||||
EMPTY
|
||||
ENABLE
|
||||
END
|
||||
EQUAL
|
||||
EQUALS
|
||||
ERROR
|
||||
ESCAPE
|
||||
ESCAPED
|
||||
EVAL
|
||||
EVALUATE
|
||||
EXCEEDED
|
||||
EXCEPT
|
||||
EXCEPTION
|
||||
EXCEPTIONS
|
||||
EXCLUSIVE
|
||||
EXEC
|
||||
EXECUTE
|
||||
EXISTS
|
||||
EXIT
|
||||
EXPLAIN
|
||||
EXPLODE
|
||||
EXPORT
|
||||
EXPRESSION
|
||||
EXTENDED
|
||||
EXTERNAL
|
||||
EXTRACT
|
||||
FAIL
|
||||
FALSE
|
||||
FAMILY
|
||||
FETCH
|
||||
FIELDS
|
||||
FILE
|
||||
FILTER
|
||||
FILTERING
|
||||
FINAL
|
||||
FINISH
|
||||
FIRST
|
||||
FIXED
|
||||
FLATTERN
|
||||
FLOAT
|
||||
FOR
|
||||
FORCE
|
||||
FOREIGN
|
||||
FORMAT
|
||||
FORWARD
|
||||
FOUND
|
||||
FREE
|
||||
FROM
|
||||
FULL
|
||||
FUNCTION
|
||||
FUNCTIONS
|
||||
GENERAL
|
||||
GENERATE
|
||||
GET
|
||||
GLOB
|
||||
GLOBAL
|
||||
GO
|
||||
GOTO
|
||||
GRANT
|
||||
GREATER
|
||||
GROUP
|
||||
GROUPING
|
||||
HANDLER
|
||||
HASH
|
||||
HAVE
|
||||
HAVING
|
||||
HEAP
|
||||
HIDDEN
|
||||
HOLD
|
||||
HOUR
|
||||
IDENTIFIED
|
||||
IDENTITY
|
||||
IF
|
||||
IGNORE
|
||||
IMMEDIATE
|
||||
IMPORT
|
||||
IN
|
||||
INCLUDING
|
||||
INCLUSIVE
|
||||
INCREMENT
|
||||
INCREMENTAL
|
||||
INDEX
|
||||
INDEXED
|
||||
INDEXES
|
||||
INDICATOR
|
||||
INFINITE
|
||||
INITIALLY
|
||||
INLINE
|
||||
INNER
|
||||
INNTER
|
||||
INOUT
|
||||
INPUT
|
||||
INSENSITIVE
|
||||
INSERT
|
||||
INSTEAD
|
||||
INT
|
||||
INTEGER
|
||||
INTERSECT
|
||||
INTERVAL
|
||||
INTO
|
||||
INVALIDATE
|
||||
IS
|
||||
ISOLATION
|
||||
ITEM
|
||||
ITEMS
|
||||
ITERATE
|
||||
JOIN
|
||||
KEY
|
||||
KEYS
|
||||
LAG
|
||||
LANGUAGE
|
||||
LARGE
|
||||
LAST
|
||||
LATERAL
|
||||
LEAD
|
||||
LEADING
|
||||
LEAVE
|
||||
LEFT
|
||||
LENGTH
|
||||
LESS
|
||||
LEVEL
|
||||
LIKE
|
||||
LIMIT
|
||||
LIMITED
|
||||
LINES
|
||||
LIST
|
||||
LOAD
|
||||
LOCAL
|
||||
LOCALTIME
|
||||
LOCALTIMESTAMP
|
||||
LOCATION
|
||||
LOCATOR
|
||||
LOCK
|
||||
LOCKS
|
||||
LOG
|
||||
LOGED
|
||||
LONG
|
||||
LOOP
|
||||
LOWER
|
||||
MAP
|
||||
MATCH
|
||||
MATERIALIZED
|
||||
MAX
|
||||
MAXLEN
|
||||
MEMBER
|
||||
MERGE
|
||||
METHOD
|
||||
METRICS
|
||||
MIN
|
||||
MINUS
|
||||
MINUTE
|
||||
MISSING
|
||||
MOD
|
||||
MODE
|
||||
MODIFIES
|
||||
MODIFY
|
||||
MODULE
|
||||
MONTH
|
||||
MULTI
|
||||
MULTISET
|
||||
NAME
|
||||
NAMES
|
||||
NATIONAL
|
||||
NATURAL
|
||||
NCHAR
|
||||
NCLOB
|
||||
NEW
|
||||
NEXT
|
||||
NO
|
||||
NONE
|
||||
NOT
|
||||
NULL
|
||||
NULLIF
|
||||
NUMBER
|
||||
NUMERIC
|
||||
OBJECT
|
||||
OF
|
||||
OFFLINE
|
||||
OFFSET
|
||||
OLD
|
||||
ON
|
||||
ONLINE
|
||||
ONLY
|
||||
OPAQUE
|
||||
OPEN
|
||||
OPERATOR
|
||||
OPTION
|
||||
OR
|
||||
ORDER
|
||||
ORDINALITY
|
||||
OTHER
|
||||
OTHERS
|
||||
OUT
|
||||
OUTER
|
||||
OUTPUT
|
||||
OVER
|
||||
OVERLAPS
|
||||
OVERRIDE
|
||||
OWNER
|
||||
PAD
|
||||
PARALLEL
|
||||
PARAMETER
|
||||
PARAMETERS
|
||||
PARTIAL
|
||||
PARTITION
|
||||
PARTITIONED
|
||||
PARTITIONS
|
||||
PATH
|
||||
PERCENT
|
||||
PERCENTILE
|
||||
PERMISSION
|
||||
PERMISSIONS
|
||||
PIPE
|
||||
PIPELINED
|
||||
PLAN
|
||||
POOL
|
||||
POSITION
|
||||
PRECISION
|
||||
PREPARE
|
||||
PRESERVE
|
||||
PRIMARY
|
||||
PRIOR
|
||||
PRIVATE
|
||||
PRIVILEGES
|
||||
PROCEDURE
|
||||
PROCESSED
|
||||
PROJECT
|
||||
PROJECTION
|
||||
PROPERTY
|
||||
PROVISIONING
|
||||
PUBLIC
|
||||
PUT
|
||||
QUERY
|
||||
QUIT
|
||||
QUORUM
|
||||
RAISE
|
||||
RANDOM
|
||||
RANGE
|
||||
RANK
|
||||
RAW
|
||||
READ
|
||||
READS
|
||||
REAL
|
||||
REBUILD
|
||||
RECORD
|
||||
RECURSIVE
|
||||
REDUCE
|
||||
REF
|
||||
REFERENCE
|
||||
REFERENCES
|
||||
REFERENCING
|
||||
REGEXP
|
||||
REGION
|
||||
REINDEX
|
||||
RELATIVE
|
||||
RELEASE
|
||||
REMAINDER
|
||||
RENAME
|
||||
REPEAT
|
||||
REPLACE
|
||||
REQUEST
|
||||
RESET
|
||||
RESIGNAL
|
||||
RESOURCE
|
||||
RESPONSE
|
||||
RESTORE
|
||||
RESTRICT
|
||||
RESULT
|
||||
RETURN
|
||||
RETURNING
|
||||
RETURNS
|
||||
REVERSE
|
||||
REVOKE
|
||||
RIGHT
|
||||
ROLE
|
||||
ROLES
|
||||
ROLLBACK
|
||||
ROLLUP
|
||||
ROUTINE
|
||||
ROW
|
||||
ROWS
|
||||
RULE
|
||||
RULES
|
||||
SAMPLE
|
||||
SATISFIES
|
||||
SAVE
|
||||
SAVEPOINT
|
||||
SCAN
|
||||
SCHEMA
|
||||
SCOPE
|
||||
SCROLL
|
||||
SEARCH
|
||||
SECOND
|
||||
SECTION
|
||||
SEGMENT
|
||||
SEGMENTS
|
||||
SELECT
|
||||
SELF
|
||||
SEMI
|
||||
SENSITIVE
|
||||
SEPARATE
|
||||
SEQUENCE
|
||||
SERIALIZABLE
|
||||
SESSION
|
||||
SET
|
||||
SETS
|
||||
SHARD
|
||||
SHARE
|
||||
SHARED
|
||||
SHORT
|
||||
SHOW
|
||||
SIGNAL
|
||||
SIMILAR
|
||||
SIZE
|
||||
SKEWED
|
||||
SMALLINT
|
||||
SNAPSHOT
|
||||
SOME
|
||||
SOURCE
|
||||
SPACE
|
||||
SPACES
|
||||
SPARSE
|
||||
SPECIFIC
|
||||
SPECIFICTYPE
|
||||
SPLIT
|
||||
SQL
|
||||
SQLCODE
|
||||
SQLERROR
|
||||
SQLEXCEPTION
|
||||
SQLSTATE
|
||||
SQLWARNING
|
||||
START
|
||||
STATE
|
||||
STATIC
|
||||
STATUS
|
||||
STORAGE
|
||||
STORE
|
||||
STORED
|
||||
STREAM
|
||||
STRING
|
||||
STRUCT
|
||||
STYLE
|
||||
SUB
|
||||
SUBMULTISET
|
||||
SUBPARTITION
|
||||
SUBSTRING
|
||||
SUBTYPE
|
||||
SUM
|
||||
SUPER
|
||||
SYMMETRIC
|
||||
SYNONYM
|
||||
SYSTEM
|
||||
TABLE
|
||||
TABLESAMPLE
|
||||
TEMP
|
||||
TEMPORARY
|
||||
TERMINATED
|
||||
TEXT
|
||||
THAN
|
||||
THEN
|
||||
THROUGHPUT
|
||||
TIME
|
||||
TIMESTAMP
|
||||
TIMEZONE
|
||||
TINYINT
|
||||
TO
|
||||
TOKEN
|
||||
TOTAL
|
||||
TOUCH
|
||||
TRAILING
|
||||
TRANSACTION
|
||||
TRANSFORM
|
||||
TRANSLATE
|
||||
TRANSLATION
|
||||
TREAT
|
||||
TRIGGER
|
||||
TRIM
|
||||
TRUE
|
||||
TRUNCATE
|
||||
TTL
|
||||
TUPLE
|
||||
TYPE
|
||||
UNDER
|
||||
UNDO
|
||||
UNION
|
||||
UNIQUE
|
||||
UNIT
|
||||
UNKNOWN
|
||||
UNLOGGED
|
||||
UNNEST
|
||||
UNPROCESSED
|
||||
UNSIGNED
|
||||
UNTIL
|
||||
UPDATE
|
||||
UPPER
|
||||
URL
|
||||
USAGE
|
||||
USE
|
||||
USER
|
||||
USERS
|
||||
USING
|
||||
UUID
|
||||
VACUUM
|
||||
VALUE
|
||||
VALUED
|
||||
VALUES
|
||||
VARCHAR
|
||||
VARIABLE
|
||||
VARIANCE
|
||||
VARINT
|
||||
VARYING
|
||||
VIEW
|
||||
VIEWS
|
||||
VIRTUAL
|
||||
VOID
|
||||
WAIT
|
||||
WHEN
|
||||
WHENEVER
|
||||
WHERE
|
||||
WHILE
|
||||
WINDOW
|
||||
WITH
|
||||
WITHIN
|
||||
WITHOUT
|
||||
WORK
|
||||
WRAPPED
|
||||
WRITE
|
||||
YEAR
|
||||
ZONE
|
||||
223
moto/dynamodb2/parsing/tokens.py
Normal file
223
moto/dynamodb2/parsing/tokens.py
Normal file
|
|
@ -0,0 +1,223 @@
|
|||
import re
|
||||
import sys
|
||||
|
||||
from moto.dynamodb2.exceptions import (
|
||||
InvalidTokenException,
|
||||
InvalidExpressionAttributeNameKey,
|
||||
)
|
||||
|
||||
|
||||
class Token(object):
|
||||
_TOKEN_INSTANCE = None
|
||||
MINUS_SIGN = "-"
|
||||
PLUS_SIGN = "+"
|
||||
SPACE_SIGN = " "
|
||||
EQUAL_SIGN = "="
|
||||
OPEN_ROUND_BRACKET = "("
|
||||
CLOSE_ROUND_BRACKET = ")"
|
||||
COMMA = ","
|
||||
SPACE = " "
|
||||
DOT = "."
|
||||
OPEN_SQUARE_BRACKET = "["
|
||||
CLOSE_SQUARE_BRACKET = "]"
|
||||
|
||||
SPECIAL_CHARACTERS = [
|
||||
MINUS_SIGN,
|
||||
PLUS_SIGN,
|
||||
SPACE_SIGN,
|
||||
EQUAL_SIGN,
|
||||
OPEN_ROUND_BRACKET,
|
||||
CLOSE_ROUND_BRACKET,
|
||||
COMMA,
|
||||
SPACE,
|
||||
DOT,
|
||||
OPEN_SQUARE_BRACKET,
|
||||
CLOSE_SQUARE_BRACKET,
|
||||
]
|
||||
|
||||
# Attribute: an identifier that is an attribute
|
||||
ATTRIBUTE = 0
|
||||
# Place holder for attribute name
|
||||
ATTRIBUTE_NAME = 1
|
||||
# Placeholder for attribute value starts with :
|
||||
ATTRIBUTE_VALUE = 2
|
||||
# WhiteSpace shall be grouped together
|
||||
WHITESPACE = 3
|
||||
# Placeholder for a number
|
||||
NUMBER = 4
|
||||
|
||||
PLACEHOLDER_NAMES = {
|
||||
ATTRIBUTE: "Attribute",
|
||||
ATTRIBUTE_NAME: "AttributeName",
|
||||
ATTRIBUTE_VALUE: "AttributeValue",
|
||||
WHITESPACE: "Whitespace",
|
||||
NUMBER: "Number",
|
||||
}
|
||||
|
||||
def __init__(self, token_type, value):
|
||||
assert (
|
||||
token_type in self.SPECIAL_CHARACTERS
|
||||
or token_type in self.PLACEHOLDER_NAMES
|
||||
)
|
||||
self.type = token_type
|
||||
self.value = value
|
||||
|
||||
def __repr__(self):
|
||||
if isinstance(self.type, int):
|
||||
return 'Token("{tt}", "{tv}")'.format(
|
||||
tt=self.PLACEHOLDER_NAMES[self.type], tv=self.value
|
||||
)
|
||||
else:
|
||||
return 'Token("{tt}", "{tv}")'.format(tt=self.type, tv=self.value)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.type == other.type and self.value == other.value
|
||||
|
||||
|
||||
class ExpressionTokenizer(object):
|
||||
"""
|
||||
Takes a string and returns a list of tokens. While attribute names in DynamoDB must be between 1 and 255 characters
|
||||
long there are no other restrictions for attribute names. For expressions however there are additional rules. If an
|
||||
attribute name does not adhere then it must be passed via an ExpressionAttributeName. This tokenizer is aware of the
|
||||
rules of Expression attributes.
|
||||
|
||||
We consider a Token as a tuple which has the tokenType
|
||||
|
||||
From https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Expressions.ExpressionAttributeNames.html
|
||||
1) If an attribute name begins with a number or contains a space, a special character, or a reserved word, you
|
||||
must use an expression attribute name to replace that attribute's name in the expression.
|
||||
=> So spaces,+,- or other special characters do identify tokens in update expressions
|
||||
|
||||
2) When using a dot (.) in an attribute name you must use expression-attribute-names. A dot in an expression
|
||||
will be interpreted as a separator in a document path
|
||||
|
||||
3) For a nested structure if you want to use expression_attribute_names you must specify one per part of the
|
||||
path. Since for members of expression_attribute_names the . is part of the name
|
||||
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def is_simple_token_character(cls, character):
|
||||
return character.isalnum() or character in ("_", ":", "#")
|
||||
|
||||
@classmethod
|
||||
def is_possible_token_boundary(cls, character):
|
||||
return (
|
||||
character in Token.SPECIAL_CHARACTERS
|
||||
or not cls.is_simple_token_character(character)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_expression_attribute(cls, input_string):
|
||||
return re.compile("^[a-zA-Z][a-zA-Z0-9_]*$").match(input_string) is not None
|
||||
|
||||
@classmethod
|
||||
def is_expression_attribute_name(cls, input_string):
|
||||
"""
|
||||
https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Expressions.ExpressionAttributeNames.html
|
||||
An expression attribute name must begin with a pound sign (#), and be followed by one or more alphanumeric
|
||||
characters.
|
||||
"""
|
||||
return input_string.startswith("#") and cls.is_expression_attribute(
|
||||
input_string[1:]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_expression_attribute_value(cls, input_string):
|
||||
return re.compile("^:[a-zA-Z0-9_]*$").match(input_string) is not None
|
||||
|
||||
def raise_unexpected_token(self):
|
||||
"""If during parsing an unexpected token is encountered"""
|
||||
if len(self.token_list) == 0:
|
||||
near = ""
|
||||
else:
|
||||
if len(self.token_list) == 1:
|
||||
near = self.token_list[-1].value
|
||||
else:
|
||||
if self.token_list[-1].type == Token.WHITESPACE:
|
||||
# Last token was whitespace take 2nd last token value as well to help User orientate
|
||||
near = self.token_list[-2].value + self.token_list[-1].value
|
||||
else:
|
||||
near = self.token_list[-1].value
|
||||
|
||||
problematic_token = self.staged_characters[0]
|
||||
raise InvalidTokenException(problematic_token, near + self.staged_characters)
|
||||
|
||||
def __init__(self, input_expression_str):
|
||||
self.input_expression_str = input_expression_str
|
||||
self.token_list = []
|
||||
self.staged_characters = ""
|
||||
|
||||
@classmethod
|
||||
def is_py2(cls):
|
||||
return sys.version_info[0] == 2
|
||||
|
||||
@classmethod
|
||||
def make_list(cls, input_expression_str):
|
||||
if cls.is_py2():
|
||||
pass
|
||||
else:
|
||||
assert isinstance(input_expression_str, str)
|
||||
|
||||
return ExpressionTokenizer(input_expression_str)._make_list()
|
||||
|
||||
def add_token(self, token_type, token_value):
|
||||
self.token_list.append(Token(token_type, token_value))
|
||||
|
||||
def add_token_from_stage(self, token_type):
|
||||
self.add_token(token_type, self.staged_characters)
|
||||
self.staged_characters = ""
|
||||
|
||||
@classmethod
|
||||
def is_numeric(cls, input_str):
|
||||
return re.compile("[0-9]+").match(input_str) is not None
|
||||
|
||||
def process_staged_characters(self):
|
||||
if len(self.staged_characters) == 0:
|
||||
return
|
||||
if self.staged_characters.startswith("#"):
|
||||
if self.is_expression_attribute_name(self.staged_characters):
|
||||
self.add_token_from_stage(Token.ATTRIBUTE_NAME)
|
||||
else:
|
||||
raise InvalidExpressionAttributeNameKey(self.staged_characters)
|
||||
elif self.is_numeric(self.staged_characters):
|
||||
self.add_token_from_stage(Token.NUMBER)
|
||||
elif self.is_expression_attribute(self.staged_characters):
|
||||
self.add_token_from_stage(Token.ATTRIBUTE)
|
||||
elif self.is_expression_attribute_value(self.staged_characters):
|
||||
self.add_token_from_stage(Token.ATTRIBUTE_VALUE)
|
||||
else:
|
||||
self.raise_unexpected_token()
|
||||
|
||||
def _make_list(self):
|
||||
"""
|
||||
Just go through characters if a character is not a token boundary stage it for adding it as a grouped token
|
||||
later if it is a tokenboundary process staged characters and then process the token boundary as well.
|
||||
"""
|
||||
for character in self.input_expression_str:
|
||||
if not self.is_possible_token_boundary(character):
|
||||
self.staged_characters += character
|
||||
else:
|
||||
self.process_staged_characters()
|
||||
|
||||
if character == Token.SPACE:
|
||||
if (
|
||||
len(self.token_list) > 0
|
||||
and self.token_list[-1].type == Token.WHITESPACE
|
||||
):
|
||||
self.token_list[-1].value = (
|
||||
self.token_list[-1].value + character
|
||||
)
|
||||
else:
|
||||
self.add_token(Token.WHITESPACE, character)
|
||||
elif character in Token.SPECIAL_CHARACTERS:
|
||||
self.add_token(character, character)
|
||||
elif not self.is_simple_token_character(character):
|
||||
self.staged_characters += character
|
||||
self.raise_unexpected_token()
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Encountered character which was not implemented : " + character
|
||||
)
|
||||
self.process_staged_characters()
|
||||
return self.token_list
|
||||
341
moto/dynamodb2/parsing/validators.py
Normal file
341
moto/dynamodb2/parsing/validators.py
Normal file
|
|
@ -0,0 +1,341 @@
|
|||
"""
|
||||
See docstring class Validator below for more details on validation
|
||||
"""
|
||||
from abc import abstractmethod
|
||||
from copy import deepcopy
|
||||
|
||||
from moto.dynamodb2.exceptions import (
|
||||
AttributeIsReservedKeyword,
|
||||
ExpressionAttributeValueNotDefined,
|
||||
AttributeDoesNotExist,
|
||||
ExpressionAttributeNameNotDefined,
|
||||
IncorrectOperandType,
|
||||
InvalidUpdateExpressionInvalidDocumentPath,
|
||||
)
|
||||
from moto.dynamodb2.models import DynamoType
|
||||
from moto.dynamodb2.parsing.ast_nodes import (
|
||||
ExpressionAttribute,
|
||||
UpdateExpressionPath,
|
||||
UpdateExpressionSetAction,
|
||||
UpdateExpressionAddAction,
|
||||
UpdateExpressionDeleteAction,
|
||||
UpdateExpressionRemoveAction,
|
||||
DDBTypedValue,
|
||||
ExpressionAttributeValue,
|
||||
ExpressionAttributeName,
|
||||
DepthFirstTraverser,
|
||||
NoneExistingPath,
|
||||
UpdateExpressionFunction,
|
||||
ExpressionPathDescender,
|
||||
UpdateExpressionValue,
|
||||
ExpressionValueOperator,
|
||||
ExpressionSelector,
|
||||
)
|
||||
from moto.dynamodb2.parsing.reserved_keywords import ReservedKeywords
|
||||
|
||||
|
||||
class ExpressionAttributeValueProcessor(DepthFirstTraverser):
|
||||
def __init__(self, expression_attribute_values):
|
||||
self.expression_attribute_values = expression_attribute_values
|
||||
|
||||
def _processing_map(self):
|
||||
return {
|
||||
ExpressionAttributeValue: self.replace_expression_attribute_value_with_value
|
||||
}
|
||||
|
||||
def replace_expression_attribute_value_with_value(self, node):
|
||||
"""A node representing an Expression Attribute Value. Resolve and replace value"""
|
||||
assert isinstance(node, ExpressionAttributeValue)
|
||||
attribute_value_name = node.get_value_name()
|
||||
try:
|
||||
target = self.expression_attribute_values[attribute_value_name]
|
||||
except KeyError:
|
||||
raise ExpressionAttributeValueNotDefined(
|
||||
attribute_value=attribute_value_name
|
||||
)
|
||||
return DDBTypedValue(DynamoType(target))
|
||||
|
||||
|
||||
class ExpressionAttributeResolvingProcessor(DepthFirstTraverser):
|
||||
def _processing_map(self):
|
||||
return {
|
||||
UpdateExpressionSetAction: self.disable_resolving,
|
||||
UpdateExpressionPath: self.process_expression_path_node,
|
||||
}
|
||||
|
||||
def __init__(self, expression_attribute_names, item):
|
||||
self.expression_attribute_names = expression_attribute_names
|
||||
self.item = item
|
||||
self.resolving = False
|
||||
|
||||
def pre_processing_of_child(self, parent_node, child_id):
|
||||
"""
|
||||
We have to enable resolving if we are processing a child of UpdateExpressionSetAction that is not first.
|
||||
Because first argument is path to be set, 2nd argument would be the value.
|
||||
"""
|
||||
if isinstance(
|
||||
parent_node,
|
||||
(
|
||||
UpdateExpressionSetAction,
|
||||
UpdateExpressionRemoveAction,
|
||||
UpdateExpressionDeleteAction,
|
||||
UpdateExpressionAddAction,
|
||||
),
|
||||
):
|
||||
if child_id == 0:
|
||||
self.resolving = False
|
||||
else:
|
||||
self.resolving = True
|
||||
|
||||
def disable_resolving(self, node=None):
|
||||
self.resolving = False
|
||||
return node
|
||||
|
||||
def process_expression_path_node(self, node):
|
||||
"""Resolve ExpressionAttribute if not part of a path and resolving is enabled."""
|
||||
if self.resolving:
|
||||
return self.resolve_expression_path(node)
|
||||
else:
|
||||
# Still resolve but return original note to make sure path is correct Just make sure nodes are creatable.
|
||||
result_node = self.resolve_expression_path(node)
|
||||
if (
|
||||
isinstance(result_node, NoneExistingPath)
|
||||
and not result_node.is_creatable()
|
||||
):
|
||||
raise InvalidUpdateExpressionInvalidDocumentPath()
|
||||
|
||||
return node
|
||||
|
||||
def resolve_expression_path(self, node):
|
||||
assert isinstance(node, UpdateExpressionPath)
|
||||
|
||||
target = deepcopy(self.item.attrs)
|
||||
for child in node.children:
|
||||
# First replace placeholder with attribute_name
|
||||
attr_name = None
|
||||
if isinstance(child, ExpressionAttributeName):
|
||||
attr_placeholder = child.get_attribute_name_placeholder()
|
||||
try:
|
||||
attr_name = self.expression_attribute_names[attr_placeholder]
|
||||
except KeyError:
|
||||
raise ExpressionAttributeNameNotDefined(attr_placeholder)
|
||||
elif isinstance(child, ExpressionAttribute):
|
||||
attr_name = child.get_attribute_name()
|
||||
self.raise_exception_if_keyword(attr_name)
|
||||
if attr_name is not None:
|
||||
# Resolv attribute_name
|
||||
try:
|
||||
target = target[attr_name]
|
||||
except (KeyError, TypeError):
|
||||
if child == node.children[-1]:
|
||||
return NoneExistingPath(creatable=True)
|
||||
return NoneExistingPath()
|
||||
else:
|
||||
if isinstance(child, ExpressionPathDescender):
|
||||
continue
|
||||
elif isinstance(child, ExpressionSelector):
|
||||
index = child.get_index()
|
||||
if target.is_list():
|
||||
try:
|
||||
target = target[index]
|
||||
except IndexError:
|
||||
# When a list goes out of bounds when assigning that is no problem when at the assignment
|
||||
# side. It will just append to the list.
|
||||
if child == node.children[-1]:
|
||||
return NoneExistingPath(creatable=True)
|
||||
return NoneExistingPath()
|
||||
else:
|
||||
raise InvalidUpdateExpressionInvalidDocumentPath
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Path resolution for {t}".format(t=type(child))
|
||||
)
|
||||
return DDBTypedValue(DynamoType(target))
|
||||
|
||||
@classmethod
|
||||
def raise_exception_if_keyword(cls, attribute):
|
||||
if attribute.upper() in ReservedKeywords.get_reserved_keywords():
|
||||
raise AttributeIsReservedKeyword(attribute)
|
||||
|
||||
|
||||
class UpdateExpressionFunctionEvaluator(DepthFirstTraverser):
|
||||
"""
|
||||
At time of writing there are only 2 functions for DDB UpdateExpressions. They both are specific to the SET
|
||||
expression as per the official AWS docs:
|
||||
https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/
|
||||
Expressions.UpdateExpressions.html#Expressions.UpdateExpressions.SET
|
||||
"""
|
||||
|
||||
def _processing_map(self):
|
||||
return {UpdateExpressionFunction: self.process_function}
|
||||
|
||||
def process_function(self, node):
|
||||
assert isinstance(node, UpdateExpressionFunction)
|
||||
function_name = node.get_function_name()
|
||||
first_arg = node.get_nth_argument(1)
|
||||
second_arg = node.get_nth_argument(2)
|
||||
|
||||
if function_name == "if_not_exists":
|
||||
if isinstance(first_arg, NoneExistingPath):
|
||||
result = second_arg
|
||||
else:
|
||||
result = first_arg
|
||||
assert isinstance(result, (DDBTypedValue, NoneExistingPath))
|
||||
return result
|
||||
elif function_name == "list_append":
|
||||
first_arg = self.get_list_from_ddb_typed_value(first_arg, function_name)
|
||||
second_arg = self.get_list_from_ddb_typed_value(second_arg, function_name)
|
||||
for list_element in second_arg.value:
|
||||
first_arg.value.append(list_element)
|
||||
return DDBTypedValue(first_arg)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Unsupported function for moto {name}".format(name=function_name)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_list_from_ddb_typed_value(cls, node, function_name):
|
||||
assert isinstance(node, DDBTypedValue)
|
||||
dynamo_value = node.get_value()
|
||||
assert isinstance(dynamo_value, DynamoType)
|
||||
if not dynamo_value.is_list():
|
||||
raise IncorrectOperandType(function_name, dynamo_value.type)
|
||||
return dynamo_value
|
||||
|
||||
|
||||
class NoneExistingPathChecker(DepthFirstTraverser):
|
||||
"""
|
||||
Pass through the AST and make sure there are no none-existing paths.
|
||||
"""
|
||||
|
||||
def _processing_map(self):
|
||||
return {NoneExistingPath: self.raise_none_existing_path}
|
||||
|
||||
def raise_none_existing_path(self, node):
|
||||
raise AttributeDoesNotExist
|
||||
|
||||
|
||||
class ExecuteOperations(DepthFirstTraverser):
|
||||
def _processing_map(self):
|
||||
return {UpdateExpressionValue: self.process_update_expression_value}
|
||||
|
||||
def process_update_expression_value(self, node):
|
||||
"""
|
||||
If an UpdateExpressionValue only has a single child the node will be replaced with the childe.
|
||||
Otherwise it has 3 children and the middle one is an ExpressionValueOperator which details how to combine them
|
||||
Args:
|
||||
node(Node):
|
||||
|
||||
Returns:
|
||||
Node: The resulting node of the operation if present or the child.
|
||||
"""
|
||||
assert isinstance(node, UpdateExpressionValue)
|
||||
if len(node.children) == 1:
|
||||
return node.children[0]
|
||||
elif len(node.children) == 3:
|
||||
operator_node = node.children[1]
|
||||
assert isinstance(operator_node, ExpressionValueOperator)
|
||||
operator = operator_node.get_operator()
|
||||
left_operand = self.get_dynamo_value_from_ddb_typed_value(node.children[0])
|
||||
right_operand = self.get_dynamo_value_from_ddb_typed_value(node.children[2])
|
||||
if operator == "+":
|
||||
return self.get_sum(left_operand, right_operand)
|
||||
elif operator == "-":
|
||||
return self.get_subtraction(left_operand, right_operand)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Moto does not support operator {operator}".format(
|
||||
operator=operator
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"UpdateExpressionValue only has implementations for 1 or 3 children."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_dynamo_value_from_ddb_typed_value(cls, node):
|
||||
assert isinstance(node, DDBTypedValue)
|
||||
dynamo_value = node.get_value()
|
||||
assert isinstance(dynamo_value, DynamoType)
|
||||
return dynamo_value
|
||||
|
||||
@classmethod
|
||||
def get_sum(cls, left_operand, right_operand):
|
||||
"""
|
||||
Args:
|
||||
left_operand(DynamoType):
|
||||
right_operand(DynamoType):
|
||||
|
||||
Returns:
|
||||
DDBTypedValue:
|
||||
"""
|
||||
try:
|
||||
return DDBTypedValue(left_operand + right_operand)
|
||||
except TypeError:
|
||||
raise IncorrectOperandType("+", left_operand.type)
|
||||
|
||||
@classmethod
|
||||
def get_subtraction(cls, left_operand, right_operand):
|
||||
"""
|
||||
Args:
|
||||
left_operand(DynamoType):
|
||||
right_operand(DynamoType):
|
||||
|
||||
Returns:
|
||||
DDBTypedValue:
|
||||
"""
|
||||
try:
|
||||
return DDBTypedValue(left_operand - right_operand)
|
||||
except TypeError:
|
||||
raise IncorrectOperandType("-", left_operand.type)
|
||||
|
||||
|
||||
class Validator(object):
|
||||
"""
|
||||
A validator is used to validate expressions which are passed in as an AST.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, expression, expression_attribute_names, expression_attribute_values, item
|
||||
):
|
||||
"""
|
||||
Besides validation the Validator should also replace referenced parts of an item which is cheapest upon
|
||||
validation.
|
||||
|
||||
Args:
|
||||
expression(Node): The root node of the AST representing the expression to be validated
|
||||
expression_attribute_names(ExpressionAttributeNames):
|
||||
expression_attribute_values(ExpressionAttributeValues):
|
||||
item(Item): The item which will be updated (pointed to by Key of update_item)
|
||||
"""
|
||||
self.expression_attribute_names = expression_attribute_names
|
||||
self.expression_attribute_values = expression_attribute_values
|
||||
self.item = item
|
||||
self.processors = self.get_ast_processors()
|
||||
self.node_to_validate = deepcopy(expression)
|
||||
|
||||
@abstractmethod
|
||||
def get_ast_processors(self):
|
||||
"""Get the different processors that go through the AST tree and processes the nodes."""
|
||||
|
||||
def validate(self):
|
||||
n = self.node_to_validate
|
||||
for processor in self.processors:
|
||||
n = processor.traverse(n)
|
||||
return n
|
||||
|
||||
|
||||
class UpdateExpressionValidator(Validator):
|
||||
def get_ast_processors(self):
|
||||
"""Get the different processors that go through the AST tree and processes the nodes."""
|
||||
processors = [
|
||||
ExpressionAttributeValueProcessor(self.expression_attribute_values),
|
||||
ExpressionAttributeResolvingProcessor(
|
||||
self.expression_attribute_names, self.item
|
||||
),
|
||||
UpdateExpressionFunctionEvaluator(),
|
||||
NoneExistingPathChecker(),
|
||||
ExecuteOperations(),
|
||||
]
|
||||
return processors
|
||||
|
|
@ -9,8 +9,8 @@ import six
|
|||
|
||||
from moto.core.responses import BaseResponse
|
||||
from moto.core.utils import camelcase_to_underscores, amzn_request_id
|
||||
from .exceptions import InvalidIndexNameError, InvalidUpdateExpression, ItemSizeTooLarge
|
||||
from .models import dynamodb_backends, dynamo_json_dump
|
||||
from .exceptions import InvalidIndexNameError, ItemSizeTooLarge, MockValidationException
|
||||
from moto.dynamodb2.models import dynamodb_backends, dynamo_json_dump
|
||||
|
||||
|
||||
TRANSACTION_MAX_ITEMS = 25
|
||||
|
|
@ -92,16 +92,24 @@ class DynamoHandler(BaseResponse):
|
|||
def list_tables(self):
|
||||
body = self.body
|
||||
limit = body.get("Limit", 100)
|
||||
if body.get("ExclusiveStartTableName"):
|
||||
last = body.get("ExclusiveStartTableName")
|
||||
start = list(self.dynamodb_backend.tables.keys()).index(last) + 1
|
||||
all_tables = list(self.dynamodb_backend.tables.keys())
|
||||
|
||||
exclusive_start_table_name = body.get("ExclusiveStartTableName")
|
||||
if exclusive_start_table_name:
|
||||
try:
|
||||
last_table_index = all_tables.index(exclusive_start_table_name)
|
||||
except ValueError:
|
||||
start = len(all_tables)
|
||||
else:
|
||||
start = last_table_index + 1
|
||||
else:
|
||||
start = 0
|
||||
all_tables = list(self.dynamodb_backend.tables.keys())
|
||||
|
||||
if limit:
|
||||
tables = all_tables[start : start + limit]
|
||||
else:
|
||||
tables = all_tables[start:]
|
||||
|
||||
response = {"TableNames": tables}
|
||||
if limit and len(all_tables) > start + limit:
|
||||
response["LastEvaluatedTableName"] = tables[-1]
|
||||
|
|
@ -298,7 +306,7 @@ class DynamoHandler(BaseResponse):
|
|||
)
|
||||
except ItemSizeTooLarge:
|
||||
er = "com.amazonaws.dynamodb.v20111205#ValidationException"
|
||||
return self.error(er, ItemSizeTooLarge.message)
|
||||
return self.error(er, ItemSizeTooLarge.item_size_too_large_msg)
|
||||
except KeyError as ke:
|
||||
er = "com.amazonaws.dynamodb.v20111205#ValidationException"
|
||||
return self.error(er, ke.args[0])
|
||||
|
|
@ -462,8 +470,10 @@ class DynamoHandler(BaseResponse):
|
|||
for k, v in six.iteritems(self.body.get("ExpressionAttributeNames", {}))
|
||||
)
|
||||
|
||||
if " AND " in key_condition_expression:
|
||||
expressions = key_condition_expression.split(" AND ", 1)
|
||||
if " and " in key_condition_expression.lower():
|
||||
expressions = re.split(
|
||||
" AND ", key_condition_expression, maxsplit=1, flags=re.IGNORECASE
|
||||
)
|
||||
|
||||
index_hash_key = [key for key in index if key["KeyType"] == "HASH"][0]
|
||||
hash_key_var = reverse_attribute_lookup.get(
|
||||
|
|
@ -748,11 +758,6 @@ class DynamoHandler(BaseResponse):
|
|||
expression_attribute_names = self.body.get("ExpressionAttributeNames", {})
|
||||
expression_attribute_values = self.body.get("ExpressionAttributeValues", {})
|
||||
|
||||
# Support spaces between operators in an update expression
|
||||
# E.g. `a = b + c` -> `a=b+c`
|
||||
if update_expression:
|
||||
update_expression = re.sub(r"\s*([=\+-])\s*", "\\1", update_expression)
|
||||
|
||||
try:
|
||||
item = self.dynamodb_backend.update_item(
|
||||
name,
|
||||
|
|
@ -764,15 +769,9 @@ class DynamoHandler(BaseResponse):
|
|||
expected=expected,
|
||||
condition_expression=condition_expression,
|
||||
)
|
||||
except InvalidUpdateExpression:
|
||||
except MockValidationException as mve:
|
||||
er = "com.amazonaws.dynamodb.v20111205#ValidationException"
|
||||
return self.error(
|
||||
er,
|
||||
"The document path provided in the update expression is invalid for update",
|
||||
)
|
||||
except ItemSizeTooLarge:
|
||||
er = "com.amazonaws.dynamodb.v20111205#ValidationException"
|
||||
return self.error(er, ItemSizeTooLarge.message)
|
||||
return self.error(er, mve.exception_msg)
|
||||
except ValueError:
|
||||
er = "com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException"
|
||||
return self.error(
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import base64
|
|||
from boto3 import Session
|
||||
|
||||
from moto.core import BaseBackend, BaseModel
|
||||
from moto.dynamodb2.models import dynamodb_backends
|
||||
from moto.dynamodb2.models import dynamodb_backends, DynamoJsonEncoder
|
||||
|
||||
|
||||
class ShardIterator(BaseModel):
|
||||
|
|
@ -137,7 +137,7 @@ class DynamoDBStreamsBackend(BaseBackend):
|
|||
|
||||
def get_records(self, iterator_arn, limit):
|
||||
shard_iterator = self.shard_iterators[iterator_arn]
|
||||
return json.dumps(shard_iterator.get(limit))
|
||||
return json.dumps(shard_iterator.get(limit), cls=DynamoJsonEncoder)
|
||||
|
||||
|
||||
dynamodbstreams_backends = {}
|
||||
|
|
|
|||
|
|
@ -231,6 +231,14 @@ class InvalidVolumeAttachmentError(EC2ClientError):
|
|||
)
|
||||
|
||||
|
||||
class VolumeInUseError(EC2ClientError):
|
||||
def __init__(self, volume_id, instance_id):
|
||||
super(VolumeInUseError, self).__init__(
|
||||
"VolumeInUse",
|
||||
"Volume {0} is currently attached to {1}".format(volume_id, instance_id),
|
||||
)
|
||||
|
||||
|
||||
class InvalidDomainError(EC2ClientError):
|
||||
def __init__(self, domain):
|
||||
super(InvalidDomainError, self).__init__(
|
||||
|
|
|
|||
|
|
@ -70,6 +70,7 @@ from .exceptions import (
|
|||
InvalidSubnetIdError,
|
||||
InvalidSubnetRangeError,
|
||||
InvalidVolumeIdError,
|
||||
VolumeInUseError,
|
||||
InvalidVolumeAttachmentError,
|
||||
InvalidVpcCidrBlockAssociationIdError,
|
||||
InvalidVPCPeeringConnectionIdError,
|
||||
|
|
@ -556,6 +557,10 @@ class Instance(TaggedEC2Resource, BotoInstance):
|
|||
# worst case we'll get IP address exaustion... rarely
|
||||
pass
|
||||
|
||||
def add_block_device(self, size, device_path):
|
||||
volume = self.ec2_backend.create_volume(size, self.region_name)
|
||||
self.ec2_backend.attach_volume(volume.id, self.id, device_path)
|
||||
|
||||
def setup_defaults(self):
|
||||
# Default have an instance with root volume should you not wish to
|
||||
# override with attach volume cmd.
|
||||
|
|
@ -563,9 +568,10 @@ class Instance(TaggedEC2Resource, BotoInstance):
|
|||
self.ec2_backend.attach_volume(volume.id, self.id, "/dev/sda1")
|
||||
|
||||
def teardown_defaults(self):
|
||||
volume_id = self.block_device_mapping["/dev/sda1"].volume_id
|
||||
self.ec2_backend.detach_volume(volume_id, self.id, "/dev/sda1")
|
||||
self.ec2_backend.delete_volume(volume_id)
|
||||
if "/dev/sda1" in self.block_device_mapping:
|
||||
volume_id = self.block_device_mapping["/dev/sda1"].volume_id
|
||||
self.ec2_backend.detach_volume(volume_id, self.id, "/dev/sda1")
|
||||
self.ec2_backend.delete_volume(volume_id)
|
||||
|
||||
@property
|
||||
def get_block_device_mapping(self):
|
||||
|
|
@ -620,6 +626,7 @@ class Instance(TaggedEC2Resource, BotoInstance):
|
|||
subnet_id=properties.get("SubnetId"),
|
||||
key_name=properties.get("KeyName"),
|
||||
private_ip=properties.get("PrivateIpAddress"),
|
||||
block_device_mappings=properties.get("BlockDeviceMappings", {}),
|
||||
)
|
||||
instance = reservation.instances[0]
|
||||
for tag in properties.get("Tags", []):
|
||||
|
|
@ -775,7 +782,14 @@ class Instance(TaggedEC2Resource, BotoInstance):
|
|||
if "SubnetId" in nic:
|
||||
subnet = self.ec2_backend.get_subnet(nic["SubnetId"])
|
||||
else:
|
||||
subnet = None
|
||||
# Get default Subnet
|
||||
subnet = [
|
||||
subnet
|
||||
for subnet in self.ec2_backend.get_all_subnets(
|
||||
filters={"availabilityZone": self._placement.zone}
|
||||
)
|
||||
if subnet.default_for_az
|
||||
][0]
|
||||
|
||||
group_id = nic.get("SecurityGroupId")
|
||||
group_ids = [group_id] if group_id else []
|
||||
|
|
@ -872,7 +886,14 @@ class InstanceBackend(object):
|
|||
)
|
||||
new_reservation.instances.append(new_instance)
|
||||
new_instance.add_tags(instance_tags)
|
||||
new_instance.setup_defaults()
|
||||
if "block_device_mappings" in kwargs:
|
||||
for block_device in kwargs["block_device_mappings"]:
|
||||
new_instance.add_block_device(
|
||||
block_device["Ebs"]["VolumeSize"], block_device["DeviceName"]
|
||||
)
|
||||
else:
|
||||
new_instance.setup_defaults()
|
||||
|
||||
return new_reservation
|
||||
|
||||
def start_instances(self, instance_ids):
|
||||
|
|
@ -936,6 +957,12 @@ class InstanceBackend(object):
|
|||
value = getattr(instance, key)
|
||||
return instance, value
|
||||
|
||||
def describe_instance_credit_specifications(self, instance_ids):
|
||||
queried_instances = []
|
||||
for instance in self.get_multi_instances_by_id(instance_ids):
|
||||
queried_instances.append(instance)
|
||||
return queried_instances
|
||||
|
||||
def all_instances(self, filters=None):
|
||||
instances = []
|
||||
for reservation in self.all_reservations():
|
||||
|
|
@ -1498,6 +1525,11 @@ class RegionsAndZonesBackend(object):
|
|||
regions.append(Region(region, "ec2.{}.amazonaws.com.cn".format(region)))
|
||||
|
||||
zones = {
|
||||
"af-south-1": [
|
||||
Zone(region_name="af-south-1", name="af-south-1a", zone_id="afs1-az1"),
|
||||
Zone(region_name="af-south-1", name="af-south-1b", zone_id="afs1-az2"),
|
||||
Zone(region_name="af-south-1", name="af-south-1c", zone_id="afs1-az3"),
|
||||
],
|
||||
"ap-south-1": [
|
||||
Zone(region_name="ap-south-1", name="ap-south-1a", zone_id="aps1-az1"),
|
||||
Zone(region_name="ap-south-1", name="ap-south-1b", zone_id="aps1-az3"),
|
||||
|
|
@ -2385,6 +2417,9 @@ class EBSBackend(object):
|
|||
|
||||
def delete_volume(self, volume_id):
|
||||
if volume_id in self.volumes:
|
||||
volume = self.volumes[volume_id]
|
||||
if volume.attachment:
|
||||
raise VolumeInUseError(volume_id, volume.attachment.instance.id)
|
||||
return self.volumes.pop(volume_id)
|
||||
raise InvalidVolumeIdError(volume_id)
|
||||
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
|
|
@ -35,6 +35,7 @@ DESCRIBE_ZONES_RESPONSE = """<DescribeAvailabilityZonesResponse xmlns="http://ec
|
|||
<zoneName>{{ zone.name }}</zoneName>
|
||||
<zoneState>available</zoneState>
|
||||
<regionName>{{ zone.region_name }}</regionName>
|
||||
<zoneId>{{ zone.zone_id }}</zoneId>
|
||||
<messageSet/>
|
||||
</item>
|
||||
{% endfor %}
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ class InstanceResponse(BaseResponse):
|
|||
private_ip = self._get_param("PrivateIpAddress")
|
||||
associate_public_ip = self._get_param("AssociatePublicIpAddress")
|
||||
key_name = self._get_param("KeyName")
|
||||
ebs_optimized = self._get_param("EbsOptimized")
|
||||
ebs_optimized = self._get_param("EbsOptimized") or False
|
||||
instance_initiated_shutdown_behavior = self._get_param(
|
||||
"InstanceInitiatedShutdownBehavior"
|
||||
)
|
||||
|
|
@ -168,6 +168,14 @@ class InstanceResponse(BaseResponse):
|
|||
|
||||
return template.render(instance=instance, attribute=attribute, value=value)
|
||||
|
||||
def describe_instance_credit_specifications(self):
|
||||
instance_ids = self._get_multi_param("InstanceId")
|
||||
instance = self.ec2_backend.describe_instance_credit_specifications(
|
||||
instance_ids
|
||||
)
|
||||
template = self.response_template(EC2_DESCRIBE_INSTANCE_CREDIT_SPECIFICATIONS)
|
||||
return template.render(instances=instance)
|
||||
|
||||
def modify_instance_attribute(self):
|
||||
handlers = [
|
||||
self._dot_value_instance_attribute_handler,
|
||||
|
|
@ -671,6 +679,18 @@ EC2_DESCRIBE_INSTANCE_ATTRIBUTE = """<DescribeInstanceAttributeResponse xmlns="h
|
|||
</{{ attribute }}>
|
||||
</DescribeInstanceAttributeResponse>"""
|
||||
|
||||
EC2_DESCRIBE_INSTANCE_CREDIT_SPECIFICATIONS = """<DescribeInstanceCreditSpecificationsResponse xmlns="http://ec2.amazonaws.com/doc/2016-11-15/">
|
||||
<requestId>1b234b5c-d6ef-7gh8-90i1-j2345678901</requestId>
|
||||
<instanceCreditSpecificationSet>
|
||||
{% for instance in instances %}
|
||||
<item>
|
||||
<instanceId>{{ instance.id }}</instanceId>
|
||||
<cpuCredits>standard</cpuCredits>
|
||||
</item>
|
||||
{% endfor %}
|
||||
</instanceCreditSpecificationSet>
|
||||
</DescribeInstanceCreditSpecificationsResponse>"""
|
||||
|
||||
EC2_DESCRIBE_INSTANCE_GROUPSET_ATTRIBUTE = """<DescribeInstanceAttributeResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">
|
||||
<requestId>59dbff89-35bd-4eac-99ed-be587EXAMPLE</requestId>
|
||||
<instanceId>{{ instance.id }}</instanceId>
|
||||
|
|
|
|||
|
|
@ -2,7 +2,8 @@ from __future__ import unicode_literals
|
|||
|
||||
from moto.core.responses import BaseResponse
|
||||
from moto.ec2.models import validate_resource_ids
|
||||
from moto.ec2.utils import tags_from_query_string, filters_from_querystring
|
||||
from moto.ec2.utils import filters_from_querystring
|
||||
from moto.core.utils import tags_from_query_string
|
||||
|
||||
|
||||
class TagResponse(BaseResponse):
|
||||
|
|
|
|||
|
|
@ -2,6 +2,6 @@ from __future__ import unicode_literals
|
|||
from .responses import EC2Response
|
||||
|
||||
|
||||
url_bases = ["https?://ec2\.(.+)\.amazonaws\.com(|\.cn)"]
|
||||
url_bases = [r"https?://ec2\.(.+)\.amazonaws\.com(|\.cn)"]
|
||||
|
||||
url_paths = {"{0}/": EC2Response.dispatch}
|
||||
|
|
|
|||
|
|
@ -196,22 +196,6 @@ def split_route_id(route_id):
|
|||
return values[0], values[1]
|
||||
|
||||
|
||||
def tags_from_query_string(querystring_dict):
|
||||
prefix = "Tag"
|
||||
suffix = "Key"
|
||||
response_values = {}
|
||||
for key, value in querystring_dict.items():
|
||||
if key.startswith(prefix) and key.endswith(suffix):
|
||||
tag_index = key.replace(prefix + ".", "").replace("." + suffix, "")
|
||||
tag_key = querystring_dict.get("Tag.{0}.Key".format(tag_index))[0]
|
||||
tag_value_key = "Tag.{0}.Value".format(tag_index)
|
||||
if tag_value_key in querystring_dict:
|
||||
response_values[tag_key] = querystring_dict.get(tag_value_key)[0]
|
||||
else:
|
||||
response_values[tag_key] = None
|
||||
return response_values
|
||||
|
||||
|
||||
def dhcp_configuration_from_querystring(querystring, option="DhcpConfiguration"):
|
||||
"""
|
||||
turn:
|
||||
|
|
|
|||
|
|
@ -604,7 +604,10 @@ class EC2ContainerServiceBackend(BaseBackend):
|
|||
raise Exception("{0} is not a task_definition".format(task_definition_name))
|
||||
|
||||
def run_task(self, cluster_str, task_definition_str, count, overrides, started_by):
|
||||
cluster_name = cluster_str.split("/")[-1]
|
||||
if cluster_str:
|
||||
cluster_name = cluster_str.split("/")[-1]
|
||||
else:
|
||||
cluster_name = "default"
|
||||
if cluster_name in self.clusters:
|
||||
cluster = self.clusters[cluster_name]
|
||||
else:
|
||||
|
|
|
|||
4
moto/elasticbeanstalk/__init__.py
Normal file
4
moto/elasticbeanstalk/__init__.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
from .models import eb_backends
|
||||
from moto.core.models import base_decorator
|
||||
|
||||
mock_elasticbeanstalk = base_decorator(eb_backends)
|
||||
15
moto/elasticbeanstalk/exceptions.py
Normal file
15
moto/elasticbeanstalk/exceptions.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
from moto.core.exceptions import RESTError
|
||||
|
||||
|
||||
class InvalidParameterValueError(RESTError):
|
||||
def __init__(self, message):
|
||||
super(InvalidParameterValueError, self).__init__(
|
||||
"InvalidParameterValue", message
|
||||
)
|
||||
|
||||
|
||||
class ResourceNotFoundException(RESTError):
|
||||
def __init__(self, message):
|
||||
super(ResourceNotFoundException, self).__init__(
|
||||
"ResourceNotFoundException", message
|
||||
)
|
||||
152
moto/elasticbeanstalk/models.py
Normal file
152
moto/elasticbeanstalk/models.py
Normal file
|
|
@ -0,0 +1,152 @@
|
|||
import weakref
|
||||
|
||||
from boto3 import Session
|
||||
|
||||
from moto.core import BaseBackend, BaseModel
|
||||
from .exceptions import InvalidParameterValueError, ResourceNotFoundException
|
||||
|
||||
|
||||
class FakeEnvironment(BaseModel):
|
||||
def __init__(
|
||||
self, application, environment_name, solution_stack_name, tags,
|
||||
):
|
||||
self.application = weakref.proxy(
|
||||
application
|
||||
) # weakref to break circular dependencies
|
||||
self.environment_name = environment_name
|
||||
self.solution_stack_name = solution_stack_name
|
||||
self.tags = tags
|
||||
|
||||
@property
|
||||
def application_name(self):
|
||||
return self.application.application_name
|
||||
|
||||
@property
|
||||
def environment_arn(self):
|
||||
return (
|
||||
"arn:aws:elasticbeanstalk:{region}:{account_id}:"
|
||||
"environment/{application_name}/{environment_name}".format(
|
||||
region=self.region,
|
||||
account_id="123456789012",
|
||||
application_name=self.application_name,
|
||||
environment_name=self.environment_name,
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def platform_arn(self):
|
||||
return "TODO" # TODO
|
||||
|
||||
@property
|
||||
def region(self):
|
||||
return self.application.region
|
||||
|
||||
|
||||
class FakeApplication(BaseModel):
|
||||
def __init__(self, backend, application_name):
|
||||
self.backend = weakref.proxy(backend) # weakref to break cycles
|
||||
self.application_name = application_name
|
||||
self.environments = dict()
|
||||
|
||||
def create_environment(
|
||||
self, environment_name, solution_stack_name, tags,
|
||||
):
|
||||
if environment_name in self.environments:
|
||||
raise InvalidParameterValueError
|
||||
|
||||
env = FakeEnvironment(
|
||||
application=self,
|
||||
environment_name=environment_name,
|
||||
solution_stack_name=solution_stack_name,
|
||||
tags=tags,
|
||||
)
|
||||
self.environments[environment_name] = env
|
||||
|
||||
return env
|
||||
|
||||
@property
|
||||
def region(self):
|
||||
return self.backend.region
|
||||
|
||||
|
||||
class EBBackend(BaseBackend):
|
||||
def __init__(self, region):
|
||||
self.region = region
|
||||
self.applications = dict()
|
||||
|
||||
def reset(self):
|
||||
# preserve region
|
||||
region = self.region
|
||||
self._reset_model_refs()
|
||||
self.__dict__ = {}
|
||||
self.__init__(region)
|
||||
|
||||
def create_application(self, application_name):
|
||||
if application_name in self.applications:
|
||||
raise InvalidParameterValueError(
|
||||
"Application {} already exists.".format(application_name)
|
||||
)
|
||||
new_app = FakeApplication(backend=self, application_name=application_name,)
|
||||
self.applications[application_name] = new_app
|
||||
return new_app
|
||||
|
||||
def create_environment(self, app, environment_name, stack_name, tags):
|
||||
return app.create_environment(
|
||||
environment_name=environment_name,
|
||||
solution_stack_name=stack_name,
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
def describe_environments(self):
|
||||
envs = []
|
||||
for app in self.applications.values():
|
||||
for env in app.environments.values():
|
||||
envs.append(env)
|
||||
return envs
|
||||
|
||||
def list_available_solution_stacks(self):
|
||||
# Implemented in response.py
|
||||
pass
|
||||
|
||||
def update_tags_for_resource(self, resource_arn, tags_to_add, tags_to_remove):
|
||||
try:
|
||||
res = self._find_environment_by_arn(resource_arn)
|
||||
except KeyError:
|
||||
raise ResourceNotFoundException(
|
||||
"Resource not found for ARN '{}'.".format(resource_arn)
|
||||
)
|
||||
|
||||
for key, value in tags_to_add.items():
|
||||
res.tags[key] = value
|
||||
|
||||
for key in tags_to_remove:
|
||||
del res.tags[key]
|
||||
|
||||
def list_tags_for_resource(self, resource_arn):
|
||||
try:
|
||||
res = self._find_environment_by_arn(resource_arn)
|
||||
except KeyError:
|
||||
raise ResourceNotFoundException(
|
||||
"Resource not found for ARN '{}'.".format(resource_arn)
|
||||
)
|
||||
return res.tags
|
||||
|
||||
def _find_environment_by_arn(self, arn):
|
||||
for app in self.applications.keys():
|
||||
for env in self.applications[app].environments.values():
|
||||
if env.environment_arn == arn:
|
||||
return env
|
||||
raise KeyError()
|
||||
|
||||
|
||||
eb_backends = {}
|
||||
for region in Session().get_available_regions("elasticbeanstalk"):
|
||||
eb_backends[region] = EBBackend(region)
|
||||
for region in Session().get_available_regions(
|
||||
"elasticbeanstalk", partition_name="aws-us-gov"
|
||||
):
|
||||
eb_backends[region] = EBBackend(region)
|
||||
for region in Session().get_available_regions(
|
||||
"elasticbeanstalk", partition_name="aws-cn"
|
||||
):
|
||||
eb_backends[region] = EBBackend(region)
|
||||
1386
moto/elasticbeanstalk/responses.py
Normal file
1386
moto/elasticbeanstalk/responses.py
Normal file
File diff suppressed because it is too large
Load diff
11
moto/elasticbeanstalk/urls.py
Normal file
11
moto/elasticbeanstalk/urls.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
from .responses import EBResponse
|
||||
|
||||
url_bases = [
|
||||
r"https?://elasticbeanstalk.(?P<region>[a-zA-Z0-9\-_]+).amazonaws.com",
|
||||
]
|
||||
|
||||
url_paths = {
|
||||
"{0}/$": EBResponse.dispatch,
|
||||
}
|
||||
|
|
@ -1,6 +1,9 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
import datetime
|
||||
|
||||
import pytz
|
||||
|
||||
from boto.ec2.elb.attributes import (
|
||||
LbAttributes,
|
||||
ConnectionSettingAttribute,
|
||||
|
|
@ -83,7 +86,7 @@ class FakeLoadBalancer(BaseModel):
|
|||
self.zones = zones
|
||||
self.listeners = []
|
||||
self.backends = []
|
||||
self.created_time = datetime.datetime.now()
|
||||
self.created_time = datetime.datetime.now(pytz.utc)
|
||||
self.scheme = scheme
|
||||
self.attributes = FakeLoadBalancer.get_default_attributes()
|
||||
self.policies = Policies()
|
||||
|
|
|
|||
|
|
@ -442,7 +442,7 @@ DESCRIBE_LOAD_BALANCERS_TEMPLATE = """<DescribeLoadBalancersResponse xmlns="http
|
|||
{% endfor %}
|
||||
</SecurityGroups>
|
||||
<LoadBalancerName>{{ load_balancer.name }}</LoadBalancerName>
|
||||
<CreatedTime>{{ load_balancer.created_time }}</CreatedTime>
|
||||
<CreatedTime>{{ load_balancer.created_time.isoformat() }}</CreatedTime>
|
||||
<HealthCheck>
|
||||
{% if load_balancer.health_check %}
|
||||
<Interval>{{ load_balancer.health_check.interval }}</Interval>
|
||||
|
|
|
|||
|
|
@ -10,9 +10,10 @@ from six.moves.urllib.parse import urlparse
|
|||
from moto.core.responses import AWSServiceSpec
|
||||
from moto.core.responses import BaseResponse
|
||||
from moto.core.responses import xml_to_json_response
|
||||
from moto.core.utils import tags_from_query_string
|
||||
from .exceptions import EmrError
|
||||
from .models import emr_backends
|
||||
from .utils import steps_from_query_string, tags_from_query_string
|
||||
from .utils import steps_from_query_string
|
||||
|
||||
|
||||
def generate_boto3_response(operation):
|
||||
|
|
@ -91,7 +92,7 @@ class ElasticMapReduceResponse(BaseResponse):
|
|||
@generate_boto3_response("AddTags")
|
||||
def add_tags(self):
|
||||
cluster_id = self._get_param("ResourceId")
|
||||
tags = tags_from_query_string(self.querystring)
|
||||
tags = tags_from_query_string(self.querystring, prefix="Tags")
|
||||
self.backend.add_tags(cluster_id, tags)
|
||||
template = self.response_template(ADD_TAGS_TEMPLATE)
|
||||
return template.render()
|
||||
|
|
|
|||
|
|
@ -22,22 +22,6 @@ def random_instance_group_id(size=13):
|
|||
return "i-{0}".format(random_id())
|
||||
|
||||
|
||||
def tags_from_query_string(querystring_dict):
|
||||
prefix = "Tags"
|
||||
suffix = "Key"
|
||||
response_values = {}
|
||||
for key, value in querystring_dict.items():
|
||||
if key.startswith(prefix) and key.endswith(suffix):
|
||||
tag_index = key.replace(prefix + ".", "").replace("." + suffix, "")
|
||||
tag_key = querystring_dict.get("Tags.{0}.Key".format(tag_index))[0]
|
||||
tag_value_key = "Tags.{0}.Value".format(tag_index)
|
||||
if tag_value_key in querystring_dict:
|
||||
response_values[tag_key] = querystring_dict.get(tag_value_key)[0]
|
||||
else:
|
||||
response_values[tag_key] = None
|
||||
return response_values
|
||||
|
||||
|
||||
def steps_from_query_string(querystring_dict):
|
||||
steps = []
|
||||
for step in querystring_dict:
|
||||
|
|
|
|||
|
|
@ -26,6 +26,10 @@ class Rule(BaseModel):
|
|||
self.role_arn = kwargs.get("RoleArn")
|
||||
self.targets = []
|
||||
|
||||
@property
|
||||
def physical_resource_id(self):
|
||||
return self.name
|
||||
|
||||
# This song and dance for targets is because we need order for Limits and NextTokens, but can't use OrderedDicts
|
||||
# with Python 2.6, so tracking it with an array it is.
|
||||
def _check_target_exists(self, target_id):
|
||||
|
|
@ -59,6 +63,14 @@ class Rule(BaseModel):
|
|||
if index is not None:
|
||||
self.targets.pop(index)
|
||||
|
||||
def get_cfn_attribute(self, attribute_name):
|
||||
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
|
||||
|
||||
if attribute_name == "Arn":
|
||||
return self.arn
|
||||
|
||||
raise UnformattedGetAttTemplateException()
|
||||
|
||||
@classmethod
|
||||
def create_from_cloudformation_json(
|
||||
cls, resource_name, cloudformation_json, region_name
|
||||
|
|
|
|||
|
|
@ -34,6 +34,9 @@ class GlueBackend(BaseBackend):
|
|||
except KeyError:
|
||||
raise DatabaseNotFoundException(database_name)
|
||||
|
||||
def get_databases(self):
|
||||
return [self.databases[key] for key in self.databases] if self.databases else []
|
||||
|
||||
def create_table(self, database_name, table_name, table_input):
|
||||
database = self.get_database(database_name)
|
||||
|
||||
|
|
|
|||
|
|
@ -30,6 +30,12 @@ class GlueResponse(BaseResponse):
|
|||
database = self.glue_backend.get_database(database_name)
|
||||
return json.dumps({"Database": {"Name": database.name}})
|
||||
|
||||
def get_databases(self):
|
||||
database_list = self.glue_backend.get_databases()
|
||||
return json.dumps(
|
||||
{"DatabaseList": [{"Name": database.name} for database in database_list]}
|
||||
)
|
||||
|
||||
def create_table(self):
|
||||
database_name = self.parameters.get("DatabaseName")
|
||||
table_input = self.parameters.get("TableInput")
|
||||
|
|
|
|||
|
|
@ -7,10 +7,10 @@ class IoTClientError(JsonRESTError):
|
|||
|
||||
|
||||
class ResourceNotFoundException(IoTClientError):
|
||||
def __init__(self):
|
||||
def __init__(self, msg=None):
|
||||
self.code = 404
|
||||
super(ResourceNotFoundException, self).__init__(
|
||||
"ResourceNotFoundException", "The specified resource does not exist"
|
||||
"ResourceNotFoundException", msg or "The specified resource does not exist"
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -805,6 +805,14 @@ class IoTBackend(BaseBackend):
|
|||
return thing_names
|
||||
|
||||
def list_thing_principals(self, thing_name):
|
||||
|
||||
things = [_ for _ in self.things.values() if _.thing_name == thing_name]
|
||||
if len(things) == 0:
|
||||
raise ResourceNotFoundException(
|
||||
"Failed to list principals for thing %s because the thing does not exist in your account"
|
||||
% thing_name
|
||||
)
|
||||
|
||||
principals = [
|
||||
k[0] for k, v in self.principal_things.items() if k[1] == thing_name
|
||||
]
|
||||
|
|
|
|||
|
|
@ -134,7 +134,7 @@ class LogStream:
|
|||
return None, 0
|
||||
|
||||
events = sorted(
|
||||
filter(filter_func, self.events), key=lambda event: event.timestamp,
|
||||
filter(filter_func, self.events), key=lambda event: event.timestamp
|
||||
)
|
||||
|
||||
direction, index = get_index_and_direction_from_token(next_token)
|
||||
|
|
@ -169,11 +169,7 @@ class LogStream:
|
|||
if end_index > final_index:
|
||||
end_index = final_index
|
||||
elif end_index < 0:
|
||||
return (
|
||||
[],
|
||||
"b/{:056d}".format(0),
|
||||
"f/{:056d}".format(0),
|
||||
)
|
||||
return ([], "b/{:056d}".format(0), "f/{:056d}".format(0))
|
||||
|
||||
events_page = [
|
||||
event.to_response_dict() for event in events[start_index : end_index + 1]
|
||||
|
|
@ -219,7 +215,7 @@ class LogStream:
|
|||
|
||||
|
||||
class LogGroup:
|
||||
def __init__(self, region, name, tags):
|
||||
def __init__(self, region, name, tags, **kwargs):
|
||||
self.name = name
|
||||
self.region = region
|
||||
self.arn = "arn:aws:logs:{region}:1:log-group:{log_group}".format(
|
||||
|
|
@ -228,9 +224,9 @@ class LogGroup:
|
|||
self.creationTime = int(unix_time_millis())
|
||||
self.tags = tags
|
||||
self.streams = dict() # {name: LogStream}
|
||||
self.retentionInDays = (
|
||||
None # AWS defaults to Never Expire for log group retention
|
||||
)
|
||||
self.retention_in_days = kwargs.get(
|
||||
"RetentionInDays"
|
||||
) # AWS defaults to Never Expire for log group retention
|
||||
|
||||
def create_log_stream(self, log_stream_name):
|
||||
if log_stream_name in self.streams:
|
||||
|
|
@ -368,12 +364,12 @@ class LogGroup:
|
|||
"storedBytes": sum(s.storedBytes for s in self.streams.values()),
|
||||
}
|
||||
# AWS only returns retentionInDays if a value is set for the log group (ie. not Never Expire)
|
||||
if self.retentionInDays:
|
||||
log_group["retentionInDays"] = self.retentionInDays
|
||||
if self.retention_in_days:
|
||||
log_group["retentionInDays"] = self.retention_in_days
|
||||
return log_group
|
||||
|
||||
def set_retention_policy(self, retention_in_days):
|
||||
self.retentionInDays = retention_in_days
|
||||
self.retention_in_days = retention_in_days
|
||||
|
||||
def list_tags(self):
|
||||
return self.tags if self.tags else {}
|
||||
|
|
@ -401,10 +397,12 @@ class LogsBackend(BaseBackend):
|
|||
self.__dict__ = {}
|
||||
self.__init__(region_name)
|
||||
|
||||
def create_log_group(self, log_group_name, tags):
|
||||
def create_log_group(self, log_group_name, tags, **kwargs):
|
||||
if log_group_name in self.groups:
|
||||
raise ResourceAlreadyExistsException()
|
||||
self.groups[log_group_name] = LogGroup(self.region_name, log_group_name, tags)
|
||||
self.groups[log_group_name] = LogGroup(
|
||||
self.region_name, log_group_name, tags, **kwargs
|
||||
)
|
||||
return self.groups[log_group_name]
|
||||
|
||||
def ensure_log_group(self, log_group_name, tags):
|
||||
|
|
|
|||
|
|
@ -865,7 +865,10 @@ class RDS2Backend(BaseBackend):
|
|||
def stop_database(self, db_instance_identifier, db_snapshot_identifier=None):
|
||||
database = self.describe_databases(db_instance_identifier)[0]
|
||||
# todo: certain rds types not allowed to be stopped at this time.
|
||||
if database.is_replica or database.multi_az:
|
||||
# https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/USER_StopInstance.html#USER_StopInstance.Limitations
|
||||
if database.is_replica or (
|
||||
database.multi_az and database.engine.lower().startswith("sqlserver")
|
||||
):
|
||||
# todo: more db types not supported by stop/start instance api
|
||||
raise InvalidDBClusterStateFaultError(db_instance_identifier)
|
||||
if database.status != "available":
|
||||
|
|
|
|||
|
|
@ -145,10 +145,7 @@ class ResourceGroupsTaggingAPIBackend(BaseBackend):
|
|||
# Do S3, resource type s3
|
||||
if not resource_type_filters or "s3" in resource_type_filters:
|
||||
for bucket in self.s3_backend.buckets.values():
|
||||
tags = []
|
||||
for tag in bucket.tags.tag_set.tags:
|
||||
tags.append({"Key": tag.key, "Value": tag.value})
|
||||
|
||||
tags = self.s3_backend.tagger.list_tags_for_resource(bucket.arn)["Tags"]
|
||||
if not tags or not tag_filter(
|
||||
tags
|
||||
): # Skip if no tags, or invalid filter
|
||||
|
|
@ -362,8 +359,9 @@ class ResourceGroupsTaggingAPIBackend(BaseBackend):
|
|||
|
||||
# Do S3, resource type s3
|
||||
for bucket in self.s3_backend.buckets.values():
|
||||
for tag in bucket.tags.tag_set.tags:
|
||||
yield tag.key
|
||||
tags = self.s3_backend.tagger.get_tag_dict_for_resource(bucket.arn)
|
||||
for key, _ in tags.items():
|
||||
yield key
|
||||
|
||||
# EC2 tags
|
||||
def get_ec2_keys(res_id):
|
||||
|
|
@ -414,9 +412,10 @@ class ResourceGroupsTaggingAPIBackend(BaseBackend):
|
|||
|
||||
# Do S3, resource type s3
|
||||
for bucket in self.s3_backend.buckets.values():
|
||||
for tag in bucket.tags.tag_set.tags:
|
||||
if tag.key == tag_key:
|
||||
yield tag.value
|
||||
tags = self.s3_backend.tagger.get_tag_dict_for_resource(bucket.arn)
|
||||
for key, value in tags.items():
|
||||
if key == tag_key:
|
||||
yield value
|
||||
|
||||
# EC2 tags
|
||||
def get_ec2_values(res_id):
|
||||
|
|
|
|||
|
|
@ -22,6 +22,8 @@ import six
|
|||
from bisect import insort
|
||||
from moto.core import ACCOUNT_ID, BaseBackend, BaseModel
|
||||
from moto.core.utils import iso_8601_datetime_with_milliseconds, rfc_1123_datetime
|
||||
from moto.cloudwatch.models import metric_providers, MetricDatum
|
||||
from moto.utilities.tagging_service import TaggingService
|
||||
from .exceptions import (
|
||||
BucketAlreadyExists,
|
||||
MissingBucket,
|
||||
|
|
@ -34,7 +36,6 @@ from .exceptions import (
|
|||
MalformedXML,
|
||||
InvalidStorageClass,
|
||||
InvalidTargetBucketForLogging,
|
||||
DuplicateTagKeys,
|
||||
CrossLocationLoggingProhibitted,
|
||||
NoSuchPublicAccessBlockConfiguration,
|
||||
InvalidPublicAccessBlockConfiguration,
|
||||
|
|
@ -94,6 +95,7 @@ class FakeKey(BaseModel):
|
|||
version_id=0,
|
||||
max_buffer_size=DEFAULT_KEY_BUFFER_SIZE,
|
||||
multipart=None,
|
||||
bucket_name=None,
|
||||
):
|
||||
self.name = name
|
||||
self.last_modified = datetime.datetime.utcnow()
|
||||
|
|
@ -105,8 +107,8 @@ class FakeKey(BaseModel):
|
|||
self._etag = etag
|
||||
self._version_id = version_id
|
||||
self._is_versioned = is_versioned
|
||||
self._tagging = FakeTagging()
|
||||
self.multipart = multipart
|
||||
self.bucket_name = bucket_name
|
||||
|
||||
self._value_buffer = tempfile.SpooledTemporaryFile(max_size=max_buffer_size)
|
||||
self._max_buffer_size = max_buffer_size
|
||||
|
|
@ -126,6 +128,13 @@ class FakeKey(BaseModel):
|
|||
self.lock.release()
|
||||
return r
|
||||
|
||||
@property
|
||||
def arn(self):
|
||||
# S3 Objects don't have an ARN, but we do need something unique when creating tags against this resource
|
||||
return "arn:aws:s3:::{}/{}/{}".format(
|
||||
self.bucket_name, self.name, self.version_id
|
||||
)
|
||||
|
||||
@value.setter
|
||||
def value(self, new_value):
|
||||
self._value_buffer.seek(0)
|
||||
|
|
@ -152,9 +161,6 @@ class FakeKey(BaseModel):
|
|||
self._metadata = {}
|
||||
self._metadata.update(metadata)
|
||||
|
||||
def set_tagging(self, tagging):
|
||||
self._tagging = tagging
|
||||
|
||||
def set_storage_class(self, storage):
|
||||
if storage is not None and storage not in STORAGE_CLASS:
|
||||
raise InvalidStorageClass(storage=storage)
|
||||
|
|
@ -210,10 +216,6 @@ class FakeKey(BaseModel):
|
|||
def metadata(self):
|
||||
return self._metadata
|
||||
|
||||
@property
|
||||
def tagging(self):
|
||||
return self._tagging
|
||||
|
||||
@property
|
||||
def response_dict(self):
|
||||
res = {
|
||||
|
|
@ -471,26 +473,10 @@ def get_canned_acl(acl):
|
|||
return FakeAcl(grants=grants)
|
||||
|
||||
|
||||
class FakeTagging(BaseModel):
|
||||
def __init__(self, tag_set=None):
|
||||
self.tag_set = tag_set or FakeTagSet()
|
||||
|
||||
|
||||
class FakeTagSet(BaseModel):
|
||||
def __init__(self, tags=None):
|
||||
self.tags = tags or []
|
||||
|
||||
|
||||
class FakeTag(BaseModel):
|
||||
def __init__(self, key, value=None):
|
||||
self.key = key
|
||||
self.value = value
|
||||
|
||||
|
||||
class LifecycleFilter(BaseModel):
|
||||
def __init__(self, prefix=None, tag=None, and_filter=None):
|
||||
self.prefix = prefix
|
||||
self.tag = tag
|
||||
(self.tag_key, self.tag_value) = tag if tag else (None, None)
|
||||
self.and_filter = and_filter
|
||||
|
||||
def to_config_dict(self):
|
||||
|
|
@ -499,11 +485,11 @@ class LifecycleFilter(BaseModel):
|
|||
"predicate": {"type": "LifecyclePrefixPredicate", "prefix": self.prefix}
|
||||
}
|
||||
|
||||
elif self.tag:
|
||||
elif self.tag_key:
|
||||
return {
|
||||
"predicate": {
|
||||
"type": "LifecycleTagPredicate",
|
||||
"tag": {"key": self.tag.key, "value": self.tag.value},
|
||||
"tag": {"key": self.tag_key, "value": self.tag_value},
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -527,12 +513,9 @@ class LifecycleAndFilter(BaseModel):
|
|||
if self.prefix is not None:
|
||||
data.append({"type": "LifecyclePrefixPredicate", "prefix": self.prefix})
|
||||
|
||||
for tag in self.tags:
|
||||
for key, value in self.tags.items():
|
||||
data.append(
|
||||
{
|
||||
"type": "LifecycleTagPredicate",
|
||||
"tag": {"key": tag.key, "value": tag.value},
|
||||
}
|
||||
{"type": "LifecycleTagPredicate", "tag": {"key": key, "value": value},}
|
||||
)
|
||||
|
||||
return data
|
||||
|
|
@ -787,7 +770,6 @@ class FakeBucket(BaseModel):
|
|||
self.policy = None
|
||||
self.website_configuration = None
|
||||
self.acl = get_canned_acl("private")
|
||||
self.tags = FakeTagging()
|
||||
self.cors = []
|
||||
self.logging = {}
|
||||
self.notification_configuration = None
|
||||
|
|
@ -879,7 +861,7 @@ class FakeBucket(BaseModel):
|
|||
and_filter = None
|
||||
if rule["Filter"].get("And"):
|
||||
filters += 1
|
||||
and_tags = []
|
||||
and_tags = {}
|
||||
if rule["Filter"]["And"].get("Tag"):
|
||||
if not isinstance(rule["Filter"]["And"]["Tag"], list):
|
||||
rule["Filter"]["And"]["Tag"] = [
|
||||
|
|
@ -887,7 +869,7 @@ class FakeBucket(BaseModel):
|
|||
]
|
||||
|
||||
for t in rule["Filter"]["And"]["Tag"]:
|
||||
and_tags.append(FakeTag(t["Key"], t.get("Value", "")))
|
||||
and_tags[t["Key"]] = t.get("Value", "")
|
||||
|
||||
try:
|
||||
and_prefix = (
|
||||
|
|
@ -901,7 +883,7 @@ class FakeBucket(BaseModel):
|
|||
filter_tag = None
|
||||
if rule["Filter"].get("Tag"):
|
||||
filters += 1
|
||||
filter_tag = FakeTag(
|
||||
filter_tag = (
|
||||
rule["Filter"]["Tag"]["Key"],
|
||||
rule["Filter"]["Tag"].get("Value", ""),
|
||||
)
|
||||
|
|
@ -988,16 +970,6 @@ class FakeBucket(BaseModel):
|
|||
def delete_cors(self):
|
||||
self.cors = []
|
||||
|
||||
def set_tags(self, tagging):
|
||||
self.tags = tagging
|
||||
|
||||
def delete_tags(self):
|
||||
self.tags = FakeTagging()
|
||||
|
||||
@property
|
||||
def tagging(self):
|
||||
return self.tags
|
||||
|
||||
def set_logging(self, logging_config, bucket_backend):
|
||||
if not logging_config:
|
||||
self.logging = {}
|
||||
|
|
@ -1085,6 +1057,10 @@ class FakeBucket(BaseModel):
|
|||
def set_acl(self, acl):
|
||||
self.acl = acl
|
||||
|
||||
@property
|
||||
def arn(self):
|
||||
return "arn:aws:s3:::{}".format(self.name)
|
||||
|
||||
@property
|
||||
def physical_resource_id(self):
|
||||
return self.name
|
||||
|
|
@ -1110,7 +1086,7 @@ class FakeBucket(BaseModel):
|
|||
int(time.mktime(self.creation_date.timetuple()))
|
||||
), # PY2 and 3 compatible
|
||||
"configurationItemMD5Hash": "",
|
||||
"arn": "arn:aws:s3:::{}".format(self.name),
|
||||
"arn": self.arn,
|
||||
"resourceType": "AWS::S3::Bucket",
|
||||
"resourceId": self.name,
|
||||
"resourceName": self.name,
|
||||
|
|
@ -1119,7 +1095,7 @@ class FakeBucket(BaseModel):
|
|||
"resourceCreationTime": str(self.creation_date),
|
||||
"relatedEvents": [],
|
||||
"relationships": [],
|
||||
"tags": {tag.key: tag.value for tag in self.tagging.tag_set.tags},
|
||||
"tags": s3_backend.tagger.get_tag_dict_for_resource(self.arn),
|
||||
"configuration": {
|
||||
"name": self.name,
|
||||
"owner": {"id": OWNER},
|
||||
|
|
@ -1181,6 +1157,40 @@ class S3Backend(BaseBackend):
|
|||
def __init__(self):
|
||||
self.buckets = {}
|
||||
self.account_public_access_block = None
|
||||
self.tagger = TaggingService()
|
||||
|
||||
# Register this class as a CloudWatch Metric Provider
|
||||
# Must provide a method 'get_cloudwatch_metrics' that will return a list of metrics, based on the data available
|
||||
metric_providers["S3"] = self
|
||||
|
||||
def get_cloudwatch_metrics(self):
|
||||
metrics = []
|
||||
for name, bucket in self.buckets.items():
|
||||
metrics.append(
|
||||
MetricDatum(
|
||||
namespace="AWS/S3",
|
||||
name="BucketSizeBytes",
|
||||
value=bucket.keys.item_size(),
|
||||
dimensions=[
|
||||
{"Name": "StorageType", "Value": "StandardStorage"},
|
||||
{"Name": "BucketName", "Value": name},
|
||||
],
|
||||
timestamp=datetime.datetime.now(),
|
||||
)
|
||||
)
|
||||
metrics.append(
|
||||
MetricDatum(
|
||||
namespace="AWS/S3",
|
||||
name="NumberOfObjects",
|
||||
value=len(bucket.keys),
|
||||
dimensions=[
|
||||
{"Name": "StorageType", "Value": "AllStorageTypes"},
|
||||
{"Name": "BucketName", "Value": name},
|
||||
],
|
||||
timestamp=datetime.datetime.now(),
|
||||
)
|
||||
)
|
||||
return metrics
|
||||
|
||||
def create_bucket(self, bucket_name, region_name):
|
||||
if bucket_name in self.buckets:
|
||||
|
|
@ -1350,23 +1360,32 @@ class S3Backend(BaseBackend):
|
|||
else:
|
||||
return None
|
||||
|
||||
def set_key_tagging(self, bucket_name, key_name, tagging, version_id=None):
|
||||
key = self.get_key(bucket_name, key_name, version_id)
|
||||
def get_key_tags(self, key):
|
||||
return self.tagger.list_tags_for_resource(key.arn)
|
||||
|
||||
def set_key_tags(self, key, tags, key_name=None):
|
||||
if key is None:
|
||||
raise MissingKey(key_name)
|
||||
key.set_tagging(tagging)
|
||||
self.tagger.delete_all_tags_for_resource(key.arn)
|
||||
self.tagger.tag_resource(
|
||||
key.arn, [{"Key": key, "Value": value} for key, value in tags.items()],
|
||||
)
|
||||
return key
|
||||
|
||||
def put_bucket_tagging(self, bucket_name, tagging):
|
||||
tag_keys = [tag.key for tag in tagging.tag_set.tags]
|
||||
if len(tag_keys) != len(set(tag_keys)):
|
||||
raise DuplicateTagKeys()
|
||||
def get_bucket_tags(self, bucket_name):
|
||||
bucket = self.get_bucket(bucket_name)
|
||||
bucket.set_tags(tagging)
|
||||
return self.tagger.list_tags_for_resource(bucket.arn)
|
||||
|
||||
def put_bucket_tags(self, bucket_name, tags):
|
||||
bucket = self.get_bucket(bucket_name)
|
||||
self.tagger.delete_all_tags_for_resource(bucket.arn)
|
||||
self.tagger.tag_resource(
|
||||
bucket.arn, [{"Key": key, "Value": value} for key, value in tags.items()],
|
||||
)
|
||||
|
||||
def delete_bucket_tagging(self, bucket_name):
|
||||
bucket = self.get_bucket(bucket_name)
|
||||
bucket.delete_tags()
|
||||
self.tagger.delete_all_tags_for_resource(bucket.arn)
|
||||
|
||||
def put_bucket_cors(self, bucket_name, cors_rules):
|
||||
bucket = self.get_bucket(bucket_name)
|
||||
|
|
@ -1574,6 +1593,7 @@ class S3Backend(BaseBackend):
|
|||
key = self.get_key(src_bucket_name, src_key_name, version_id=src_version_id)
|
||||
|
||||
new_key = key.copy(dest_key_name, dest_bucket.is_versioned)
|
||||
self.tagger.copy_tags(key.arn, new_key.arn)
|
||||
|
||||
if storage is not None:
|
||||
new_key.set_storage_class(storage)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import six
|
|||
from botocore.awsrequest import AWSPreparedRequest
|
||||
|
||||
from moto.core.utils import str_to_rfc_1123_datetime, py2_strip_unicode_keys
|
||||
from six.moves.urllib.parse import parse_qs, urlparse, unquote
|
||||
from six.moves.urllib.parse import parse_qs, urlparse, unquote, parse_qsl
|
||||
|
||||
import xmltodict
|
||||
|
||||
|
|
@ -24,6 +24,7 @@ from moto.s3bucket_path.utils import (
|
|||
|
||||
from .exceptions import (
|
||||
BucketAlreadyExists,
|
||||
DuplicateTagKeys,
|
||||
S3ClientError,
|
||||
MissingBucket,
|
||||
MissingKey,
|
||||
|
|
@ -43,9 +44,6 @@ from .models import (
|
|||
FakeGrant,
|
||||
FakeAcl,
|
||||
FakeKey,
|
||||
FakeTagging,
|
||||
FakeTagSet,
|
||||
FakeTag,
|
||||
)
|
||||
from .utils import (
|
||||
bucket_name_from_url,
|
||||
|
|
@ -134,7 +132,8 @@ ACTION_MAP = {
|
|||
|
||||
|
||||
def parse_key_name(pth):
|
||||
return pth.lstrip("/")
|
||||
# strip the first '/' left by urlparse
|
||||
return pth[1:] if pth.startswith("/") else pth
|
||||
|
||||
|
||||
def is_delete_keys(request, path, bucket_name):
|
||||
|
|
@ -378,13 +377,13 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
|
|||
template = self.response_template(S3_OBJECT_ACL_RESPONSE)
|
||||
return template.render(obj=bucket)
|
||||
elif "tagging" in querystring:
|
||||
bucket = self.backend.get_bucket(bucket_name)
|
||||
tags = self.backend.get_bucket_tags(bucket_name)["Tags"]
|
||||
# "Special Error" if no tags:
|
||||
if len(bucket.tagging.tag_set.tags) == 0:
|
||||
if len(tags) == 0:
|
||||
template = self.response_template(S3_NO_BUCKET_TAGGING)
|
||||
return 404, {}, template.render(bucket_name=bucket_name)
|
||||
template = self.response_template(S3_BUCKET_TAGGING_RESPONSE)
|
||||
return template.render(bucket=bucket)
|
||||
template = self.response_template(S3_OBJECT_TAGGING_RESPONSE)
|
||||
return template.render(tags=tags)
|
||||
elif "logging" in querystring:
|
||||
bucket = self.backend.get_bucket(bucket_name)
|
||||
if not bucket.logging:
|
||||
|
|
@ -652,7 +651,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
|
|||
return ""
|
||||
elif "tagging" in querystring:
|
||||
tagging = self._bucket_tagging_from_xml(body)
|
||||
self.backend.put_bucket_tagging(bucket_name, tagging)
|
||||
self.backend.put_bucket_tags(bucket_name, tagging)
|
||||
return ""
|
||||
elif "website" in querystring:
|
||||
self.backend.set_bucket_website_configuration(bucket_name, body)
|
||||
|
|
@ -777,6 +776,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
|
|||
return 409, {}, template.render(bucket=removed_bucket)
|
||||
|
||||
def _bucket_response_post(self, request, body, bucket_name):
|
||||
response_headers = {}
|
||||
if not request.headers.get("Content-Length"):
|
||||
return 411, {}, "Content-Length required"
|
||||
|
||||
|
|
@ -798,11 +798,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
|
|||
else:
|
||||
# HTTPretty, build new form object
|
||||
body = body.decode()
|
||||
|
||||
form = {}
|
||||
for kv in body.split("&"):
|
||||
k, v = kv.split("=")
|
||||
form[k] = v
|
||||
form = dict(parse_qsl(body))
|
||||
|
||||
key = form["key"]
|
||||
if "file" in form:
|
||||
|
|
@ -810,13 +806,23 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
|
|||
else:
|
||||
f = request.files["file"].stream.read()
|
||||
|
||||
if "success_action_redirect" in form:
|
||||
response_headers["Location"] = form["success_action_redirect"]
|
||||
|
||||
if "success_action_status" in form:
|
||||
status_code = form["success_action_status"]
|
||||
elif "success_action_redirect" in form:
|
||||
status_code = 303
|
||||
else:
|
||||
status_code = 204
|
||||
|
||||
new_key = self.backend.set_key(bucket_name, key, f)
|
||||
|
||||
# Metadata
|
||||
metadata = metadata_from_headers(form)
|
||||
new_key.set_metadata(metadata)
|
||||
|
||||
return 200, {}, ""
|
||||
return status_code, response_headers, ""
|
||||
|
||||
@staticmethod
|
||||
def _get_path(request):
|
||||
|
|
@ -1091,8 +1097,9 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
|
|||
template = self.response_template(S3_OBJECT_ACL_RESPONSE)
|
||||
return 200, response_headers, template.render(obj=key)
|
||||
if "tagging" in query:
|
||||
tags = self.backend.get_key_tags(key)["Tags"]
|
||||
template = self.response_template(S3_OBJECT_TAGGING_RESPONSE)
|
||||
return 200, response_headers, template.render(obj=key)
|
||||
return 200, response_headers, template.render(tags=tags)
|
||||
|
||||
response_headers.update(key.metadata)
|
||||
response_headers.update(key.response_dict)
|
||||
|
|
@ -1164,8 +1171,9 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
|
|||
version_id = query["versionId"][0]
|
||||
else:
|
||||
version_id = None
|
||||
key = self.backend.get_key(bucket_name, key_name, version_id=version_id)
|
||||
tagging = self._tagging_from_xml(body)
|
||||
self.backend.set_key_tagging(bucket_name, key_name, tagging, version_id)
|
||||
self.backend.set_key_tags(key, tagging, key_name)
|
||||
return 200, response_headers, ""
|
||||
|
||||
if "x-amz-copy-source" in request.headers:
|
||||
|
|
@ -1206,7 +1214,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
|
|||
tdirective = request.headers.get("x-amz-tagging-directive")
|
||||
if tdirective == "REPLACE":
|
||||
tagging = self._tagging_from_headers(request.headers)
|
||||
new_key.set_tagging(tagging)
|
||||
self.backend.set_key_tags(new_key, tagging)
|
||||
template = self.response_template(S3_OBJECT_COPY_RESPONSE)
|
||||
response_headers.update(new_key.response_dict)
|
||||
return 200, response_headers, template.render(key=new_key)
|
||||
|
|
@ -1230,11 +1238,10 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
|
|||
new_key.website_redirect_location = request.headers.get(
|
||||
"x-amz-website-redirect-location"
|
||||
)
|
||||
new_key.set_tagging(tagging)
|
||||
self.backend.set_key_tags(new_key, tagging)
|
||||
|
||||
template = self.response_template(S3_OBJECT_RESPONSE)
|
||||
response_headers.update(new_key.response_dict)
|
||||
return 200, response_headers, template.render(key=new_key)
|
||||
return 200, response_headers, ""
|
||||
|
||||
def _key_response_head(self, bucket_name, query, key_name, headers):
|
||||
response_headers = {}
|
||||
|
|
@ -1359,55 +1366,45 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
|
|||
return None
|
||||
|
||||
def _tagging_from_headers(self, headers):
|
||||
tags = {}
|
||||
if headers.get("x-amz-tagging"):
|
||||
parsed_header = parse_qs(headers["x-amz-tagging"], keep_blank_values=True)
|
||||
tags = []
|
||||
for tag in parsed_header.items():
|
||||
tags.append(FakeTag(tag[0], tag[1][0]))
|
||||
|
||||
tag_set = FakeTagSet(tags)
|
||||
tagging = FakeTagging(tag_set)
|
||||
return tagging
|
||||
else:
|
||||
return FakeTagging()
|
||||
tags[tag[0]] = tag[1][0]
|
||||
return tags
|
||||
|
||||
def _tagging_from_xml(self, xml):
|
||||
parsed_xml = xmltodict.parse(xml, force_list={"Tag": True})
|
||||
|
||||
tags = []
|
||||
tags = {}
|
||||
for tag in parsed_xml["Tagging"]["TagSet"]["Tag"]:
|
||||
tags.append(FakeTag(tag["Key"], tag["Value"]))
|
||||
tags[tag["Key"]] = tag["Value"]
|
||||
|
||||
tag_set = FakeTagSet(tags)
|
||||
tagging = FakeTagging(tag_set)
|
||||
return tagging
|
||||
return tags
|
||||
|
||||
def _bucket_tagging_from_xml(self, xml):
|
||||
parsed_xml = xmltodict.parse(xml)
|
||||
|
||||
tags = []
|
||||
tags = {}
|
||||
# Optional if no tags are being sent:
|
||||
if parsed_xml["Tagging"].get("TagSet"):
|
||||
# If there is only 1 tag, then it's not a list:
|
||||
if not isinstance(parsed_xml["Tagging"]["TagSet"]["Tag"], list):
|
||||
tags.append(
|
||||
FakeTag(
|
||||
parsed_xml["Tagging"]["TagSet"]["Tag"]["Key"],
|
||||
parsed_xml["Tagging"]["TagSet"]["Tag"]["Value"],
|
||||
)
|
||||
)
|
||||
tags[parsed_xml["Tagging"]["TagSet"]["Tag"]["Key"]] = parsed_xml[
|
||||
"Tagging"
|
||||
]["TagSet"]["Tag"]["Value"]
|
||||
else:
|
||||
for tag in parsed_xml["Tagging"]["TagSet"]["Tag"]:
|
||||
tags.append(FakeTag(tag["Key"], tag["Value"]))
|
||||
if tag["Key"] in tags:
|
||||
raise DuplicateTagKeys()
|
||||
tags[tag["Key"]] = tag["Value"]
|
||||
|
||||
# Verify that "aws:" is not in the tags. If so, then this is a problem:
|
||||
for tag in tags:
|
||||
if tag.key.startswith("aws:"):
|
||||
for key, _ in tags.items():
|
||||
if key.startswith("aws:"):
|
||||
raise NoSystemTags()
|
||||
|
||||
tag_set = FakeTagSet(tags)
|
||||
tagging = FakeTagging(tag_set)
|
||||
return tagging
|
||||
return tags
|
||||
|
||||
def _cors_from_xml(self, xml):
|
||||
parsed_xml = xmltodict.parse(xml)
|
||||
|
|
@ -1552,8 +1549,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
|
|||
return 204, {}, ""
|
||||
version_id = query.get("versionId", [None])[0]
|
||||
self.backend.delete_key(bucket_name, key_name, version_id=version_id)
|
||||
template = self.response_template(S3_DELETE_OBJECT_SUCCESS)
|
||||
return 204, {}, template.render()
|
||||
return 204, {}, ""
|
||||
|
||||
def _complete_multipart_body(self, body):
|
||||
ps = minidom.parseString(body).getElementsByTagName("Part")
|
||||
|
|
@ -1728,10 +1724,10 @@ S3_BUCKET_LIFECYCLE_CONFIGURATION = """<?xml version="1.0" encoding="UTF-8"?>
|
|||
{% if rule.filter.prefix != None %}
|
||||
<Prefix>{{ rule.filter.prefix }}</Prefix>
|
||||
{% endif %}
|
||||
{% if rule.filter.tag %}
|
||||
{% if rule.filter.tag_key %}
|
||||
<Tag>
|
||||
<Key>{{ rule.filter.tag.key }}</Key>
|
||||
<Value>{{ rule.filter.tag.value }}</Value>
|
||||
<Key>{{ rule.filter.tag_key }}</Key>
|
||||
<Value>{{ rule.filter.tag_value }}</Value>
|
||||
</Tag>
|
||||
{% endif %}
|
||||
{% if rule.filter.and_filter %}
|
||||
|
|
@ -1739,10 +1735,10 @@ S3_BUCKET_LIFECYCLE_CONFIGURATION = """<?xml version="1.0" encoding="UTF-8"?>
|
|||
{% if rule.filter.and_filter.prefix != None %}
|
||||
<Prefix>{{ rule.filter.and_filter.prefix }}</Prefix>
|
||||
{% endif %}
|
||||
{% for tag in rule.filter.and_filter.tags %}
|
||||
{% for key, value in rule.filter.and_filter.tags.items() %}
|
||||
<Tag>
|
||||
<Key>{{ tag.key }}</Key>
|
||||
<Value>{{ tag.value }}</Value>
|
||||
<Key>{{ key }}</Key>
|
||||
<Value>{{ value }}</Value>
|
||||
</Tag>
|
||||
{% endfor %}
|
||||
</And>
|
||||
|
|
@ -1868,20 +1864,6 @@ S3_DELETE_KEYS_RESPONSE = """<?xml version="1.0" encoding="UTF-8"?>
|
|||
{% endfor %}
|
||||
</DeleteResult>"""
|
||||
|
||||
S3_DELETE_OBJECT_SUCCESS = """<DeleteObjectResponse xmlns="http://s3.amazonaws.com/doc/2006-03-01">
|
||||
<DeleteObjectResponse>
|
||||
<Code>200</Code>
|
||||
<Description>OK</Description>
|
||||
</DeleteObjectResponse>
|
||||
</DeleteObjectResponse>"""
|
||||
|
||||
S3_OBJECT_RESPONSE = """<PutObjectResponse xmlns="http://s3.amazonaws.com/doc/2006-03-01">
|
||||
<PutObjectResponse>
|
||||
<ETag>{{ key.etag }}</ETag>
|
||||
<LastModified>{{ key.last_modified_ISO8601 }}</LastModified>
|
||||
</PutObjectResponse>
|
||||
</PutObjectResponse>"""
|
||||
|
||||
S3_OBJECT_ACL_RESPONSE = """<?xml version="1.0" encoding="UTF-8"?>
|
||||
<AccessControlPolicy xmlns="http://s3.amazonaws.com/doc/2006-03-01/">
|
||||
<Owner>
|
||||
|
|
@ -1917,22 +1899,10 @@ S3_OBJECT_TAGGING_RESPONSE = """\
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<Tagging xmlns="http://s3.amazonaws.com/doc/2006-03-01/">
|
||||
<TagSet>
|
||||
{% for tag in obj.tagging.tag_set.tags %}
|
||||
{% for tag in tags %}
|
||||
<Tag>
|
||||
<Key>{{ tag.key }}</Key>
|
||||
<Value>{{ tag.value }}</Value>
|
||||
</Tag>
|
||||
{% endfor %}
|
||||
</TagSet>
|
||||
</Tagging>"""
|
||||
|
||||
S3_BUCKET_TAGGING_RESPONSE = """<?xml version="1.0" encoding="UTF-8"?>
|
||||
<Tagging>
|
||||
<TagSet>
|
||||
{% for tag in bucket.tagging.tag_set.tags %}
|
||||
<Tag>
|
||||
<Key>{{ tag.key }}</Key>
|
||||
<Value>{{ tag.value }}</Value>
|
||||
<Key>{{ tag.Key }}</Key>
|
||||
<Value>{{ tag.Value }}</Value>
|
||||
</Tag>
|
||||
{% endfor %}
|
||||
</TagSet>
|
||||
|
|
|
|||
|
|
@ -15,5 +15,5 @@ url_paths = {
|
|||
# path-based bucket + key
|
||||
"{0}/(?P<bucket_name_path>[^/]+)/(?P<key_name>.+)": S3ResponseInstance.key_or_control_response,
|
||||
# subdomain bucket + key with empty first part of path
|
||||
"{0}//(?P<key_name>.*)$": S3ResponseInstance.key_or_control_response,
|
||||
"{0}/(?P<key_name>/.*)$": S3ResponseInstance.key_or_control_response,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -146,6 +146,12 @@ class _VersionedKeyStore(dict):
|
|||
for key in self:
|
||||
yield key, self.getlist(key)
|
||||
|
||||
def item_size(self):
|
||||
size = 0
|
||||
for val in self.values():
|
||||
size += sys.getsizeof(val)
|
||||
return size
|
||||
|
||||
items = iteritems = _iteritems
|
||||
lists = iterlists = _iterlists
|
||||
values = itervalues = _itervalues
|
||||
|
|
|
|||
|
|
@ -107,6 +107,34 @@ class SecretsManagerBackend(BaseBackend):
|
|||
|
||||
return response
|
||||
|
||||
def update_secret(
|
||||
self, secret_id, secret_string=None, secret_binary=None, **kwargs
|
||||
):
|
||||
|
||||
# error if secret does not exist
|
||||
if secret_id not in self.secrets.keys():
|
||||
raise SecretNotFoundException()
|
||||
|
||||
if "deleted_date" in self.secrets[secret_id]:
|
||||
raise InvalidRequestException(
|
||||
"An error occurred (InvalidRequestException) when calling the UpdateSecret operation: "
|
||||
"You can't perform this operation on the secret because it was marked for deletion."
|
||||
)
|
||||
|
||||
version_id = self._add_secret(
|
||||
secret_id, secret_string=secret_string, secret_binary=secret_binary
|
||||
)
|
||||
|
||||
response = json.dumps(
|
||||
{
|
||||
"ARN": secret_arn(self.region, secret_id),
|
||||
"Name": secret_id,
|
||||
"VersionId": version_id,
|
||||
}
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def create_secret(
|
||||
self, name, secret_string=None, secret_binary=None, tags=[], **kwargs
|
||||
):
|
||||
|
|
|
|||
|
|
@ -29,6 +29,16 @@ class SecretsManagerResponse(BaseResponse):
|
|||
tags=tags,
|
||||
)
|
||||
|
||||
def update_secret(self):
|
||||
secret_id = self._get_param("SecretId")
|
||||
secret_string = self._get_param("SecretString")
|
||||
secret_binary = self._get_param("SecretBinary")
|
||||
return secretsmanager_backends[self.region].update_secret(
|
||||
secret_id=secret_id,
|
||||
secret_string=secret_string,
|
||||
secret_binary=secret_binary,
|
||||
)
|
||||
|
||||
def get_random_password(self):
|
||||
password_length = self._get_param("PasswordLength", if_none=32)
|
||||
exclude_characters = self._get_param("ExcludeCharacters", if_none="")
|
||||
|
|
|
|||
|
|
@ -651,7 +651,7 @@ class SimpleSystemManagerBackend(BaseBackend):
|
|||
label.startswith("aws")
|
||||
or label.startswith("ssm")
|
||||
or label[:1].isdigit()
|
||||
or not re.match("^[a-zA-z0-9_\.\-]*$", label)
|
||||
or not re.match(r"^[a-zA-z0-9_\.\-]*$", label)
|
||||
):
|
||||
invalid_labels.append(label)
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -5,15 +5,23 @@ class TaggingService:
|
|||
self.valueName = valueName
|
||||
self.tags = {}
|
||||
|
||||
def get_tag_dict_for_resource(self, arn):
|
||||
result = {}
|
||||
if self.has_tags(arn):
|
||||
for k, v in self.tags[arn].items():
|
||||
result[k] = v
|
||||
return result
|
||||
|
||||
def list_tags_for_resource(self, arn):
|
||||
result = []
|
||||
if arn in self.tags:
|
||||
if self.has_tags(arn):
|
||||
for k, v in self.tags[arn].items():
|
||||
result.append({self.keyName: k, self.valueName: v})
|
||||
return {self.tagName: result}
|
||||
|
||||
def delete_all_tags_for_resource(self, arn):
|
||||
del self.tags[arn]
|
||||
if self.has_tags(arn):
|
||||
del self.tags[arn]
|
||||
|
||||
def has_tags(self, arn):
|
||||
return arn in self.tags
|
||||
|
|
@ -27,6 +35,12 @@ class TaggingService:
|
|||
else:
|
||||
self.tags[arn][t[self.keyName]] = None
|
||||
|
||||
def copy_tags(self, from_arn, to_arn):
|
||||
if self.has_tags(from_arn):
|
||||
self.tag_resource(
|
||||
to_arn, self.list_tags_for_resource(from_arn)[self.tagName]
|
||||
)
|
||||
|
||||
def untag_resource_using_names(self, arn, tag_names):
|
||||
for name in tag_names:
|
||||
if name in self.tags.get(arn, {}):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue