Merge remote master

This commit is contained in:
Ilya Shmygol 2019-12-11 16:17:21 +01:00
commit 3a42079ec7
45 changed files with 1642 additions and 192 deletions

View file

@ -0,0 +1,31 @@
from botocore.client import ClientError
class LambdaClientError(ClientError):
def __init__(self, error, message):
error_response = {"Error": {"Code": error, "Message": message}}
super(LambdaClientError, self).__init__(error_response, None)
class CrossAccountNotAllowed(LambdaClientError):
def __init__(self):
super(CrossAccountNotAllowed, self).__init__(
"AccessDeniedException", "Cross-account pass role is not allowed."
)
class InvalidParameterValueException(LambdaClientError):
def __init__(self, message):
super(InvalidParameterValueException, self).__init__(
"InvalidParameterValueException", message
)
class InvalidRoleFormat(LambdaClientError):
pattern = r"arn:(aws[a-zA-Z-]*)?:iam::(\d{12}):role/?[a-zA-Z_0-9+=,.@\-_/]+"
def __init__(self, role):
message = "1 validation error detected: Value '{0}' at 'role' failed to satisfy constraint: Member must satisfy regular expression pattern: {1}".format(
role, InvalidRoleFormat.pattern
)
super(InvalidRoleFormat, self).__init__("ValidationException", message)

View file

