Run black on moto & test directories.

This commit is contained in:
Asher Foa 2019-10-31 08:44:26 -07:00
commit 96e5b1993d
507 changed files with 52541 additions and 47814 deletions

View file

@ -2,5 +2,5 @@ from __future__ import unicode_literals
from .models import ssm_backends
from ..core.models import base_decorator
ssm_backend = ssm_backends['us-east-1']
ssm_backend = ssm_backends["us-east-1"]
mock_ssm = base_decorator(ssm_backends)

View file

@ -6,29 +6,25 @@ class InvalidFilterKey(JsonRESTError):
code = 400
def __init__(self, message):
super(InvalidFilterKey, self).__init__(
"InvalidFilterKey", message)
super(InvalidFilterKey, self).__init__("InvalidFilterKey", message)
class InvalidFilterOption(JsonRESTError):
code = 400
def __init__(self, message):
super(InvalidFilterOption, self).__init__(
"InvalidFilterOption", message)
super(InvalidFilterOption, self).__init__("InvalidFilterOption", message)
class InvalidFilterValue(JsonRESTError):
code = 400
def __init__(self, message):
super(InvalidFilterValue, self).__init__(
"InvalidFilterValue", message)
super(InvalidFilterValue, self).__init__("InvalidFilterValue", message)
class ValidationException(JsonRESTError):
code = 400
def __init__(self, message):
super(ValidationException, self).__init__(
"ValidationException", message)
super(ValidationException, self).__init__("ValidationException", message)

View file

@ -13,12 +13,26 @@ import time
import uuid
import itertools
from .exceptions import ValidationException, InvalidFilterValue, InvalidFilterOption, InvalidFilterKey
from .exceptions import (
ValidationException,
InvalidFilterValue,
InvalidFilterOption,
InvalidFilterKey,
)
class Parameter(BaseModel):
def __init__(self, name, value, type, description, allowed_pattern, keyid,
last_modified_date, version):
def __init__(
self,
name,
value,
type,
description,
allowed_pattern,
keyid,
last_modified_date,
version,
):
self.name = name
self.type = type
self.description = description
@ -27,48 +41,48 @@ class Parameter(BaseModel):
self.last_modified_date = last_modified_date
self.version = version
if self.type == 'SecureString':
if self.type == "SecureString":
if not self.keyid:
self.keyid = 'alias/aws/ssm'
self.keyid = "alias/aws/ssm"
self.value = self.encrypt(value)
else:
self.value = value
def encrypt(self, value):
return 'kms:{}:'.format(self.keyid) + value
return "kms:{}:".format(self.keyid) + value
def decrypt(self, value):
if self.type != 'SecureString':
if self.type != "SecureString":
return value
prefix = 'kms:{}:'.format(self.keyid or 'default')
prefix = "kms:{}:".format(self.keyid or "default")
if value.startswith(prefix):
return value[len(prefix):]
return value[len(prefix) :]
def response_object(self, decrypt=False):
r = {
'Name': self.name,
'Type': self.type,
'Value': self.decrypt(self.value) if decrypt else self.value,
'Version': self.version,
"Name": self.name,
"Type": self.type,
"Value": self.decrypt(self.value) if decrypt else self.value,
"Version": self.version,
}
return r
def describe_response_object(self, decrypt=False):
r = self.response_object(decrypt)
r['LastModifiedDate'] = int(self.last_modified_date)
r['LastModifiedUser'] = 'N/A'
r["LastModifiedDate"] = int(self.last_modified_date)
r["LastModifiedUser"] = "N/A"
if self.description:
r['Description'] = self.description
r["Description"] = self.description
if self.keyid:
r['KeyId'] = self.keyid
r["KeyId"] = self.keyid
if self.allowed_pattern:
r['AllowedPattern'] = self.allowed_pattern
r["AllowedPattern"] = self.allowed_pattern
return r
@ -77,11 +91,23 @@ MAX_TIMEOUT_SECONDS = 3600
class Command(BaseModel):
def __init__(self, comment='', document_name='', timeout_seconds=MAX_TIMEOUT_SECONDS,
instance_ids=None, max_concurrency='', max_errors='',
notification_config=None, output_s3_bucket_name='',
output_s3_key_prefix='', output_s3_region='', parameters=None,
service_role_arn='', targets=None, backend_region='us-east-1'):
def __init__(
self,
comment="",
document_name="",
timeout_seconds=MAX_TIMEOUT_SECONDS,
instance_ids=None,
max_concurrency="",
max_errors="",
notification_config=None,
output_s3_bucket_name="",
output_s3_key_prefix="",
output_s3_region="",
parameters=None,
service_role_arn="",
targets=None,
backend_region="us-east-1",
):
if instance_ids is None:
instance_ids = []
@ -99,12 +125,14 @@ class Command(BaseModel):
self.completed_count = len(instance_ids)
self.target_count = len(instance_ids)
self.command_id = str(uuid.uuid4())
self.status = 'Success'
self.status_details = 'Details placeholder'
self.status = "Success"
self.status_details = "Details placeholder"
self.requested_date_time = datetime.datetime.now()
self.requested_date_time_iso = self.requested_date_time.isoformat()
expires_after = self.requested_date_time + datetime.timedelta(0, timeout_seconds)
expires_after = self.requested_date_time + datetime.timedelta(
0, timeout_seconds
)
self.expires_after = expires_after.isoformat()
self.comment = comment
@ -122,9 +150,11 @@ class Command(BaseModel):
self.backend_region = backend_region
# Get instance ids from a cloud formation stack target.
stack_instance_ids = [self.get_instance_ids_by_stack_ids(target['Values']) for
target in self.targets if
target['Key'] == 'tag:aws:cloudformation:stack-name']
stack_instance_ids = [
self.get_instance_ids_by_stack_ids(target["Values"])
for target in self.targets
if target["Key"] == "tag:aws:cloudformation:stack-name"
]
self.instance_ids += list(itertools.chain.from_iterable(stack_instance_ids))
@ -132,7 +162,8 @@ class Command(BaseModel):
self.invocations = []
for instance_id in self.instance_ids:
self.invocations.append(
self.invocation_response(instance_id, "aws:runShellScript"))
self.invocation_response(instance_id, "aws:runShellScript")
)
def get_instance_ids_by_stack_ids(self, stack_ids):
instance_ids = []
@ -140,34 +171,36 @@ class Command(BaseModel):
for stack_id in stack_ids:
stack_resources = cloudformation_backend.list_stack_resources(stack_id)
instance_resources = [
instance.id for instance in stack_resources
if instance.type == "AWS::EC2::Instance"]
instance.id
for instance in stack_resources
if instance.type == "AWS::EC2::Instance"
]
instance_ids.extend(instance_resources)
return instance_ids
def response_object(self):
r = {
'CommandId': self.command_id,
'Comment': self.comment,
'CompletedCount': self.completed_count,
'DocumentName': self.document_name,
'ErrorCount': self.error_count,
'ExpiresAfter': self.expires_after,
'InstanceIds': self.instance_ids,
'MaxConcurrency': self.max_concurrency,
'MaxErrors': self.max_errors,
'NotificationConfig': self.notification_config,
'OutputS3Region': self.output_s3_region,
'OutputS3BucketName': self.output_s3_bucket_name,
'OutputS3KeyPrefix': self.output_s3_key_prefix,
'Parameters': self.parameters,
'RequestedDateTime': self.requested_date_time_iso,
'ServiceRole': self.service_role_arn,
'Status': self.status,
'StatusDetails': self.status_details,
'TargetCount': self.target_count,
'Targets': self.targets,
"CommandId": self.command_id,
"Comment": self.comment,
"CompletedCount": self.completed_count,
"DocumentName": self.document_name,
"ErrorCount": self.error_count,
"ExpiresAfter": self.expires_after,
"InstanceIds": self.instance_ids,
"MaxConcurrency": self.max_concurrency,
"MaxErrors": self.max_errors,
"NotificationConfig": self.notification_config,
"OutputS3Region": self.output_s3_region,
"OutputS3BucketName": self.output_s3_bucket_name,
"OutputS3KeyPrefix": self.output_s3_key_prefix,
"Parameters": self.parameters,
"RequestedDateTime": self.requested_date_time_iso,
"ServiceRole": self.service_role_arn,
"Status": self.status,
"StatusDetails": self.status_details,
"TargetCount": self.target_count,
"Targets": self.targets,
}
return r
@ -181,44 +214,50 @@ class Command(BaseModel):
end_time = self.requested_date_time + elapsed_time_delta
r = {
'CommandId': self.command_id,
'InstanceId': instance_id,
'Comment': self.comment,
'DocumentName': self.document_name,
'PluginName': plugin_name,
'ResponseCode': 0,
'ExecutionStartDateTime': self.requested_date_time_iso,
'ExecutionElapsedTime': elapsed_time_iso,
'ExecutionEndDateTime': end_time.isoformat(),
'Status': 'Success',
'StatusDetails': 'Success',
'StandardOutputContent': '',
'StandardOutputUrl': '',
'StandardErrorContent': '',
"CommandId": self.command_id,
"InstanceId": instance_id,
"Comment": self.comment,
"DocumentName": self.document_name,
"PluginName": plugin_name,
"ResponseCode": 0,
"ExecutionStartDateTime": self.requested_date_time_iso,
"ExecutionElapsedTime": elapsed_time_iso,
"ExecutionEndDateTime": end_time.isoformat(),
"Status": "Success",
"StatusDetails": "Success",
"StandardOutputContent": "",
"StandardOutputUrl": "",
"StandardErrorContent": "",
}
return r
def get_invocation(self, instance_id, plugin_name):
invocation = next(
(invocation for invocation in self.invocations
if invocation['InstanceId'] == instance_id), None)
(
invocation
for invocation in self.invocations
if invocation["InstanceId"] == instance_id
),
None,
)
if invocation is None:
raise RESTError(
'InvocationDoesNotExist',
'An error occurred (InvocationDoesNotExist) when calling the GetCommandInvocation operation')
"InvocationDoesNotExist",
"An error occurred (InvocationDoesNotExist) when calling the GetCommandInvocation operation",
)
if plugin_name is not None and invocation['PluginName'] != plugin_name:
if plugin_name is not None and invocation["PluginName"] != plugin_name:
raise RESTError(
'InvocationDoesNotExist',
'An error occurred (InvocationDoesNotExist) when calling the GetCommandInvocation operation')
"InvocationDoesNotExist",
"An error occurred (InvocationDoesNotExist) when calling the GetCommandInvocation operation",
)
return invocation
class SimpleSystemManagerBackend(BaseBackend):
def __init__(self):
self._parameters = {}
self._resource_tags = defaultdict(lambda: defaultdict(dict))
@ -248,7 +287,9 @@ class SimpleSystemManagerBackend(BaseBackend):
def describe_parameters(self, filters, parameter_filters):
if filters and parameter_filters:
raise ValidationException('You can use either Filters or ParameterFilters in a single request.')
raise ValidationException(
"You can use either Filters or ParameterFilters in a single request."
)
self._validate_parameter_filters(parameter_filters, by_path=False)
@ -260,22 +301,22 @@ class SimpleSystemManagerBackend(BaseBackend):
if filters:
for filter in filters:
if filter['Key'] == 'Name':
if filter["Key"] == "Name":
k = ssm_parameter.name
for v in filter['Values']:
for v in filter["Values"]:
if k.startswith(v):
result.append(ssm_parameter)
break
elif filter['Key'] == 'Type':
elif filter["Key"] == "Type":
k = ssm_parameter.type
for v in filter['Values']:
for v in filter["Values"]:
if k == v:
result.append(ssm_parameter)
break
elif filter['Key'] == 'KeyId':
elif filter["Key"] == "KeyId":
k = ssm_parameter.keyid
if k:
for v in filter['Values']:
for v in filter["Values"]:
if k == v:
result.append(ssm_parameter)
break
@ -287,125 +328,157 @@ class SimpleSystemManagerBackend(BaseBackend):
def _validate_parameter_filters(self, parameter_filters, by_path):
for index, filter_obj in enumerate(parameter_filters or []):
key = filter_obj['Key']
values = filter_obj.get('Values', [])
key = filter_obj["Key"]
values = filter_obj.get("Values", [])
if key == 'Path':
option = filter_obj.get('Option', 'OneLevel')
if key == "Path":
option = filter_obj.get("Option", "OneLevel")
else:
option = filter_obj.get('Option', 'Equals')
option = filter_obj.get("Option", "Equals")
if not re.match(r'^tag:.+|Name|Type|KeyId|Path|Label|Tier$', key):
self._errors.append(self._format_error(
key='parameterFilters.{index}.member.key'.format(index=(index + 1)),
value=key,
constraint='Member must satisfy regular expression pattern: tag:.+|Name|Type|KeyId|Path|Label|Tier',
))
if not re.match(r"^tag:.+|Name|Type|KeyId|Path|Label|Tier$", key):
self._errors.append(
self._format_error(
key="parameterFilters.{index}.member.key".format(
index=(index + 1)
),
value=key,
constraint="Member must satisfy regular expression pattern: tag:.+|Name|Type|KeyId|Path|Label|Tier",
)
)
if len(key) > 132:
self._errors.append(self._format_error(
key='parameterFilters.{index}.member.key'.format(index=(index + 1)),
value=key,
constraint='Member must have length less than or equal to 132',
))
self._errors.append(
self._format_error(
key="parameterFilters.{index}.member.key".format(
index=(index + 1)
),
value=key,
constraint="Member must have length less than or equal to 132",
)
)
if len(option) > 10:
self._errors.append(self._format_error(
key='parameterFilters.{index}.member.option'.format(index=(index + 1)),
value='over 10 chars',
constraint='Member must have length less than or equal to 10',
))
self._errors.append(
self._format_error(
key="parameterFilters.{index}.member.option".format(
index=(index + 1)
),
value="over 10 chars",
constraint="Member must have length less than or equal to 10",
)
)
if len(values) > 50:
self._errors.append(self._format_error(
key='parameterFilters.{index}.member.values'.format(index=(index + 1)),
value=values,
constraint='Member must have length less than or equal to 50',
))
self._errors.append(
self._format_error(
key="parameterFilters.{index}.member.values".format(
index=(index + 1)
),
value=values,
constraint="Member must have length less than or equal to 50",
)
)
if any(len(value) > 1024 for value in values):
self._errors.append(self._format_error(
key='parameterFilters.{index}.member.values'.format(index=(index + 1)),
value=values,
constraint='[Member must have length less than or equal to 1024, Member must have length greater than or equal to 1]',
))
self._errors.append(
self._format_error(
key="parameterFilters.{index}.member.values".format(
index=(index + 1)
),
value=values,
constraint="[Member must have length less than or equal to 1024, Member must have length greater than or equal to 1]",
)
)
self._raise_errors()
filter_keys = []
for filter_obj in (parameter_filters or []):
key = filter_obj['Key']
values = filter_obj.get('Values')
for filter_obj in parameter_filters or []:
key = filter_obj["Key"]
values = filter_obj.get("Values")
if key == 'Path':
option = filter_obj.get('Option', 'OneLevel')
if key == "Path":
option = filter_obj.get("Option", "OneLevel")
else:
option = filter_obj.get('Option', 'Equals')
option = filter_obj.get("Option", "Equals")
if not by_path and key == 'Label':
raise InvalidFilterKey('The following filter key is not valid: Label. Valid filter keys include: [Path, Name, Type, KeyId, Tier].')
if not by_path and key == "Label":
raise InvalidFilterKey(
"The following filter key is not valid: Label. Valid filter keys include: [Path, Name, Type, KeyId, Tier]."
)
if not values:
raise InvalidFilterValue('The following filter values are missing : null for filter key Name.')
raise InvalidFilterValue(
"The following filter values are missing : null for filter key Name."
)
if key in filter_keys:
raise InvalidFilterKey(
'The following filter is duplicated in the request: Name. A request can contain only one occurrence of a specific filter.'
"The following filter is duplicated in the request: Name. A request can contain only one occurrence of a specific filter."
)
if key == 'Path':
if option not in ['Recursive', 'OneLevel']:
if key == "Path":
if option not in ["Recursive", "OneLevel"]:
raise InvalidFilterOption(
'The following filter option is not valid: {option}. Valid options include: [Recursive, OneLevel].'.format(option=option)
"The following filter option is not valid: {option}. Valid options include: [Recursive, OneLevel].".format(
option=option
)
)
if any(value.lower().startswith(('/aws', '/ssm')) for value in values):
if any(value.lower().startswith(("/aws", "/ssm")) for value in values):
raise ValidationException(
'Filters for common parameters can\'t be prefixed with "aws" or "ssm" (case-insensitive). '
'When using global parameters, please specify within a global namespace.'
"When using global parameters, please specify within a global namespace."
)
for value in values:
if value.lower().startswith(('/aws', '/ssm')):
if value.lower().startswith(("/aws", "/ssm")):
raise ValidationException(
'Filters for common parameters can\'t be prefixed with "aws" or "ssm" (case-insensitive). '
'When using global parameters, please specify within a global namespace.'
"When using global parameters, please specify within a global namespace."
)
if ('//' in value or
not value.startswith('/') or
not re.match('^[a-zA-Z0-9_.-/]*$', value)):
if (
"//" in value
or not value.startswith("/")
or not re.match("^[a-zA-Z0-9_.-/]*$", value)
):
raise ValidationException(
'The parameter doesn\'t meet the parameter name requirements. The parameter name must begin with a forward slash "/". '
'It can\'t be prefixed with \"aws\" or \"ssm\" (case-insensitive). '
'It must use only letters, numbers, or the following symbols: . (period), - (hyphen), _ (underscore). '
'It can\'t be prefixed with "aws" or "ssm" (case-insensitive). '
"It must use only letters, numbers, or the following symbols: . (period), - (hyphen), _ (underscore). "
'Special characters are not allowed. All sub-paths, if specified, must use the forward slash symbol "/". '
'Valid example: /get/parameters2-/by1./path0_.'
"Valid example: /get/parameters2-/by1./path0_."
)
if key == 'Tier':
if key == "Tier":
for value in values:
if value not in ['Standard', 'Advanced', 'Intelligent-Tiering']:
if value not in ["Standard", "Advanced", "Intelligent-Tiering"]:
raise InvalidFilterOption(
'The following filter value is not valid: {value}. Valid values include: [Standard, Advanced, Intelligent-Tiering].'.format(value=value)
"The following filter value is not valid: {value}. Valid values include: [Standard, Advanced, Intelligent-Tiering].".format(
value=value
)
)
if key == 'Type':
if key == "Type":
for value in values:
if value not in ['String', 'StringList', 'SecureString']:
if value not in ["String", "StringList", "SecureString"]:
raise InvalidFilterOption(
'The following filter value is not valid: {value}. Valid values include: [String, StringList, SecureString].'.format(value=value)
"The following filter value is not valid: {value}. Valid values include: [String, StringList, SecureString].".format(
value=value
)
)
if key != 'Path' and option not in ['Equals', 'BeginsWith']:
if key != "Path" and option not in ["Equals", "BeginsWith"]:
raise InvalidFilterOption(
'The following filter option is not valid: {option}. Valid options include: [BeginsWith, Equals].'.format(option=option)
"The following filter option is not valid: {option}. Valid options include: [BeginsWith, Equals].".format(
option=option
)
)
filter_keys.append(key)
def _format_error(self, key, value, constraint):
return 'Value "{value}" at "{key}" failed to satisfy constraint: {constraint}'.format(
constraint=constraint,
key=key,
value=value,
constraint=constraint, key=key, value=value
)
def _raise_errors(self):
@ -415,9 +488,11 @@ class SimpleSystemManagerBackend(BaseBackend):
errors = "; ".join(self._errors)
self._errors = [] # reset collected errors
raise ValidationException('{count} validation error{plural} detected: {errors}'.format(
count=count, plural=plural, errors=errors,
))
raise ValidationException(
"{count} validation error{plural} detected: {errors}".format(
count=count, plural=plural, errors=errors
)
)
def get_all_parameters(self):
result = []
@ -437,11 +512,11 @@ class SimpleSystemManagerBackend(BaseBackend):
result = []
# path could be with or without a trailing /. we handle this
# difference here.
path = path.rstrip('/') + '/'
path = path.rstrip("/") + "/"
for param in self._parameters:
if path != '/' and not param.startswith(path):
if path != "/" and not param.startswith(path):
continue
if '/' in param[len(path) + 1:] and not recursive:
if "/" in param[len(path) + 1 :] and not recursive:
continue
if not self._match_filters(self._parameters[param], filters):
continue
@ -451,48 +526,51 @@ class SimpleSystemManagerBackend(BaseBackend):
def _match_filters(self, parameter, filters=None):
"""Return True if the given parameter matches all the filters"""
for filter_obj in (filters or []):
key = filter_obj['Key']
values = filter_obj.get('Values', [])
for filter_obj in filters or []:
key = filter_obj["Key"]
values = filter_obj.get("Values", [])
if key == 'Path':
option = filter_obj.get('Option', 'OneLevel')
if key == "Path":
option = filter_obj.get("Option", "OneLevel")
else:
option = filter_obj.get('Option', 'Equals')
option = filter_obj.get("Option", "Equals")
what = None
if key == 'KeyId':
if key == "KeyId":
what = parameter.keyid
elif key == 'Name':
what = '/' + parameter.name.lstrip('/')
values = ['/' + value.lstrip('/') for value in values]
elif key == 'Path':
what = '/' + parameter.name.lstrip('/')
values = ['/' + value.strip('/') for value in values]
elif key == 'Type':
elif key == "Name":
what = "/" + parameter.name.lstrip("/")
values = ["/" + value.lstrip("/") for value in values]
elif key == "Path":
what = "/" + parameter.name.lstrip("/")
values = ["/" + value.strip("/") for value in values]
elif key == "Type":
what = parameter.type
if what is None:
return False
elif (option == 'BeginsWith' and
not any(what.startswith(value) for value in values)):
elif option == "BeginsWith" and not any(
what.startswith(value) for value in values
):
return False
elif (option == 'Equals' and
not any(what == value for value in values)):
elif option == "Equals" and not any(what == value for value in values):
return False
elif option == 'OneLevel':
if any(value == '/' and len(what.split('/')) == 2 for value in values):
elif option == "OneLevel":
if any(value == "/" and len(what.split("/")) == 2 for value in values):
continue
elif any(value != '/' and
what.startswith(value + '/') and
len(what.split('/')) - 1 == len(value.split('/')) for value in values):
elif any(
value != "/"
and what.startswith(value + "/")
and len(what.split("/")) - 1 == len(value.split("/"))
for value in values
):
continue
else:
return False
elif option == 'Recursive':
if any(value == '/' for value in values):
elif option == "Recursive":
if any(value == "/" for value in values):
continue
elif any(what.startswith(value + '/') for value in values):
elif any(what.startswith(value + "/") for value in values):
continue
else:
return False
@ -504,8 +582,9 @@ class SimpleSystemManagerBackend(BaseBackend):
return self._parameters[name]
return None
def put_parameter(self, name, description, value, type, allowed_pattern,
keyid, overwrite):
def put_parameter(
self, name, description, value, type, allowed_pattern, keyid, overwrite
):
previous_parameter = self._parameters.get(name)
version = 1
@ -516,8 +595,16 @@ class SimpleSystemManagerBackend(BaseBackend):
return
last_modified_date = time.time()
self._parameters[name] = Parameter(name, value, type, description,
allowed_pattern, keyid, last_modified_date, version)
self._parameters[name] = Parameter(
name,
value,
type,
description,
allowed_pattern,
keyid,
last_modified_date,
version,
)
return version
def add_tags_to_resource(self, resource_type, resource_id, tags):
@ -535,29 +622,31 @@ class SimpleSystemManagerBackend(BaseBackend):
def send_command(self, **kwargs):
command = Command(
comment=kwargs.get('Comment', ''),
document_name=kwargs.get('DocumentName'),
timeout_seconds=kwargs.get('TimeoutSeconds', 3600),
instance_ids=kwargs.get('InstanceIds', []),
max_concurrency=kwargs.get('MaxConcurrency', '50'),
max_errors=kwargs.get('MaxErrors', '0'),
notification_config=kwargs.get('NotificationConfig', {
'NotificationArn': 'string',
'NotificationEvents': ['Success'],
'NotificationType': 'Command'
}),
output_s3_bucket_name=kwargs.get('OutputS3BucketName', ''),
output_s3_key_prefix=kwargs.get('OutputS3KeyPrefix', ''),
output_s3_region=kwargs.get('OutputS3Region', ''),
parameters=kwargs.get('Parameters', {}),
service_role_arn=kwargs.get('ServiceRoleArn', ''),
targets=kwargs.get('Targets', []),
backend_region=self._region)
comment=kwargs.get("Comment", ""),
document_name=kwargs.get("DocumentName"),
timeout_seconds=kwargs.get("TimeoutSeconds", 3600),
instance_ids=kwargs.get("InstanceIds", []),
max_concurrency=kwargs.get("MaxConcurrency", "50"),
max_errors=kwargs.get("MaxErrors", "0"),
notification_config=kwargs.get(
"NotificationConfig",
{
"NotificationArn": "string",
"NotificationEvents": ["Success"],
"NotificationType": "Command",
},
),
output_s3_bucket_name=kwargs.get("OutputS3BucketName", ""),
output_s3_key_prefix=kwargs.get("OutputS3KeyPrefix", ""),
output_s3_region=kwargs.get("OutputS3Region", ""),
parameters=kwargs.get("Parameters", {}),
service_role_arn=kwargs.get("ServiceRoleArn", ""),
targets=kwargs.get("Targets", []),
backend_region=self._region,
)
self._commands.append(command)
return {
'Command': command.response_object()
}
return {"Command": command.response_object()}
def list_commands(self, **kwargs):
"""
@ -565,39 +654,38 @@ class SimpleSystemManagerBackend(BaseBackend):
"""
commands = self._commands
command_id = kwargs.get('CommandId', None)
command_id = kwargs.get("CommandId", None)
if command_id:
commands = [self.get_command_by_id(command_id)]
instance_id = kwargs.get('InstanceId', None)
instance_id = kwargs.get("InstanceId", None)
if instance_id:
commands = self.get_commands_by_instance_id(instance_id)
return {
'Commands': [command.response_object() for command in commands]
}
return {"Commands": [command.response_object() for command in commands]}
def get_command_by_id(self, id):
command = next(
(command for command in self._commands if command.command_id == id), None)
(command for command in self._commands if command.command_id == id), None
)
if command is None:
raise RESTError('InvalidCommandId', 'Invalid command id.')
raise RESTError("InvalidCommandId", "Invalid command id.")
return command
def get_commands_by_instance_id(self, instance_id):
return [
command for command in self._commands
if instance_id in command.instance_ids]
command for command in self._commands if instance_id in command.instance_ids
]
def get_command_invocation(self, **kwargs):
"""
https://docs.aws.amazon.com/systems-manager/latest/APIReference/API_GetCommandInvocation.html
"""
command_id = kwargs.get('CommandId')
instance_id = kwargs.get('InstanceId')
plugin_name = kwargs.get('PluginName', None)
command_id = kwargs.get("CommandId")
instance_id = kwargs.get("InstanceId")
plugin_name = kwargs.get("PluginName", None)
command = self.get_command_by_id(command_id)
return command.get_invocation(instance_id, plugin_name)

View file

@ -6,7 +6,6 @@ from .models import ssm_backends
class SimpleSystemManagerResponse(BaseResponse):
@property
def ssm_backend(self):
return ssm_backends[self.region]
@ -22,171 +21,151 @@ class SimpleSystemManagerResponse(BaseResponse):
return self.request_params.get(param, default)
def delete_parameter(self):
name = self._get_param('Name')
name = self._get_param("Name")
self.ssm_backend.delete_parameter(name)
return json.dumps({})
def delete_parameters(self):
names = self._get_param('Names')
names = self._get_param("Names")
result = self.ssm_backend.delete_parameters(names)
response = {
'DeletedParameters': [],
'InvalidParameters': []
}
response = {"DeletedParameters": [], "InvalidParameters": []}
for name in names:
if name in result:
response['DeletedParameters'].append(name)
response["DeletedParameters"].append(name)
else:
response['InvalidParameters'].append(name)
response["InvalidParameters"].append(name)
return json.dumps(response)
def get_parameter(self):
name = self._get_param('Name')
with_decryption = self._get_param('WithDecryption')
name = self._get_param("Name")
with_decryption = self._get_param("WithDecryption")
result = self.ssm_backend.get_parameter(name, with_decryption)
if result is None:
error = {
'__type': 'ParameterNotFound',
'message': 'Parameter {0} not found.'.format(name)
"__type": "ParameterNotFound",
"message": "Parameter {0} not found.".format(name),
}
return json.dumps(error), dict(status=400)
response = {
'Parameter': result.response_object(with_decryption)
}
response = {"Parameter": result.response_object(with_decryption)}
return json.dumps(response)
def get_parameters(self):
names = self._get_param('Names')
with_decryption = self._get_param('WithDecryption')
names = self._get_param("Names")
with_decryption = self._get_param("WithDecryption")
result = self.ssm_backend.get_parameters(names, with_decryption)
response = {
'Parameters': [],
'InvalidParameters': [],
}
response = {"Parameters": [], "InvalidParameters": []}
for parameter in result:
param_data = parameter.response_object(with_decryption)
response['Parameters'].append(param_data)
response["Parameters"].append(param_data)
param_names = [param.name for param in result]
for name in names:
if name not in param_names:
response['InvalidParameters'].append(name)
response["InvalidParameters"].append(name)
return json.dumps(response)
def get_parameters_by_path(self):
path = self._get_param('Path')
with_decryption = self._get_param('WithDecryption')
recursive = self._get_param('Recursive', False)
filters = self._get_param('ParameterFilters')
path = self._get_param("Path")
with_decryption = self._get_param("WithDecryption")
recursive = self._get_param("Recursive", False)
filters = self._get_param("ParameterFilters")
result = self.ssm_backend.get_parameters_by_path(
path, with_decryption, recursive, filters
)
response = {
'Parameters': [],
}
response = {"Parameters": []}
for parameter in result:
param_data = parameter.response_object(with_decryption)
response['Parameters'].append(param_data)
response["Parameters"].append(param_data)
return json.dumps(response)
def describe_parameters(self):
page_size = 10
filters = self._get_param('Filters')
parameter_filters = self._get_param('ParameterFilters')
token = self._get_param('NextToken')
if hasattr(token, 'strip'):
filters = self._get_param("Filters")
parameter_filters = self._get_param("ParameterFilters")
token = self._get_param("NextToken")
if hasattr(token, "strip"):
token = token.strip()
if not token:
token = '0'
token = "0"
token = int(token)
result = self.ssm_backend.describe_parameters(
filters, parameter_filters
)
result = self.ssm_backend.describe_parameters(filters, parameter_filters)
response = {
'Parameters': [],
}
response = {"Parameters": []}
end = token + page_size
for parameter in result[token:]:
response['Parameters'].append(parameter.describe_response_object(False))
response["Parameters"].append(parameter.describe_response_object(False))
token = token + 1
if len(response['Parameters']) == page_size:
response['NextToken'] = str(end)
if len(response["Parameters"]) == page_size:
response["NextToken"] = str(end)
break
return json.dumps(response)
def put_parameter(self):
name = self._get_param('Name')
description = self._get_param('Description')
value = self._get_param('Value')
type_ = self._get_param('Type')
allowed_pattern = self._get_param('AllowedPattern')
keyid = self._get_param('KeyId')
overwrite = self._get_param('Overwrite', False)
name = self._get_param("Name")
description = self._get_param("Description")
value = self._get_param("Value")
type_ = self._get_param("Type")
allowed_pattern = self._get_param("AllowedPattern")
keyid = self._get_param("KeyId")
overwrite = self._get_param("Overwrite", False)
result = self.ssm_backend.put_parameter(
name, description, value, type_, allowed_pattern, keyid, overwrite)
name, description, value, type_, allowed_pattern, keyid, overwrite
)
if result is None:
error = {
'__type': 'ParameterAlreadyExists',
'message': 'Parameter {0} already exists.'.format(name)
"__type": "ParameterAlreadyExists",
"message": "Parameter {0} already exists.".format(name),
}
return json.dumps(error), dict(status=400)
response = {'Version': result}
response = {"Version": result}
return json.dumps(response)
def add_tags_to_resource(self):
resource_id = self._get_param('ResourceId')
resource_type = self._get_param('ResourceType')
tags = {t['Key']: t['Value'] for t in self._get_param('Tags')}
self.ssm_backend.add_tags_to_resource(
resource_id, resource_type, tags)
resource_id = self._get_param("ResourceId")
resource_type = self._get_param("ResourceType")
tags = {t["Key"]: t["Value"] for t in self._get_param("Tags")}
self.ssm_backend.add_tags_to_resource(resource_id, resource_type, tags)
return json.dumps({})
def remove_tags_from_resource(self):
resource_id = self._get_param('ResourceId')
resource_type = self._get_param('ResourceType')
keys = self._get_param('TagKeys')
self.ssm_backend.remove_tags_from_resource(
resource_id, resource_type, keys)
resource_id = self._get_param("ResourceId")
resource_type = self._get_param("ResourceType")
keys = self._get_param("TagKeys")
self.ssm_backend.remove_tags_from_resource(resource_id, resource_type, keys)
return json.dumps({})
def list_tags_for_resource(self):
resource_id = self._get_param('ResourceId')
resource_type = self._get_param('ResourceType')
tags = self.ssm_backend.list_tags_for_resource(
resource_id, resource_type)
tag_list = [{'Key': k, 'Value': v} for (k, v) in tags.items()]
response = {'TagList': tag_list}
resource_id = self._get_param("ResourceId")
resource_type = self._get_param("ResourceType")
tags = self.ssm_backend.list_tags_for_resource(resource_id, resource_type)
tag_list = [{"Key": k, "Value": v} for (k, v) in tags.items()]
response = {"TagList": tag_list}
return json.dumps(response)
def send_command(self):
return json.dumps(
self.ssm_backend.send_command(**self.request_params)
)
return json.dumps(self.ssm_backend.send_command(**self.request_params))
def list_commands(self):
return json.dumps(
self.ssm_backend.list_commands(**self.request_params)
)
return json.dumps(self.ssm_backend.list_commands(**self.request_params))
def get_command_invocation(self):
return json.dumps(

View file

@ -1,11 +1,6 @@
from __future__ import unicode_literals
from .responses import SimpleSystemManagerResponse
url_bases = [
"https?://ssm.(.+).amazonaws.com",
"https?://ssm.(.+).amazonaws.com.cn",
]
url_bases = ["https?://ssm.(.+).amazonaws.com", "https?://ssm.(.+).amazonaws.com.cn"]
url_paths = {
'{0}/$': SimpleSystemManagerResponse.dispatch,
}
url_paths = {"{0}/$": SimpleSystemManagerResponse.dispatch}