@ -26,11 +26,18 @@ import requests.adapters
import boto.awslambda
from moto.core import BaseBackend, BaseModel
from moto.core.exceptions import RESTError
from moto.iam.models import iam_backend
from moto.iam.exceptions import IAMNotFoundException
from moto.core.utils import unix_time_millis
from moto.s3.models import s3_backend
from moto.logs.models import logs_backends
from moto.s3.exceptions import MissingBucket, MissingKey
from moto import settings
from .exceptions import (
CrossAccountNotAllowed,
InvalidRoleFormat,
InvalidParameterValueException,
)
from .utils import make_function_arn, make_function_ver_arn
from moto.sqs import sqs_backends
from moto.dynamodb2 import dynamodb_backends2
@ -214,9 +221,8 @@ class LambdaFunction(BaseModel):
key = s3_backend.get_key(self.code["S3Bucket"], self.code["S3Key"])
except MissingBucket:
if do_validate_s3():
raise ValueError(
"InvalidParameterValueException",
"Error occurred while GetObject. S3 Error Code: NoSuchBucket. S3 Error Message: The specified bucket does not exist",
raise InvalidParameterValueException(
"Error occurred while GetObject. S3 Error Code: NoSuchBucket. S3 Error Message: The specified bucket does not exist"
)
except MissingKey:
if do_validate_s3():
@ -357,6 +363,8 @@ class LambdaFunction(BaseModel):
self.code_bytes = key.value
self.code_size = key.size
self.code_sha_256 = hashlib.sha256(key.value).hexdigest()
self.code["S3Bucket"] = updated_spec["S3Bucket"]
self.code["S3Key"] = updated_spec["S3Key"]
return self.get_configuration()
@ -520,6 +528,15 @@ class LambdaFunction(BaseModel):
return make_function_arn(self.region, ACCOUNT_ID, self.function_name)
raise UnformattedGetAttTemplateException()
@classmethod
def update_from_cloudformation_json(
cls, new_resource_name, cloudformation_json, original_resource, region_name
):
updated_props = cloudformation_json["Properties"]
original_resource.update_configuration(updated_props)
original_resource.update_function_code(updated_props["Code"])
return original_resource
@staticmethod
def _create_zipfile_from_plaintext_code(code):
zip_output = io.BytesIO()
@ -529,6 +546,9 @@ class LambdaFunction(BaseModel):
zip_output.seek(0)
return zip_output.read()
def delete(self, region):
lambda_backends[region].delete_function(self.function_name)
class EventSourceMapping(BaseModel):
def __init__(self, spec):
@ -668,6 +688,19 @@ class LambdaStorage(object):
:param fn: Function
:type fn: LambdaFunction
"""
valid_role = re.match(InvalidRoleFormat.pattern, fn.role)
if valid_role:
account = valid_role.group(2)
if account != ACCOUNT_ID:
raise CrossAccountNotAllowed()
try:
iam_backend.get_role_by_arn(fn.role)
except IAMNotFoundException:
raise InvalidParameterValueException(
"The role defined for the function cannot be assumed by Lambda."
)
else:
raise InvalidRoleFormat(fn.role)
if fn.function_name in self._functions:
self._functions[fn.function_name]["latest"] = fn
else:

View file

@ -211,30 +211,14 @@ class LambdaResponse(BaseResponse):
return 200, {}, json.dumps(result)
def _create_function(self, request, full_url, headers):
try:
fn = self.lambda_backend.create_function(self.json_body)
except ValueError as e:
return (
400,
{},
json.dumps({"Error": {"Code": e.args[0], "Message": e.args[1]}}),
)
else:
config = fn.get_configuration()
return 201, {}, json.dumps(config)
fn = self.lambda_backend.create_function(self.json_body)
config = fn.get_configuration()
return 201, {}, json.dumps(config)
def _create_event_source_mapping(self, request, full_url, headers):
try:
fn = self.lambda_backend.create_event_source_mapping(self.json_body)
except ValueError as e:
return (
400,
{},
json.dumps({"Error": {"Code": e.args[0], "Message": e.args[1]}}),
)
else:
config = fn.get_configuration()
return 201, {}, json.dumps(config)
fn = self.lambda_backend.create_event_source_mapping(self.json_body)
config = fn.get_configuration()
return 201, {}, json.dumps(config)
def _list_event_source_mappings(self, event_source_arn, function_name):
esms = self.lambda_backend.list_event_source_mappings(

View file

@ -5,6 +5,7 @@ from moto.core.exceptions import RESTError
import boto.ec2.cloudwatch
from datetime import datetime, timedelta
from dateutil.tz import tzutc
from uuid import uuid4
from .utils import make_arn_for_dashboard
DEFAULT_ACCOUNT_ID = 123456789012
@ -193,6 +194,7 @@ class CloudWatchBackend(BaseBackend):
self.alarms = {}
self.dashboards = {}
self.metric_data = []
self.paged_metric_data = {}
def put_metric_alarm(
self,
@ -377,6 +379,36 @@ class CloudWatchBackend(BaseBackend):
self.alarms[alarm_name].update_state(reason, reason_data, state_value)
def list_metrics(self, next_token, namespace, metric_name):
if next_token:
if next_token not in self.paged_metric_data:
raise RESTError(
"PaginationException", "Request parameter NextToken is invalid"
)
else:
metrics = self.paged_metric_data[next_token]
del self.paged_metric_data[next_token] # Cant reuse same token twice
return self._get_paginated(metrics)
else:
metrics = self.get_filtered_metrics(metric_name, namespace)
return self._get_paginated(metrics)
def get_filtered_metrics(self, metric_name, namespace):
metrics = self.get_all_metrics()
if namespace:
metrics = [md for md in metrics if md.namespace == namespace]
if metric_name:
metrics = [md for md in metrics if md.name == metric_name]
return metrics
def _get_paginated(self, metrics):
if len(metrics) > 500:
next_token = str(uuid4())
self.paged_metric_data[next_token] = metrics[500:]
return next_token, metrics[0:500]
else:
return None, metrics
class LogGroup(BaseModel):
def __init__(self, spec):

View file

@ -120,9 +120,14 @@ class CloudWatchResponse(BaseResponse):
@amzn_request_id
def list_metrics(self):
metrics = self.cloudwatch_backend.get_all_metrics()
namespace = self._get_param("Namespace")
metric_name = self._get_param("MetricName")
next_token = self._get_param("NextToken")
next_token, metrics = self.cloudwatch_backend.list_metrics(
next_token, namespace, metric_name
)
template = self.response_template(LIST_METRICS_TEMPLATE)
return template.render(metrics=metrics)
return template.render(metrics=metrics, next_token=next_token)
@amzn_request_id
def delete_dashboards(self):
@ -340,9 +345,11 @@ LIST_METRICS_TEMPLATE = """<ListMetricsResponse xmlns="http://monitoring.amazona
</member>
{% endfor %}
</Metrics>
{% if next_token is not none %}
<NextToken>
96e88479-4662-450b-8a13-239ded6ce9fe
{{ next_token }}
</NextToken>
{% endif %}
</ListMetricsResult>
</ListMetricsResponse>"""

View file

@ -307,7 +307,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
def _convert(elem, is_last):
if not re.match("^{.*}$", elem):
return elem
name = elem.replace("{", "").replace("}", "")
name = elem.replace("{", "").replace("}", "").replace("+", "")
if is_last:
return "(?P<%s>[^/]*)" % name
return "(?P<%s>.*)" % name

View file

@ -8,6 +8,7 @@ import random
import re
import six
import string
from botocore.exceptions import ClientError
from six.moves.urllib.parse import urlparse
@ -141,7 +142,10 @@ class convert_flask_to_httpretty_response(object):
def __call__(self, args=None, **kwargs):
from flask import request, Response
result = self.callback(request, request.url, {})
try:
result = self.callback(request, request.url, {})
except ClientError as exc:
result = 400, {}, exc.response["Error"]["Message"]
# result is a status, headers, response tuple
if len(result) == 3:
status, headers, content = result

View file

@ -153,7 +153,7 @@ class DataSyncResponse(BaseResponse):
task_execution_arn = self._get_param("TaskExecutionArn")
task_execution = self.datasync_backend._get_task_execution(task_execution_arn)
result = json.dumps(
{"TaskExecutionArn": task_execution.arn, "Status": task_execution.status,}
{"TaskExecutionArn": task_execution.arn, "Status": task_execution.status}
)
if task_execution.status == "SUCCESS":
self.datasync_backend.tasks[task_execution.task_arn].status = "AVAILABLE"

View file

@ -4,6 +4,4 @@ from .responses import DataSyncResponse
url_bases = ["https?://(.*?)(datasync)(.*?).amazonaws.com"]
url_paths = {
"{0}/$": DataSyncResponse.dispatch,
}
url_paths = {"{0}/$": DataSyncResponse.dispatch}

View file

@ -63,6 +63,16 @@ class DynamoType(object):
elif self.is_map():
self.value = dict((k, DynamoType(v)) for k, v in self.value.items())
def get(self, key):
if not key:
return self
else:
key_head = key.split(".")[0]
key_tail = ".".join(key.split(".")[1:])
if key_head not in self.value:
self.value[key_head] = DynamoType({"NONE": None})
return self.value[key_head].get(key_tail)
def set(self, key, new_value, index=None):
if index:
index = int(index)
@ -174,8 +184,13 @@ class DynamoType(object):
Returns DynamoType or None.
"""
if isinstance(key, six.string_types) and self.is_map() and key in self.value:
return DynamoType(self.value[key])
if isinstance(key, six.string_types) and self.is_map():
if "." in key and key.split(".")[0] in self.value:
return self.value[key.split(".")[0]].child_attr(
".".join(key.split(".")[1:])
)
elif "." not in key and key in self.value:
return DynamoType(self.value[key])
if isinstance(key, int) and self.is_list():
idx = key
@ -383,11 +398,19 @@ class Item(BaseModel):
# created with only this value if it doesn't exist yet
# New value must be of same set type as previous value
elif dyn_value.is_set():
existing = self.attrs.get(key, DynamoType({dyn_value.type: {}}))
if not existing.same_type(dyn_value):
key_head = key.split(".")[0]
key_tail = ".".join(key.split(".")[1:])
if key_head not in self.attrs:
self.attrs[key_head] = DynamoType({dyn_value.type: {}})
existing = self.attrs.get(key_head)
existing = existing.get(key_tail)
if existing.value and not existing.same_type(dyn_value):
raise TypeError()
new_set = set(existing.value).union(dyn_value.value)
self.attrs[key] = DynamoType({existing.type: list(new_set)})
new_set = set(existing.value or []).union(dyn_value.value)
existing.set(
key=None,
new_value=DynamoType({dyn_value.type: list(new_set)}),
)
else: # Number and Sets are the only supported types for ADD
raise TypeError
@ -402,12 +425,18 @@ class Item(BaseModel):
if not dyn_value.is_set():
raise TypeError
existing = self.attrs.get(key, None)
key_head = key.split(".")[0]
key_tail = ".".join(key.split(".")[1:])
existing = self.attrs.get(key_head)
existing = existing.get(key_tail)
if existing:
if not existing.same_type(dyn_value):
raise TypeError
new_set = set(existing.value).difference(dyn_value.value)
self.attrs[key] = DynamoType({existing.type: list(new_set)})
existing.set(
key=None,
new_value=DynamoType({existing.type: list(new_set)}),
)
else:
raise NotImplementedError(
"{} update action not yet supported".format(action)
@ -418,7 +447,14 @@ class Item(BaseModel):
list_append_re = re.match("list_append\\((.+),(.+)\\)", value)
if list_append_re:
new_value = expression_attribute_values[list_append_re.group(2).strip()]
old_list = self.attrs[list_append_re.group(1)]
old_list_key = list_append_re.group(1)
# Get the existing value
old_list = self.attrs[old_list_key.split(".")[0]]
if "." in old_list_key:
# Value is nested inside a map - find the appropriate child attr
old_list = old_list.child_attr(
".".join(old_list_key.split(".")[1:])
)
if not old_list.is_list():
raise ParamValidationError
old_list.value.extend(new_value["L"])

View file

@ -1644,23 +1644,27 @@ class RegionsAndZonesBackend(object):
class SecurityRule(object):
def __init__(self, ip_protocol, from_port, to_port, ip_ranges, source_groups):
self.ip_protocol = ip_protocol
self.from_port = from_port
self.to_port = to_port
self.ip_ranges = ip_ranges or []
self.source_groups = source_groups
@property
def unique_representation(self):
return "{0}-{1}-{2}-{3}-{4}".format(
self.ip_protocol,
self.from_port,
self.to_port,
self.ip_ranges,
self.source_groups,
)
if ip_protocol != "-1":
self.from_port = from_port
self.to_port = to_port
def __eq__(self, other):
return self.unique_representation == other.unique_representation
if self.ip_protocol != other.ip_protocol:
return False
if self.ip_ranges != other.ip_ranges:
return False
if self.source_groups != other.source_groups:
return False
if self.ip_protocol != "-1":
if self.from_port != other.from_port:
return False
if self.to_port != other.to_port:
return False
return True
class SecurityGroup(TaggedEC2Resource):
@ -1670,7 +1674,7 @@ class SecurityGroup(TaggedEC2Resource):
self.name = name
self.description = description
self.ingress_rules = []
self.egress_rules = [SecurityRule(-1, None, None, ["0.0.0.0/0"], [])]
self.egress_rules = [SecurityRule("-1", None, None, ["0.0.0.0/0"], [])]
self.enis = {}
self.vpc_id = vpc_id
self.owner_id = OWNER_ID

View file

@ -567,16 +567,14 @@ class EC2ContainerServiceBackend(BaseBackend):
return task_definition
def list_task_definitions(self):
"""
Filtering not implemented
"""
def list_task_definitions(self, family_prefix):
task_arns = []
for task_definition_list in self.task_definitions.values():
task_arns.extend(
[
task_definition.arn
for task_definition in task_definition_list.values()
if family_prefix is None or task_definition.family == family_prefix
]
)
return task_arns

View file

@ -68,7 +68,8 @@ class EC2ContainerServiceResponse(BaseResponse):
return json.dumps({"taskDefinition": task_definition.response_object})
def list_task_definitions(self):
task_definition_arns = self.ecs_backend.list_task_definitions()
family_prefix = self._get_param("familyPrefix")
task_definition_arns = self.ecs_backend.list_task_definitions(family_prefix)
return json.dumps(
{
"taskDefinitionArns": task_definition_arns

View file

@ -316,8 +316,7 @@ class EventsBackend(BaseBackend):
if not event_bus:
raise JsonRESTError(
"ResourceNotFoundException",
"Event bus {} does not exist.".format(name),
"ResourceNotFoundException", "Event bus {} does not exist.".format(name)
)
return event_bus

View file

@ -261,10 +261,7 @@ class EventsHandler(BaseResponse):
name = self._get_param("Name")
event_bus = self.events_backend.describe_event_bus(name)
response = {
"Name": event_bus.name,
"Arn": event_bus.arn,
}
response = {"Name": event_bus.name, "Arn": event_bus.arn}
if event_bus.policy:
response["Policy"] = event_bus.policy
@ -285,10 +282,7 @@ class EventsHandler(BaseResponse):
response = []
for event_bus in self.events_backend.list_event_buses(name_prefix):
event_bus_response = {
"Name": event_bus.name,
"Arn": event_bus.arn,
}
event_bus_response = {"Name": event_bus.name, "Arn": event_bus.arn}
if event_bus.policy:
event_bus_response["Policy"] = event_bus.policy

View file

@ -371,7 +371,7 @@ class Role(BaseModel):
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
if attribute_name == "Arn":
raise NotImplementedError('"Fn::GetAtt" : [ "{0}" , "Arn" ]"')
return self.arn
raise UnformattedGetAttTemplateException()
def get_tags(self):

View file

@ -55,7 +55,7 @@ class FakeThingType(BaseModel):
self.thing_type_properties = thing_type_properties
self.thing_type_id = str(uuid.uuid4()) # I don't know the rule of id
t = time.time()
self.metadata = {"deprecated": False, "creationData": int(t * 1000) / 1000.0}
self.metadata = {"deprecated": False, "creationDate": int(t * 1000) / 1000.0}
self.arn = "arn:aws:iot:%s:1:thingtype/%s" % (self.region_name, thing_type_name)
def to_dict(self):
@ -69,7 +69,12 @@ class FakeThingType(BaseModel):
class FakeThingGroup(BaseModel):
def __init__(
self, thing_group_name, parent_group_name, thing_group_properties, region_name
self,
thing_group_name,
parent_group_name,
thing_group_properties,
region_name,
thing_groups,
):
self.region_name = region_name
self.thing_group_name = thing_group_name
@ -78,7 +83,32 @@ class FakeThingGroup(BaseModel):
self.parent_group_name = parent_group_name
self.thing_group_properties = thing_group_properties or {}
t = time.time()
self.metadata = {"creationData": int(t * 1000) / 1000.0}
self.metadata = {"creationDate": int(t * 1000) / 1000.0}
if parent_group_name:
self.metadata["parentGroupName"] = parent_group_name
# initilize rootToParentThingGroups
if "rootToParentThingGroups" not in self.metadata:
self.metadata["rootToParentThingGroups"] = []
# search for parent arn
for thing_group_arn, thing_group in thing_groups.items():
if thing_group.thing_group_name == parent_group_name:
parent_thing_group_structure = thing_group
break
# if parent arn found (should always be found)
if parent_thing_group_structure:
# copy parent's rootToParentThingGroups
if "rootToParentThingGroups" in parent_thing_group_structure.metadata:
self.metadata["rootToParentThingGroups"].extend(
parent_thing_group_structure.metadata["rootToParentThingGroups"]
)
self.metadata["rootToParentThingGroups"].extend(
[
{
"groupName": parent_group_name,
"groupArn": parent_thing_group_structure.arn,
}
]
)
self.arn = "arn:aws:iot:%s:1:thinggroup/%s" % (
self.region_name,
thing_group_name,
@ -639,6 +669,7 @@ class IoTBackend(BaseBackend):
parent_group_name,
thing_group_properties,
self.region_name,
self.thing_groups,
)
self.thing_groups[thing_group.arn] = thing_group
return thing_group.thing_group_name, thing_group.arn, thing_group.thing_group_id

View file

@ -1,7 +1,11 @@
from moto.core import BaseBackend
import boto.logs
from moto.core.utils import unix_time_millis
from .exceptions import ResourceNotFoundException, ResourceAlreadyExistsException
from .exceptions import (
ResourceNotFoundException,
ResourceAlreadyExistsException,
InvalidParameterException,
)
class LogEvent:
@ -118,41 +122,66 @@ class LogStream:
return True
def get_paging_token_from_index(index, back=False):
if index is not None:
return "b/{:056d}".format(index) if back else "f/{:056d}".format(index)
return 0
def get_index_from_paging_token(token):
def get_index_and_direction_from_token(token):
if token is not None:
return int(token[2:])
return 0
try:
return token[0], int(token[2:])
except Exception:
raise InvalidParameterException(
"The specified nextToken is invalid."
)
return None, 0
events = sorted(
filter(filter_func, self.events),
key=lambda event: event.timestamp,
reverse=start_from_head,
filter(filter_func, self.events), key=lambda event: event.timestamp,
)
next_index = get_index_from_paging_token(next_token)
back_index = next_index
direction, index = get_index_and_direction_from_token(next_token)
limit_index = limit - 1
final_index = len(events) - 1
if direction is None:
if start_from_head:
start_index = 0
end_index = start_index + limit_index
else:
end_index = final_index
start_index = end_index - limit_index
elif direction == "f":
start_index = index + 1
end_index = start_index + limit_index
elif direction == "b":
end_index = index - 1
start_index = end_index - limit_index
else:
raise InvalidParameterException("The specified nextToken is invalid.")
if start_index < 0:
start_index = 0
elif start_index > final_index:
return (
[],
"b/{:056d}".format(final_index),
"f/{:056d}".format(final_index),
)
if end_index > final_index:
end_index = final_index
elif end_index < 0:
return (
[],
"b/{:056d}".format(0),
"f/{:056d}".format(0),
)
events_page = [
event.to_response_dict()
for event in events[next_index : next_index + limit]
event.to_response_dict() for event in events[start_index : end_index + 1]
]
if next_index + limit < len(self.events):
next_index += limit
else:
next_index = len(self.events)
back_index -= limit
if back_index <= 0:
back_index = 0
return (
events_page,
get_paging_token_from_index(back_index, True),
get_paging_token_from_index(next_index),
"b/{:056d}".format(start_index),
"f/{:056d}".format(end_index),
)
def filter_log_events(

View file

@ -0,0 +1,12 @@
from __future__ import unicode_literals
from moto.core.exceptions import JsonRESTError
class InvalidInputException(JsonRESTError):
code = 400
def __init__(self):
super(InvalidInputException, self).__init__(
"InvalidInputException",
"You provided a value that does not match the required pattern.",
)

View file

@ -8,6 +8,7 @@ from moto.core import BaseBackend, BaseModel
from moto.core.exceptions import RESTError
from moto.core.utils import unix_time
from moto.organizations import utils
from moto.organizations.exceptions import InvalidInputException
class FakeOrganization(BaseModel):
@ -57,6 +58,7 @@ class FakeAccount(BaseModel):
self.joined_method = "CREATED"
self.parent_id = organization.root_id
self.attached_policies = []
self.tags = {}
@property
def arn(self):
@ -442,5 +444,32 @@ class OrganizationsBackend(BaseBackend):
]
return dict(Targets=objects)
def tag_resource(self, **kwargs):
account = next((a for a in self.accounts if a.id == kwargs["ResourceId"]), None)
if account is None:
raise InvalidInputException
new_tags = {tag["Key"]: tag["Value"] for tag in kwargs["Tags"]}
account.tags.update(new_tags)
def list_tags_for_resource(self, **kwargs):
account = next((a for a in self.accounts if a.id == kwargs["ResourceId"]), None)
if account is None:
raise InvalidInputException
tags = [{"Key": key, "Value": value} for key, value in account.tags.items()]
return dict(Tags=tags)
def untag_resource(self, **kwargs):
account = next((a for a in self.accounts if a.id == kwargs["ResourceId"]), None)
if account is None:
raise InvalidInputException
for key in kwargs["TagKeys"]:
account.tags.pop(key, None)
organizations_backend = OrganizationsBackend()

View file

@ -119,3 +119,18 @@ class OrganizationsResponse(BaseResponse):
return json.dumps(
self.organizations_backend.list_targets_for_policy(**self.request_params)
)
def tag_resource(self):
return json.dumps(
self.organizations_backend.tag_resource(**self.request_params)
)
def list_tags_for_resource(self):
return json.dumps(
self.organizations_backend.list_tags_for_resource(**self.request_params)
)
def untag_resource(self):
return json.dumps(
self.organizations_backend.untag_resource(**self.request_params)
)

View file

@ -17,7 +17,7 @@ from .exceptions import (
InvalidRequestException,
ClientError,
)
from .utils import random_password, secret_arn
from .utils import random_password, secret_arn, get_secret_name_from_arn
class SecretsManager(BaseModel):
@ -25,11 +25,25 @@ class SecretsManager(BaseModel):
self.region = region_name
class SecretsStore(dict):
def __setitem__(self, key, value):
new_key = get_secret_name_from_arn(key)
super(SecretsStore, self).__setitem__(new_key, value)
def __getitem__(self, key):
new_key = get_secret_name_from_arn(key)
return super(SecretsStore, self).__getitem__(new_key)
def __contains__(self, key):
new_key = get_secret_name_from_arn(key)
return dict.__contains__(self, new_key)
class SecretsManagerBackend(BaseBackend):
def __init__(self, region_name=None, **kwargs):
super(SecretsManagerBackend, self).__init__()
self.region = region_name
self.secrets = {}
self.secrets = SecretsStore()
def reset(self):
region_name = self.region
@ -44,7 +58,6 @@ class SecretsManagerBackend(BaseBackend):
return (dt - epoch).total_seconds()
def get_secret_value(self, secret_id, version_id, version_stage):
if not self._is_valid_identifier(secret_id):
raise SecretNotFoundException()
@ -453,6 +466,30 @@ class SecretsManagerBackend(BaseBackend):
return arn, name
@staticmethod
def get_resource_policy(secret_id):
resource_policy = {
"Version": "2012-10-17",
"Statement": {
"Effect": "Allow",
"Principal": {
"AWS": [
"arn:aws:iam::111122223333:root",
"arn:aws:iam::444455556666:root",
]
},
"Action": ["secretsmanager:GetSecretValue"],
"Resource": "*",
},
}
return json.dumps(
{
"ARN": secret_id,
"Name": secret_id,
"ResourcePolicy": json.dumps(resource_policy),
}
)
available_regions = boto3.session.Session().get_available_regions("secretsmanager")
secretsmanager_backends = {

View file

@ -114,3 +114,9 @@ class SecretsManagerResponse(BaseResponse):
secret_id=secret_id
)
return json.dumps(dict(ARN=arn, Name=name))
def get_resource_policy(self):
secret_id = self._get_param("SecretId")
return secretsmanager_backends[self.region].get_resource_policy(
secret_id=secret_id
)

View file

@ -72,6 +72,19 @@ def secret_arn(region, secret_id):
)
def get_secret_name_from_arn(secret_id):
# can fetch by both arn and by name
# but we are storing via name
# so we need to change the arn to name
# if it starts with arn then the secret id is arn
if secret_id.startswith("arn:aws:secretsmanager:"):
# split the arn by colon
# then get the last value which is the name appended with a random string
# then remove the random string
secret_id = "-".join(secret_id.split(":")[-1].split("-")[:-1])
return secret_id
def _exclude_characters(password, exclude_characters):
for c in exclude_characters:
if c in string.punctuation: