Run black on moto & test directories.

This commit is contained in:
Asher Foa 2019-10-31 08:44:26 -07:00
commit 96e5b1993d
507 changed files with 52541 additions and 47814 deletions

View file

@ -4,4 +4,6 @@ from .models import BaseModel, BaseBackend, moto_api_backend # noqa
from .responses import ActionAuthenticatorMixin
moto_api_backends = {"global": moto_api_backend}
set_initial_no_auth_action_count = ActionAuthenticatorMixin.set_initial_no_auth_action_count
set_initial_no_auth_action_count = (
ActionAuthenticatorMixin.set_initial_no_auth_action_count
)

View file

@ -26,7 +26,12 @@ from six import string_types
from moto.iam.models import ACCOUNT_ID, Policy
from moto.iam import iam_backend
from moto.core.exceptions import SignatureDoesNotMatchError, AccessDeniedError, InvalidClientTokenIdError, AuthFailureError
from moto.core.exceptions import (
SignatureDoesNotMatchError,
AccessDeniedError,
InvalidClientTokenIdError,
AuthFailureError,
)
from moto.s3.exceptions import (
BucketAccessDeniedError,
S3AccessDeniedError,
@ -35,7 +40,7 @@ from moto.s3.exceptions import (
S3InvalidAccessKeyIdError,
BucketInvalidAccessKeyIdError,
BucketSignatureDoesNotMatchError,
S3SignatureDoesNotMatchError
S3SignatureDoesNotMatchError,
)
from moto.sts import sts_backend
@ -50,9 +55,8 @@ def create_access_key(access_key_id, headers):
class IAMUserAccessKey(object):
def __init__(self, access_key_id, headers):
iam_users = iam_backend.list_users('/', None, None)
iam_users = iam_backend.list_users("/", None, None)
for iam_user in iam_users:
for access_key in iam_user.access_keys:
if access_key.access_key_id == access_key_id:
@ -67,8 +71,7 @@ class IAMUserAccessKey(object):
@property
def arn(self):
return "arn:aws:iam::{account_id}:user/{iam_user_name}".format(
account_id=ACCOUNT_ID,
iam_user_name=self._owner_user_name
account_id=ACCOUNT_ID, iam_user_name=self._owner_user_name
)
def create_credentials(self):
@ -79,27 +82,34 @@ class IAMUserAccessKey(object):
inline_policy_names = iam_backend.list_user_policies(self._owner_user_name)
for inline_policy_name in inline_policy_names:
inline_policy = iam_backend.get_user_policy(self._owner_user_name, inline_policy_name)
inline_policy = iam_backend.get_user_policy(
self._owner_user_name, inline_policy_name
)
user_policies.append(inline_policy)
attached_policies, _ = iam_backend.list_attached_user_policies(self._owner_user_name)
attached_policies, _ = iam_backend.list_attached_user_policies(
self._owner_user_name
)
user_policies += attached_policies
user_groups = iam_backend.get_groups_for_user(self._owner_user_name)
for user_group in user_groups:
inline_group_policy_names = iam_backend.list_group_policies(user_group.name)
for inline_group_policy_name in inline_group_policy_names:
inline_user_group_policy = iam_backend.get_group_policy(user_group.name, inline_group_policy_name)
inline_user_group_policy = iam_backend.get_group_policy(
user_group.name, inline_group_policy_name
)
user_policies.append(inline_user_group_policy)
attached_group_policies, _ = iam_backend.list_attached_group_policies(user_group.name)
attached_group_policies, _ = iam_backend.list_attached_group_policies(
user_group.name
)
user_policies += attached_group_policies
return user_policies
class AssumedRoleAccessKey(object):
def __init__(self, access_key_id, headers):
for assumed_role in sts_backend.assumed_roles:
if assumed_role.access_key_id == access_key_id:
@ -118,28 +128,33 @@ class AssumedRoleAccessKey(object):
return "arn:aws:sts::{account_id}:assumed-role/{role_name}/{session_name}".format(
account_id=ACCOUNT_ID,
role_name=self._owner_role_name,
session_name=self._session_name
session_name=self._session_name,
)
def create_credentials(self):
return Credentials(self._access_key_id, self._secret_access_key, self._session_token)
return Credentials(
self._access_key_id, self._secret_access_key, self._session_token
)
def collect_policies(self):
role_policies = []
inline_policy_names = iam_backend.list_role_policies(self._owner_role_name)
for inline_policy_name in inline_policy_names:
_, inline_policy = iam_backend.get_role_policy(self._owner_role_name, inline_policy_name)
_, inline_policy = iam_backend.get_role_policy(
self._owner_role_name, inline_policy_name
)
role_policies.append(inline_policy)
attached_policies, _ = iam_backend.list_attached_role_policies(self._owner_role_name)
attached_policies, _ = iam_backend.list_attached_role_policies(
self._owner_role_name
)
role_policies += attached_policies
return role_policies
class CreateAccessKeyFailure(Exception):
def __init__(self, reason, *args):
super(CreateAccessKeyFailure, self).__init__(*args)
self.reason = reason
@ -147,32 +162,54 @@ class CreateAccessKeyFailure(Exception):
@six.add_metaclass(ABCMeta)
class IAMRequestBase(object):
def __init__(self, method, path, data, headers):
log.debug("Creating {class_name} with method={method}, path={path}, data={data}, headers={headers}".format(
class_name=self.__class__.__name__, method=method, path=path, data=data, headers=headers))
log.debug(
"Creating {class_name} with method={method}, path={path}, data={data}, headers={headers}".format(
class_name=self.__class__.__name__,
method=method,
path=path,
data=data,
headers=headers,
)
)
self._method = method
self._path = path
self._data = data
self._headers = headers
credential_scope = self._get_string_between('Credential=', ',', self._headers['Authorization'])
credential_data = credential_scope.split('/')
credential_scope = self._get_string_between(
"Credential=", ",", self._headers["Authorization"]
)
credential_data = credential_scope.split("/")
self._region = credential_data[2]
self._service = credential_data[3]
self._action = self._service + ":" + (self._data["Action"][0] if isinstance(self._data["Action"], list) else self._data["Action"])
self._action = (
self._service
+ ":"
+ (
self._data["Action"][0]
if isinstance(self._data["Action"], list)
else self._data["Action"]
)
)
try:
self._access_key = create_access_key(access_key_id=credential_data[0], headers=headers)
self._access_key = create_access_key(
access_key_id=credential_data[0], headers=headers
)
except CreateAccessKeyFailure as e:
self._raise_invalid_access_key(e.reason)
def check_signature(self):
original_signature = self._get_string_between('Signature=', ',', self._headers['Authorization'])
original_signature = self._get_string_between(
"Signature=", ",", self._headers["Authorization"]
)
calculated_signature = self._calculate_signature()
if original_signature != calculated_signature:
self._raise_signature_does_not_match()
def check_action_permitted(self):
if self._action == 'sts:GetCallerIdentity': # always allowed, even if there's an explicit Deny for it
if (
self._action == "sts:GetCallerIdentity"
): # always allowed, even if there's an explicit Deny for it
return True
policies = self._access_key.collect_policies()
@ -213,10 +250,14 @@ class IAMRequestBase(object):
return headers
def _create_aws_request(self):
signed_headers = self._get_string_between('SignedHeaders=', ',', self._headers['Authorization']).split(';')
signed_headers = self._get_string_between(
"SignedHeaders=", ",", self._headers["Authorization"]
).split(";")
headers = self._create_headers_for_aws_request(signed_headers, self._headers)
request = AWSRequest(method=self._method, url=self._path, data=self._data, headers=headers)
request.context['timestamp'] = headers['X-Amz-Date']
request = AWSRequest(
method=self._method, url=self._path, data=self._data, headers=headers
)
request.context["timestamp"] = headers["X-Amz-Date"]
return request
@ -234,7 +275,6 @@ class IAMRequestBase(object):
class IAMRequest(IAMRequestBase):
def _raise_signature_does_not_match(self):
if self._service == "ec2":
raise AuthFailureError()
@ -251,14 +291,10 @@ class IAMRequest(IAMRequestBase):
return SigV4Auth(credentials, self._service, self._region)
def _raise_access_denied(self):
raise AccessDeniedError(
user_arn=self._access_key.arn,
action=self._action
)
raise AccessDeniedError(user_arn=self._access_key.arn, action=self._action)
class S3IAMRequest(IAMRequestBase):
def _raise_signature_does_not_match(self):
if "BucketName" in self._data:
raise BucketSignatureDoesNotMatchError(bucket=self._data["BucketName"])
@ -288,10 +324,13 @@ class S3IAMRequest(IAMRequestBase):
class IAMPolicy(object):
def __init__(self, policy):
if isinstance(policy, Policy):
default_version = next(policy_version for policy_version in policy.versions if policy_version.is_default)
default_version = next(
policy_version
for policy_version in policy.versions
if policy_version.is_default
)
policy_document = default_version.document
elif isinstance(policy, string_types):
policy_document = policy
@ -321,7 +360,6 @@ class IAMPolicy(object):
class IAMPolicyStatement(object):
def __init__(self, statement):
self._statement = statement

View file

@ -4,7 +4,7 @@ from werkzeug.exceptions import HTTPException
from jinja2 import DictLoader, Environment
SINGLE_ERROR_RESPONSE = u"""<?xml version="1.0" encoding="UTF-8"?>
SINGLE_ERROR_RESPONSE = """<?xml version="1.0" encoding="UTF-8"?>
<Error>
<Code>{{error_type}}</Code>
<Message>{{message}}</Message>
@ -13,7 +13,7 @@ SINGLE_ERROR_RESPONSE = u"""<?xml version="1.0" encoding="UTF-8"?>
</Error>
"""
ERROR_RESPONSE = u"""<?xml version="1.0" encoding="UTF-8"?>
ERROR_RESPONSE = """<?xml version="1.0" encoding="UTF-8"?>
<ErrorResponse>
<Errors>
<Error>
@ -26,7 +26,7 @@ ERROR_RESPONSE = u"""<?xml version="1.0" encoding="UTF-8"?>
</ErrorResponse>
"""
ERROR_JSON_RESPONSE = u"""{
ERROR_JSON_RESPONSE = """{
"message": "{{message}}",
"__type": "{{error_type}}"
}
@ -37,18 +37,19 @@ class RESTError(HTTPException):
code = 400
templates = {
'single_error': SINGLE_ERROR_RESPONSE,
'error': ERROR_RESPONSE,
'error_json': ERROR_JSON_RESPONSE,
"single_error": SINGLE_ERROR_RESPONSE,
"error": ERROR_RESPONSE,
"error_json": ERROR_JSON_RESPONSE,
}
def __init__(self, error_type, message, template='error', **kwargs):
def __init__(self, error_type, message, template="error", **kwargs):
super(RESTError, self).__init__()
env = Environment(loader=DictLoader(self.templates))
self.error_type = error_type
self.message = message
self.description = env.get_template(template).render(
error_type=error_type, message=message, **kwargs)
error_type=error_type, message=message, **kwargs
)
class DryRunClientError(RESTError):
@ -56,12 +57,11 @@ class DryRunClientError(RESTError):
class JsonRESTError(RESTError):
def __init__(self, error_type, message, template='error_json', **kwargs):
super(JsonRESTError, self).__init__(
error_type, message, template, **kwargs)
def __init__(self, error_type, message, template="error_json", **kwargs):
super(JsonRESTError, self).__init__(error_type, message, template, **kwargs)
def get_headers(self, *args, **kwargs):
return [('Content-Type', 'application/json')]
return [("Content-Type", "application/json")]
def get_body(self, *args, **kwargs):
return self.description
@ -72,8 +72,9 @@ class SignatureDoesNotMatchError(RESTError):
def __init__(self):
super(SignatureDoesNotMatchError, self).__init__(
'SignatureDoesNotMatch',
"The request signature we calculated does not match the signature you provided. Check your AWS Secret Access Key and signing method. Consult the service documentation for details.")
"SignatureDoesNotMatch",
"The request signature we calculated does not match the signature you provided. Check your AWS Secret Access Key and signing method. Consult the service documentation for details.",
)
class InvalidClientTokenIdError(RESTError):
@ -81,8 +82,9 @@ class InvalidClientTokenIdError(RESTError):
def __init__(self):
super(InvalidClientTokenIdError, self).__init__(
'InvalidClientTokenId',
"The security token included in the request is invalid.")
"InvalidClientTokenId",
"The security token included in the request is invalid.",
)
class AccessDeniedError(RESTError):
@ -90,11 +92,11 @@ class AccessDeniedError(RESTError):
def __init__(self, user_arn, action):
super(AccessDeniedError, self).__init__(
'AccessDenied',
"AccessDenied",
"User: {user_arn} is not authorized to perform: {operation}".format(
user_arn=user_arn,
operation=action
))
user_arn=user_arn, operation=action
),
)
class AuthFailureError(RESTError):
@ -102,13 +104,17 @@ class AuthFailureError(RESTError):
def __init__(self):
super(AuthFailureError, self).__init__(
'AuthFailure',
"AWS was not able to validate the provided access credentials")
"AuthFailure",
"AWS was not able to validate the provided access credentials",
)
class InvalidNextTokenException(JsonRESTError):
"""For AWS Config resource listing. This will be used by many different resource types, and so it is in moto.core."""
code = 400
def __init__(self):
super(InvalidNextTokenException, self).__init__('InvalidNextTokenException', 'The nextToken provided is invalid')
super(InvalidNextTokenException, self).__init__(
"InvalidNextTokenException", "The nextToken provided is invalid"
)

View file

@ -31,15 +31,19 @@ class BaseMockAWS(object):
self.backends_for_urls = {}
from moto.backends import BACKENDS
default_backends = {
"instance_metadata": BACKENDS['instance_metadata']['global'],
"moto_api": BACKENDS['moto_api']['global'],
"instance_metadata": BACKENDS["instance_metadata"]["global"],
"moto_api": BACKENDS["moto_api"]["global"],
}
self.backends_for_urls.update(self.backends)
self.backends_for_urls.update(default_backends)
# "Mock" the AWS credentials as they can't be mocked in Botocore currently
FAKE_KEYS = {"AWS_ACCESS_KEY_ID": "foobar_key", "AWS_SECRET_ACCESS_KEY": "foobar_secret"}
FAKE_KEYS = {
"AWS_ACCESS_KEY_ID": "foobar_key",
"AWS_SECRET_ACCESS_KEY": "foobar_secret",
}
self.env_variables_mocks = mock.patch.dict(os.environ, FAKE_KEYS)
if self.__class__.nested_count == 0:
@ -72,7 +76,7 @@ class BaseMockAWS(object):
self.__class__.nested_count -= 1
if self.__class__.nested_count < 0:
raise RuntimeError('Called stop() before start().')
raise RuntimeError("Called stop() before start().")
if self.__class__.nested_count == 0:
self.disable_patching()
@ -85,6 +89,7 @@ class BaseMockAWS(object):
finally:
self.stop()
return result
functools.update_wrapper(wrapper, func)
wrapper.__wrapped__ = func
return wrapper
@ -122,7 +127,6 @@ class BaseMockAWS(object):
class HttprettyMockAWS(BaseMockAWS):
def reset(self):
HTTPretty.reset()
@ -144,18 +148,26 @@ class HttprettyMockAWS(BaseMockAWS):
HTTPretty.reset()
RESPONSES_METHODS = [responses.GET, responses.DELETE, responses.HEAD,
responses.OPTIONS, responses.PATCH, responses.POST, responses.PUT]
RESPONSES_METHODS = [
responses.GET,
responses.DELETE,
responses.HEAD,
responses.OPTIONS,
responses.PATCH,
responses.POST,
responses.PUT,
]
class CallbackResponse(responses.CallbackResponse):
'''
"""
Need to subclass so we can change a couple things
'''
"""
def get_response(self, request):
'''
"""
Need to override this so we can pass decode_content=False
'''
"""
headers = self.get_headers()
result = self.callback(request)
@ -177,17 +189,17 @@ class CallbackResponse(responses.CallbackResponse):
)
def _url_matches(self, url, other, match_querystring=False):
'''
"""
Need to override this so we can fix querystrings breaking regex matching
'''
"""
if not match_querystring:
other = other.split('?', 1)[0]
other = other.split("?", 1)[0]
if responses._is_string(url):
if responses._has_unicode(url):
url = responses._clean_unicode(url)
if not isinstance(other, six.text_type):
other = other.encode('ascii').decode('utf8')
other = other.encode("ascii").decode("utf8")
return self._url_matches_strict(url, other)
elif isinstance(url, responses.Pattern) and url.match(other):
return True
@ -195,22 +207,23 @@ class CallbackResponse(responses.CallbackResponse):
return False
botocore_mock = responses.RequestsMock(assert_all_requests_are_fired=False, target='botocore.vendored.requests.adapters.HTTPAdapter.send')
botocore_mock = responses.RequestsMock(
assert_all_requests_are_fired=False,
target="botocore.vendored.requests.adapters.HTTPAdapter.send",
)
responses_mock = responses._default_mock
# Add passthrough to allow any other requests to work
# Since this uses .startswith, it applies to http and https requests.
responses_mock.add_passthru("http")
BOTOCORE_HTTP_METHODS = [
'GET', 'DELETE', 'HEAD', 'OPTIONS', 'PATCH', 'POST', 'PUT'
]
BOTOCORE_HTTP_METHODS = ["GET", "DELETE", "HEAD", "OPTIONS", "PATCH", "POST", "PUT"]
class MockRawResponse(BytesIO):
def __init__(self, input):
if isinstance(input, six.text_type):
input = input.encode('utf-8')
input = input.encode("utf-8")
super(MockRawResponse, self).__init__(input)
def stream(self, **kwargs):
@ -241,7 +254,7 @@ class BotocoreStubber(object):
found_index = None
matchers = self.methods.get(request.method)
base_url = request.url.split('?', 1)[0]
base_url = request.url.split("?", 1)[0]
for i, (pattern, callback) in enumerate(matchers):
if pattern.match(base_url):
if found_index is None:
@ -254,8 +267,10 @@ class BotocoreStubber(object):
if response_callback is not None:
for header, value in request.headers.items():
if isinstance(value, six.binary_type):
request.headers[header] = value.decode('utf-8')
status, headers, body = response_callback(request, request.url, request.headers)
request.headers[header] = value.decode("utf-8")
status, headers, body = response_callback(
request, request.url, request.headers
)
body = MockRawResponse(body)
response = AWSResponse(request.url, status, headers, body)
@ -263,7 +278,7 @@ class BotocoreStubber(object):
botocore_stubber = BotocoreStubber()
BUILTIN_HANDLERS.append(('before-send', botocore_stubber))
BUILTIN_HANDLERS.append(("before-send", botocore_stubber))
def not_implemented_callback(request):
@ -287,7 +302,9 @@ class BotocoreEventMockAWS(BaseMockAWS):
pattern = re.compile(key)
botocore_stubber.register_response(method, pattern, value)
if not hasattr(responses_mock, '_patcher') or not hasattr(responses_mock._patcher, 'target'):
if not hasattr(responses_mock, "_patcher") or not hasattr(
responses_mock._patcher, "target"
):
responses_mock.start()
for method in RESPONSES_METHODS:
@ -336,9 +353,9 @@ MockAWS = BotocoreEventMockAWS
class ServerModeMockAWS(BaseMockAWS):
def reset(self):
import requests
requests.post("http://localhost:5000/moto-api/reset")
def enable_patching(self):
@ -350,13 +367,13 @@ class ServerModeMockAWS(BaseMockAWS):
import mock
def fake_boto3_client(*args, **kwargs):
if 'endpoint_url' not in kwargs:
kwargs['endpoint_url'] = "http://localhost:5000"
if "endpoint_url" not in kwargs:
kwargs["endpoint_url"] = "http://localhost:5000"
return real_boto3_client(*args, **kwargs)
def fake_boto3_resource(*args, **kwargs):
if 'endpoint_url' not in kwargs:
kwargs['endpoint_url'] = "http://localhost:5000"
if "endpoint_url" not in kwargs:
kwargs["endpoint_url"] = "http://localhost:5000"
return real_boto3_resource(*args, **kwargs)
def fake_httplib_send_output(self, message_body=None, *args, **kwargs):
@ -364,7 +381,7 @@ class ServerModeMockAWS(BaseMockAWS):
bytes_buffer = []
for chunk in mixed_buffer:
if isinstance(chunk, six.text_type):
bytes_buffer.append(chunk.encode('utf-8'))
bytes_buffer.append(chunk.encode("utf-8"))
else:
bytes_buffer.append(chunk)
msg = b"\r\n".join(bytes_buffer)
@ -385,10 +402,12 @@ class ServerModeMockAWS(BaseMockAWS):
if message_body is not None:
self.send(message_body)
self._client_patcher = mock.patch('boto3.client', fake_boto3_client)
self._resource_patcher = mock.patch('boto3.resource', fake_boto3_resource)
self._client_patcher = mock.patch("boto3.client", fake_boto3_client)
self._resource_patcher = mock.patch("boto3.resource", fake_boto3_resource)
if six.PY2:
self._httplib_patcher = mock.patch('httplib.HTTPConnection._send_output', fake_httplib_send_output)
self._httplib_patcher = mock.patch(
"httplib.HTTPConnection._send_output", fake_httplib_send_output
)
self._client_patcher.start()
self._resource_patcher.start()
@ -404,7 +423,6 @@ class ServerModeMockAWS(BaseMockAWS):
class Model(type):
def __new__(self, clsname, bases, namespace):
cls = super(Model, self).__new__(self, clsname, bases, namespace)
cls.__models__ = {}
@ -419,9 +437,11 @@ class Model(type):
@staticmethod
def prop(model_name):
""" decorator to mark a class method as returning model values """
def dec(f):
f.__returns_model__ = model_name
return f
return dec
@ -431,7 +451,7 @@ model_data = defaultdict(dict)
class InstanceTrackerMeta(type):
def __new__(meta, name, bases, dct):
cls = super(InstanceTrackerMeta, meta).__new__(meta, name, bases, dct)
if name == 'BaseModel':
if name == "BaseModel":
return cls
service = cls.__module__.split(".")[1]
@ -450,7 +470,6 @@ class BaseModel(object):
class BaseBackend(object):
def _reset_model_refs(self):
# Remove all references to the models stored
for service, models in model_data.items():
@ -466,8 +485,9 @@ class BaseBackend(object):
def _url_module(self):
backend_module = self.__class__.__module__
backend_urls_module_name = backend_module.replace("models", "urls")
backend_urls_module = __import__(backend_urls_module_name, fromlist=[
'url_bases', 'url_paths'])
backend_urls_module = __import__(
backend_urls_module_name, fromlist=["url_bases", "url_paths"]
)
return backend_urls_module
@property
@ -523,9 +543,9 @@ class BaseBackend(object):
def decorator(self, func=None):
if settings.TEST_SERVER_MODE:
mocked_backend = ServerModeMockAWS({'global': self})
mocked_backend = ServerModeMockAWS({"global": self})
else:
mocked_backend = MockAWS({'global': self})
mocked_backend = MockAWS({"global": self})
if func:
return mocked_backend(func)
@ -534,9 +554,9 @@ class BaseBackend(object):
def deprecated_decorator(self, func=None):
if func:
return HttprettyMockAWS({'global': self})(func)
return HttprettyMockAWS({"global": self})(func)
else:
return HttprettyMockAWS({'global': self})
return HttprettyMockAWS({"global": self})
# def list_config_service_resources(self, resource_ids, resource_name, limit, next_token):
# """For AWS Config. This will list all of the resources of the given type and optional resource name and region"""
@ -544,12 +564,19 @@ class BaseBackend(object):
class ConfigQueryModel(object):
def __init__(self, backends):
"""Inits based on the resource type's backends (1 for each region if applicable)"""
self.backends = backends
def list_config_service_resources(self, resource_ids, resource_name, limit, next_token, backend_region=None, resource_region=None):
def list_config_service_resources(
self,
resource_ids,
resource_name,
limit,
next_token,
backend_region=None,
resource_region=None,
):
"""For AWS Config. This will list all of the resources of the given type and optional resource name and region.
This supports both aggregated and non-aggregated listing. The following notes the difference:
@ -593,7 +620,9 @@ class ConfigQueryModel(object):
"""
raise NotImplementedError()
def get_config_resource(self, resource_id, resource_name=None, backend_region=None, resource_region=None):
def get_config_resource(
self, resource_id, resource_name=None, backend_region=None, resource_region=None
):
"""For AWS Config. This will query the backend for the specific resource type configuration.
This supports both aggregated, and non-aggregated fetching -- for batched fetching -- the Config batching requests
@ -644,9 +673,9 @@ class deprecated_base_decorator(base_decorator):
class MotoAPIBackend(BaseBackend):
def reset(self):
from moto.backends import BACKENDS
for name, backends in BACKENDS.items():
if name == "moto_api":
continue

View file

@ -40,7 +40,7 @@ def _decode_dict(d):
newkey = []
for k in key:
if isinstance(k, six.binary_type):
newkey.append(k.decode('utf-8'))
newkey.append(k.decode("utf-8"))
else:
newkey.append(k)
else:
@ -52,7 +52,7 @@ def _decode_dict(d):
newvalue = []
for v in value:
if isinstance(v, six.binary_type):
newvalue.append(v.decode('utf-8'))
newvalue.append(v.decode("utf-8"))
else:
newvalue.append(v)
else:
@ -90,7 +90,8 @@ class _TemplateEnvironmentMixin(object):
super(_TemplateEnvironmentMixin, self).__init__()
self.loader = DynamicDictLoader({})
self.environment = Environment(
loader=self.loader, autoescape=self.should_autoescape)
loader=self.loader, autoescape=self.should_autoescape
)
@property
def should_autoescape(self):
@ -104,13 +105,15 @@ class _TemplateEnvironmentMixin(object):
template_id = id(source)
if not self.contains_template(template_id):
collapsed = re.sub(
self.RIGHT_PATTERN,
">",
re.sub(self.LEFT_PATTERN, "<", source)
self.RIGHT_PATTERN, ">", re.sub(self.LEFT_PATTERN, "<", source)
)
self.loader.update({template_id: collapsed})
self.environment = Environment(loader=self.loader, autoescape=self.should_autoescape, trim_blocks=True,
lstrip_blocks=True)
self.environment = Environment(
loader=self.loader,
autoescape=self.should_autoescape,
trim_blocks=True,
lstrip_blocks=True,
)
return self.environment.get_template(template_id)
@ -119,8 +122,13 @@ class ActionAuthenticatorMixin(object):
request_count = 0
def _authenticate_and_authorize_action(self, iam_request_cls):
if ActionAuthenticatorMixin.request_count >= settings.INITIAL_NO_AUTH_ACTION_COUNT:
iam_request = iam_request_cls(method=self.method, path=self.path, data=self.data, headers=self.headers)
if (
ActionAuthenticatorMixin.request_count
>= settings.INITIAL_NO_AUTH_ACTION_COUNT
):
iam_request = iam_request_cls(
method=self.method, path=self.path, data=self.data, headers=self.headers
)
iam_request.check_signature()
iam_request.check_action_permitted()
else:
@ -137,10 +145,17 @@ class ActionAuthenticatorMixin(object):
def decorator(function):
def wrapper(*args, **kwargs):
if settings.TEST_SERVER_MODE:
response = requests.post("http://localhost:5000/moto-api/reset-auth", data=str(initial_no_auth_action_count).encode())
original_initial_no_auth_action_count = response.json()['PREVIOUS_INITIAL_NO_AUTH_ACTION_COUNT']
response = requests.post(
"http://localhost:5000/moto-api/reset-auth",
data=str(initial_no_auth_action_count).encode(),
)
original_initial_no_auth_action_count = response.json()[
"PREVIOUS_INITIAL_NO_AUTH_ACTION_COUNT"
]
else:
original_initial_no_auth_action_count = settings.INITIAL_NO_AUTH_ACTION_COUNT
original_initial_no_auth_action_count = (
settings.INITIAL_NO_AUTH_ACTION_COUNT
)
original_request_count = ActionAuthenticatorMixin.request_count
settings.INITIAL_NO_AUTH_ACTION_COUNT = initial_no_auth_action_count
ActionAuthenticatorMixin.request_count = 0
@ -148,10 +163,15 @@ class ActionAuthenticatorMixin(object):
result = function(*args, **kwargs)
finally:
if settings.TEST_SERVER_MODE:
requests.post("http://localhost:5000/moto-api/reset-auth", data=str(original_initial_no_auth_action_count).encode())
requests.post(
"http://localhost:5000/moto-api/reset-auth",
data=str(original_initial_no_auth_action_count).encode(),
)
else:
ActionAuthenticatorMixin.request_count = original_request_count
settings.INITIAL_NO_AUTH_ACTION_COUNT = original_initial_no_auth_action_count
settings.INITIAL_NO_AUTH_ACTION_COUNT = (
original_initial_no_auth_action_count
)
return result
functools.update_wrapper(wrapper, function)
@ -163,11 +183,13 @@ class ActionAuthenticatorMixin(object):
class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
default_region = 'us-east-1'
default_region = "us-east-1"
# to extract region, use [^.]
region_regex = re.compile(r'\.(?P<region>[a-z]{2}-[a-z]+-\d{1})\.amazonaws\.com')
param_list_regex = re.compile(r'(.*)\.(\d+)\.')
access_key_regex = re.compile(r'AWS.*(?P<access_key>(?<![A-Z0-9])[A-Z0-9]{20}(?![A-Z0-9]))[:/]')
region_regex = re.compile(r"\.(?P<region>[a-z]{2}-[a-z]+-\d{1})\.amazonaws\.com")
param_list_regex = re.compile(r"(.*)\.(\d+)\.")
access_key_regex = re.compile(
r"AWS.*(?P<access_key>(?<![A-Z0-9])[A-Z0-9]{20}(?![A-Z0-9]))[:/]"
)
aws_service_spec = None
@classmethod
@ -176,7 +198,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
def setup_class(self, request, full_url, headers):
querystring = {}
if hasattr(request, 'body'):
if hasattr(request, "body"):
# Boto
self.body = request.body
else:
@ -189,24 +211,29 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
querystring = {}
for key, value in request.form.items():
querystring[key] = [value, ]
querystring[key] = [value]
raw_body = self.body
if isinstance(self.body, six.binary_type):
self.body = self.body.decode('utf-8')
self.body = self.body.decode("utf-8")
if not querystring:
querystring.update(
parse_qs(urlparse(full_url).query, keep_blank_values=True))
parse_qs(urlparse(full_url).query, keep_blank_values=True)
)
if not querystring:
if 'json' in request.headers.get('content-type', []) and self.aws_service_spec:
if (
"json" in request.headers.get("content-type", [])
and self.aws_service_spec
):
decoded = json.loads(self.body)
target = request.headers.get(
'x-amz-target') or request.headers.get('X-Amz-Target')
service, method = target.split('.')
target = request.headers.get("x-amz-target") or request.headers.get(
"X-Amz-Target"
)
service, method = target.split(".")
input_spec = self.aws_service_spec.input_spec(method)
flat = flatten_json_request_body('', decoded, input_spec)
flat = flatten_json_request_body("", decoded, input_spec)
for key, value in flat.items():
querystring[key] = [value]
elif self.body:
@ -231,17 +258,19 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
self.uri_match = None
self.headers = request.headers
if 'host' not in self.headers:
self.headers['host'] = urlparse(full_url).netloc
if "host" not in self.headers:
self.headers["host"] = urlparse(full_url).netloc
self.response_headers = {"server": "amazon.com"}
def get_region_from_url(self, request, full_url):
match = self.region_regex.search(full_url)
if match:
region = match.group(1)
elif 'Authorization' in request.headers and 'AWS4' in request.headers['Authorization']:
region = request.headers['Authorization'].split(",")[
0].split("/")[2]
elif (
"Authorization" in request.headers
and "AWS4" in request.headers["Authorization"]
):
region = request.headers["Authorization"].split(",")[0].split("/")[2]
else:
region = self.default_region
return region
@ -250,16 +279,16 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
"""
Returns the access key id used in this request as the current user id
"""
if 'Authorization' in self.headers:
match = self.access_key_regex.search(self.headers['Authorization'])
if "Authorization" in self.headers:
match = self.access_key_regex.search(self.headers["Authorization"])
if match:
return match.group(1)
if self.querystring.get('AWSAccessKeyId'):
return self.querystring.get('AWSAccessKeyId')
if self.querystring.get("AWSAccessKeyId"):
return self.querystring.get("AWSAccessKeyId")
else:
# Should we raise an unauthorized exception instead?
return '111122223333'
return "111122223333"
def _dispatch(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
@ -274,17 +303,22 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
-> '^/cars/.*/drivers/.*/drive$'
"""
def _convert(elem, is_last):
if not re.match('^{.*}$', elem):
return elem
name = elem.replace('{', '').replace('}', '')
if is_last:
return '(?P<%s>[^/]*)' % name
return '(?P<%s>.*)' % name
elems = uri.split('/')
def _convert(elem, is_last):
if not re.match("^{.*}$", elem):
return elem
name = elem.replace("{", "").replace("}", "")
if is_last:
return "(?P<%s>[^/]*)" % name
return "(?P<%s>.*)" % name
elems = uri.split("/")
num_elems = len(elems)
regexp = '^{}$'.format('/'.join([_convert(elem, (i == num_elems - 1)) for i, elem in enumerate(elems)]))
regexp = "^{}$".format(
"/".join(
[_convert(elem, (i == num_elems - 1)) for i, elem in enumerate(elems)]
)
)
return regexp
def _get_action_from_method_and_request_uri(self, method, request_uri):
@ -295,19 +329,19 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
# service response class should have 'SERVICE_NAME' class member,
# if you want to get action from method and url
if not hasattr(self, 'SERVICE_NAME'):
if not hasattr(self, "SERVICE_NAME"):
return None
service = self.SERVICE_NAME
conn = boto3.client(service, region_name=self.region)
# make cache if it does not exist yet
if not hasattr(self, 'method_urls'):
if not hasattr(self, "method_urls"):
self.method_urls = defaultdict(lambda: defaultdict(str))
op_names = conn._service_model.operation_names
for op_name in op_names:
op_model = conn._service_model.operation_model(op_name)
_method = op_model.http['method']
uri_regexp = self.uri_to_regexp(op_model.http['requestUri'])
_method = op_model.http["method"]
uri_regexp = self.uri_to_regexp(op_model.http["requestUri"])
self.method_urls[_method][uri_regexp] = op_model.name
regexp_and_names = self.method_urls[method]
for regexp, name in regexp_and_names.items():
@ -318,11 +352,10 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return None
def _get_action(self):
action = self.querystring.get('Action', [""])[0]
action = self.querystring.get("Action", [""])[0]
if not action: # Some services use a header for the action
# Headers are case-insensitive. Probably a better way to do this.
match = self.headers.get(
'x-amz-target') or self.headers.get('X-Amz-Target')
match = self.headers.get("x-amz-target") or self.headers.get("X-Amz-Target")
if match:
action = match.split(".")[-1]
# get action from method and uri
@ -354,10 +387,11 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return self._send_response(headers, response)
if not action:
return 404, headers, ''
return 404, headers, ""
raise NotImplementedError(
"The {0} action has not been implemented".format(action))
"The {0} action has not been implemented".format(action)
)
@staticmethod
def _send_response(headers, response):
@ -365,11 +399,11 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
body, new_headers = response
else:
status, new_headers, body = response
status = new_headers.get('status', 200)
status = new_headers.get("status", 200)
headers.update(new_headers)
# Cast status to string
if "status" in headers:
headers['status'] = str(headers['status'])
headers["status"] = str(headers["status"])
return status, headers, body
def _get_param(self, param_name, if_none=None):
@ -403,9 +437,9 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
def _get_bool_param(self, param_name, if_none=None):
val = self._get_param(param_name)
if val is not None:
if val.lower() == 'true':
if val.lower() == "true":
return True
elif val.lower() == 'false':
elif val.lower() == "false":
return False
return if_none
@ -423,11 +457,16 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
if is_tracked(name) or not name.startswith(param_prefix):
continue
if len(name) > len(param_prefix) and \
not name[len(param_prefix):].startswith('.'):
if len(name) > len(param_prefix) and not name[
len(param_prefix) :
].startswith("."):
continue
match = self.param_list_regex.search(name[len(param_prefix):]) if len(name) > len(param_prefix) else None
match = (
self.param_list_regex.search(name[len(param_prefix) :])
if len(name) > len(param_prefix)
else None
)
if match:
prefix = param_prefix + match.group(1)
value = self._get_multi_param(prefix)
@ -442,7 +481,10 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
if len(value_dict) > 1:
# strip off period prefix
value_dict = {name[len(param_prefix) + 1:]: value for name, value in value_dict.items()}
value_dict = {
name[len(param_prefix) + 1 :]: value
for name, value in value_dict.items()
}
else:
value_dict = list(value_dict.values())[0]
@ -461,7 +503,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
index = 1
while True:
value_dict = self._get_multi_param_helper(prefix + str(index))
if not value_dict and value_dict != '':
if not value_dict and value_dict != "":
break
values.append(value_dict)
@ -486,8 +528,9 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
params = {}
for key, value in self.querystring.items():
if key.startswith(param_prefix):
params[camelcase_to_underscores(
key.replace(param_prefix, ""))] = value[0]
params[camelcase_to_underscores(key.replace(param_prefix, ""))] = value[
0
]
return params
def _get_list_prefix(self, param_prefix):
@ -520,19 +563,20 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
new_items = {}
for key, value in self.querystring.items():
if key.startswith(index_prefix):
new_items[camelcase_to_underscores(
key.replace(index_prefix, ""))] = value[0]
new_items[
camelcase_to_underscores(key.replace(index_prefix, ""))
] = value[0]
if not new_items:
break
results.append(new_items)
param_index += 1
return results
def _get_map_prefix(self, param_prefix, key_end='.key', value_end='.value'):
def _get_map_prefix(self, param_prefix, key_end=".key", value_end=".value"):
results = {}
param_index = 1
while 1:
index_prefix = '{0}.{1}.'.format(param_prefix, param_index)
index_prefix = "{0}.{1}.".format(param_prefix, param_index)
k, v = None, None
for key, value in self.querystring.items():
@ -559,8 +603,8 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
param_index = 1
while True:
key_name = 'tag.{0}._key'.format(param_index)
value_name = 'tag.{0}._value'.format(param_index)
key_name = "tag.{0}._key".format(param_index)
value_name = "tag.{0}._value".format(param_index)
try:
results[resource_type][tag[key_name]] = tag[value_name]
@ -570,7 +614,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return results
def _get_object_map(self, prefix, name='Name', value='Value'):
def _get_object_map(self, prefix, name="Name", value="Value"):
"""
Given a query dict like
{
@ -598,15 +642,14 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
index = 1
while True:
# Loop through looking for keys representing object name
name_key = '{0}.{1}.{2}'.format(prefix, index, name)
name_key = "{0}.{1}.{2}".format(prefix, index, name)
obj_name = self.querystring.get(name_key)
if not obj_name:
# Found all keys
break
obj = {}
value_key_prefix = '{0}.{1}.{2}.'.format(
prefix, index, value)
value_key_prefix = "{0}.{1}.{2}.".format(prefix, index, value)
for k, v in self.querystring.items():
if k.startswith(value_key_prefix):
_, value_key = k.split(value_key_prefix, 1)
@ -620,31 +663,46 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
@property
def request_json(self):
return 'JSON' in self.querystring.get('ContentType', [])
return "JSON" in self.querystring.get("ContentType", [])
def is_not_dryrun(self, action):
if 'true' in self.querystring.get('DryRun', ['false']):
message = 'An error occurred (DryRunOperation) when calling the %s operation: Request would have succeeded, but DryRun flag is set' % action
raise DryRunClientError(
error_type="DryRunOperation", message=message)
if "true" in self.querystring.get("DryRun", ["false"]):
message = (
"An error occurred (DryRunOperation) when calling the %s operation: Request would have succeeded, but DryRun flag is set"
% action
)
raise DryRunClientError(error_type="DryRunOperation", message=message)
return True
class MotoAPIResponse(BaseResponse):
def reset_response(self, request, full_url, headers):
if request.method == "POST":
from .models import moto_api_backend
moto_api_backend.reset()
return 200, {}, json.dumps({"status": "ok"})
return 400, {}, json.dumps({"Error": "Need to POST to reset Moto"})
def reset_auth_response(self, request, full_url, headers):
if request.method == "POST":
previous_initial_no_auth_action_count = settings.INITIAL_NO_AUTH_ACTION_COUNT
previous_initial_no_auth_action_count = (
settings.INITIAL_NO_AUTH_ACTION_COUNT
)
settings.INITIAL_NO_AUTH_ACTION_COUNT = float(request.data.decode())
ActionAuthenticatorMixin.request_count = 0
return 200, {}, json.dumps({"status": "ok", "PREVIOUS_INITIAL_NO_AUTH_ACTION_COUNT": str(previous_initial_no_auth_action_count)})
return (
200,
{},
json.dumps(
{
"status": "ok",
"PREVIOUS_INITIAL_NO_AUTH_ACTION_COUNT": str(
previous_initial_no_auth_action_count
),
}
),
)
return 400, {}, json.dumps({"Error": "Need to POST to reset Moto Auth"})
def model_data(self, request, full_url, headers):
@ -672,7 +730,8 @@ class MotoAPIResponse(BaseResponse):
def dashboard(self, request, full_url, headers):
from flask import render_template
return render_template('dashboard.html')
return render_template("dashboard.html")
class _RecursiveDictRef(object):
@ -683,7 +742,7 @@ class _RecursiveDictRef(object):
self.dic = {}
def __repr__(self):
return '{!r}'.format(self.dic)
return "{!r}".format(self.dic)
def __getattr__(self, key):
return self.dic.__getattr__(key)
@ -707,21 +766,21 @@ class AWSServiceSpec(object):
"""
def __init__(self, path):
self.path = resource_filename('botocore', path)
with io.open(self.path, 'r', encoding='utf-8') as f:
self.path = resource_filename("botocore", path)
with io.open(self.path, "r", encoding="utf-8") as f:
spec = json.load(f)
self.metadata = spec['metadata']
self.operations = spec['operations']
self.shapes = spec['shapes']
self.metadata = spec["metadata"]
self.operations = spec["operations"]
self.shapes = spec["shapes"]
def input_spec(self, operation):
try:
op = self.operations[operation]
except KeyError:
raise ValueError('Invalid operation: {}'.format(operation))
if 'input' not in op:
raise ValueError("Invalid operation: {}".format(operation))
if "input" not in op:
return {}
shape = self.shapes[op['input']['shape']]
shape = self.shapes[op["input"]["shape"]]
return self._expand(shape)
def output_spec(self, operation):
@ -735,129 +794,133 @@ class AWSServiceSpec(object):
try:
op = self.operations[operation]
except KeyError:
raise ValueError('Invalid operation: {}'.format(operation))
if 'output' not in op:
raise ValueError("Invalid operation: {}".format(operation))
if "output" not in op:
return {}
shape = self.shapes[op['output']['shape']]
shape = self.shapes[op["output"]["shape"]]
return self._expand(shape)
def _expand(self, shape):
def expand(dic, seen=None):
seen = seen or {}
if dic['type'] == 'structure':
if dic["type"] == "structure":
nodes = {}
for k, v in dic['members'].items():
for k, v in dic["members"].items():
seen_till_here = dict(seen)
if k in seen_till_here:
nodes[k] = seen_till_here[k]
continue
seen_till_here[k] = _RecursiveDictRef()
nodes[k] = expand(self.shapes[v['shape']], seen_till_here)
nodes[k] = expand(self.shapes[v["shape"]], seen_till_here)
seen_till_here[k].set_reference(k, nodes[k])
nodes['type'] = 'structure'
nodes["type"] = "structure"
return nodes
elif dic['type'] == 'list':
elif dic["type"] == "list":
seen_till_here = dict(seen)
shape = dic['member']['shape']
shape = dic["member"]["shape"]
if shape in seen_till_here:
return seen_till_here[shape]
seen_till_here[shape] = _RecursiveDictRef()
expanded = expand(self.shapes[shape], seen_till_here)
seen_till_here[shape].set_reference(shape, expanded)
return {'type': 'list', 'member': expanded}
return {"type": "list", "member": expanded}
elif dic['type'] == 'map':
elif dic["type"] == "map":
seen_till_here = dict(seen)
node = {'type': 'map'}
node = {"type": "map"}
if 'shape' in dic['key']:
shape = dic['key']['shape']
if "shape" in dic["key"]:
shape = dic["key"]["shape"]
seen_till_here[shape] = _RecursiveDictRef()
node['key'] = expand(self.shapes[shape], seen_till_here)
seen_till_here[shape].set_reference(shape, node['key'])
node["key"] = expand(self.shapes[shape], seen_till_here)
seen_till_here[shape].set_reference(shape, node["key"])
else:
node['key'] = dic['key']['type']
node["key"] = dic["key"]["type"]
if 'shape' in dic['value']:
shape = dic['value']['shape']
if "shape" in dic["value"]:
shape = dic["value"]["shape"]
seen_till_here[shape] = _RecursiveDictRef()
node['value'] = expand(self.shapes[shape], seen_till_here)
seen_till_here[shape].set_reference(shape, node['value'])
node["value"] = expand(self.shapes[shape], seen_till_here)
seen_till_here[shape].set_reference(shape, node["value"])
else:
node['value'] = dic['value']['type']
node["value"] = dic["value"]["type"]
return node
else:
return {'type': dic['type']}
return {"type": dic["type"]}
return expand(shape)
def to_str(value, spec):
vtype = spec['type']
if vtype == 'boolean':
return 'true' if value else 'false'
elif vtype == 'integer':
vtype = spec["type"]
if vtype == "boolean":
return "true" if value else "false"
elif vtype == "integer":
return str(value)
elif vtype == 'float':
elif vtype == "float":
return str(value)
elif vtype == 'double':
elif vtype == "double":
return str(value)
elif vtype == 'timestamp':
return datetime.datetime.utcfromtimestamp(
value).replace(tzinfo=pytz.utc).isoformat()
elif vtype == 'string':
elif vtype == "timestamp":
return (
datetime.datetime.utcfromtimestamp(value)
.replace(tzinfo=pytz.utc)
.isoformat()
)
elif vtype == "string":
return str(value)
elif value is None:
return 'null'
return "null"
else:
raise TypeError('Unknown type {}'.format(vtype))
raise TypeError("Unknown type {}".format(vtype))
def from_str(value, spec):
vtype = spec['type']
if vtype == 'boolean':
return True if value == 'true' else False
elif vtype == 'integer':
vtype = spec["type"]
if vtype == "boolean":
return True if value == "true" else False
elif vtype == "integer":
return int(value)
elif vtype == 'float':
elif vtype == "float":
return float(value)
elif vtype == 'double':
elif vtype == "double":
return float(value)
elif vtype == 'timestamp':
elif vtype == "timestamp":
return value
elif vtype == 'string':
elif vtype == "string":
return value
raise TypeError('Unknown type {}'.format(vtype))
raise TypeError("Unknown type {}".format(vtype))
def flatten_json_request_body(prefix, dict_body, spec):
"""Convert a JSON request body into query params."""
if len(spec) == 1 and 'type' in spec:
if len(spec) == 1 and "type" in spec:
return {prefix: to_str(dict_body, spec)}
flat = {}
for key, value in dict_body.items():
node_type = spec[key]['type']
if node_type == 'list':
node_type = spec[key]["type"]
if node_type == "list":
for idx, v in enumerate(value, 1):
pref = key + '.member.' + str(idx)
flat.update(flatten_json_request_body(
pref, v, spec[key]['member']))
elif node_type == 'map':
pref = key + ".member." + str(idx)
flat.update(flatten_json_request_body(pref, v, spec[key]["member"]))
elif node_type == "map":
for idx, (k, v) in enumerate(value.items(), 1):
pref = key + '.entry.' + str(idx)
flat.update(flatten_json_request_body(
pref + '.key', k, spec[key]['key']))
flat.update(flatten_json_request_body(
pref + '.value', v, spec[key]['value']))
pref = key + ".entry." + str(idx)
flat.update(
flatten_json_request_body(pref + ".key", k, spec[key]["key"])
)
flat.update(
flatten_json_request_body(pref + ".value", v, spec[key]["value"])
)
else:
flat.update(flatten_json_request_body(key, value, spec[key]))
if prefix:
prefix = prefix + '.'
prefix = prefix + "."
return dict((prefix + k, v) for k, v in flat.items())
@ -880,41 +943,40 @@ def xml_to_json_response(service_spec, operation, xml, result_node=None):
od = OrderedDict()
for k, v in value.items():
if k.startswith('@'):
if k.startswith("@"):
continue
if k not in spec:
# this can happen when with an older version of
# botocore for which the node in XML template is not
# defined in service spec.
log.warning(
'Field %s is not defined by the botocore version in use', k)
log.warning("Field %s is not defined by the botocore version in use", k)
continue
if spec[k]['type'] == 'list':
if spec[k]["type"] == "list":
if v is None:
od[k] = []
elif len(spec[k]['member']) == 1:
if isinstance(v['member'], list):
od[k] = transform(v['member'], spec[k]['member'])
elif len(spec[k]["member"]) == 1:
if isinstance(v["member"], list):
od[k] = transform(v["member"], spec[k]["member"])
else:
od[k] = [transform(v['member'], spec[k]['member'])]
elif isinstance(v['member'], list):
od[k] = [transform(o, spec[k]['member'])
for o in v['member']]
elif isinstance(v['member'], OrderedDict):
od[k] = [transform(v['member'], spec[k]['member'])]
od[k] = [transform(v["member"], spec[k]["member"])]
elif isinstance(v["member"], list):
od[k] = [transform(o, spec[k]["member"]) for o in v["member"]]
elif isinstance(v["member"], OrderedDict):
od[k] = [transform(v["member"], spec[k]["member"])]
else:
raise ValueError('Malformatted input')
elif spec[k]['type'] == 'map':
raise ValueError("Malformatted input")
elif spec[k]["type"] == "map":
if v is None:
od[k] = {}
else:
items = ([v['entry']] if not isinstance(v['entry'], list) else
v['entry'])
items = (
[v["entry"]] if not isinstance(v["entry"], list) else v["entry"]
)
for item in items:
key = from_str(item['key'], spec[k]['key'])
val = from_str(item['value'], spec[k]['value'])
key = from_str(item["key"], spec[k]["key"])
val = from_str(item["value"], spec[k]["value"])
if k not in od:
od[k] = {}
od[k][key] = val
@ -928,7 +990,7 @@ def xml_to_json_response(service_spec, operation, xml, result_node=None):
dic = xmltodict.parse(xml)
output_spec = service_spec.output_spec(operation)
try:
for k in (result_node or (operation + 'Response', operation + 'Result')):
for k in result_node or (operation + "Response", operation + "Result"):
dic = dic[k]
except KeyError:
return None

View file

@ -1,15 +1,13 @@
from __future__ import unicode_literals
from .responses import MotoAPIResponse
url_bases = [
"https?://motoapi.amazonaws.com"
]
url_bases = ["https?://motoapi.amazonaws.com"]
response_instance = MotoAPIResponse()
url_paths = {
'{0}/moto-api/$': response_instance.dashboard,
'{0}/moto-api/data.json': response_instance.model_data,
'{0}/moto-api/reset': response_instance.reset_response,
'{0}/moto-api/reset-auth': response_instance.reset_auth_response,
"{0}/moto-api/$": response_instance.dashboard,
"{0}/moto-api/data.json": response_instance.model_data,
"{0}/moto-api/reset": response_instance.reset_response,
"{0}/moto-api/reset-auth": response_instance.reset_auth_response,
}

View file

@ -15,9 +15,9 @@ REQUEST_ID_LONG = string.digits + string.ascii_uppercase
def camelcase_to_underscores(argument):
''' Converts a camelcase param like theNewAttribute to the equivalent
python underscore variable like the_new_attribute'''
result = ''
""" Converts a camelcase param like theNewAttribute to the equivalent
python underscore variable like the_new_attribute"""
result = ""
prev_char_title = True
if not argument:
return argument
@ -41,18 +41,18 @@ def camelcase_to_underscores(argument):
def underscores_to_camelcase(argument):
''' Converts a camelcase param like the_new_attribute to the equivalent
""" Converts a camelcase param like the_new_attribute to the equivalent
camelcase version like theNewAttribute. Note that the first letter is
NOT capitalized by this function '''
result = ''
NOT capitalized by this function """
result = ""
previous_was_underscore = False
for char in argument:
if char != '_':
if char != "_":
if previous_was_underscore:
result += char.upper()
else:
result += char
previous_was_underscore = char == '_'
previous_was_underscore = char == "_"
return result
@ -69,12 +69,18 @@ def method_names_from_class(clazz):
def get_random_hex(length=8):
chars = list(range(10)) + ['a', 'b', 'c', 'd', 'e', 'f']
return ''.join(six.text_type(random.choice(chars)) for x in range(length))
chars = list(range(10)) + ["a", "b", "c", "d", "e", "f"]
return "".join(six.text_type(random.choice(chars)) for x in range(length))
def get_random_message_id():
return '{0}-{1}-{2}-{3}-{4}'.format(get_random_hex(8), get_random_hex(4), get_random_hex(4), get_random_hex(4), get_random_hex(12))
return "{0}-{1}-{2}-{3}-{4}".format(
get_random_hex(8),
get_random_hex(4),
get_random_hex(4),
get_random_hex(4),
get_random_hex(12),
)
def convert_regex_to_flask_path(url_path):
@ -97,7 +103,6 @@ def convert_regex_to_flask_path(url_path):
class convert_httpretty_response(object):
def __init__(self, callback):
self.callback = callback
@ -114,13 +119,12 @@ class convert_httpretty_response(object):
def __call__(self, request, url, headers, **kwargs):
result = self.callback(request, url, headers)
status, headers, response = result
if 'server' not in headers:
if "server" not in headers:
headers["server"] = "amazon.com"
return status, headers, response
class convert_flask_to_httpretty_response(object):
def __init__(self, callback):
self.callback = callback
@ -145,13 +149,12 @@ class convert_flask_to_httpretty_response(object):
status, headers, content = 200, {}, result
response = Response(response=content, status=status, headers=headers)
if request.method == "HEAD" and 'content-length' in headers:
response.headers['Content-Length'] = headers['content-length']
if request.method == "HEAD" and "content-length" in headers:
response.headers["Content-Length"] = headers["content-length"]
return response
class convert_flask_to_responses_response(object):
def __init__(self, callback):
self.callback = callback
@ -176,14 +179,14 @@ class convert_flask_to_responses_response(object):
def iso_8601_datetime_with_milliseconds(datetime):
return datetime.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + 'Z'
return datetime.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z"
def iso_8601_datetime_without_milliseconds(datetime):
return datetime.strftime("%Y-%m-%dT%H:%M:%S") + 'Z'
return datetime.strftime("%Y-%m-%dT%H:%M:%S") + "Z"
RFC1123 = '%a, %d %b %Y %H:%M:%S GMT'
RFC1123 = "%a, %d %b %Y %H:%M:%S GMT"
def rfc_1123_datetime(datetime):
@ -212,16 +215,16 @@ def gen_amz_crc32(response, headerdict=None):
crc = str(binascii.crc32(response))
if headerdict is not None and isinstance(headerdict, dict):
headerdict.update({'x-amz-crc32': crc})
headerdict.update({"x-amz-crc32": crc})
return crc
def gen_amzn_requestid_long(headerdict=None):
req_id = ''.join([random.choice(REQUEST_ID_LONG) for _ in range(0, 52)])
req_id = "".join([random.choice(REQUEST_ID_LONG) for _ in range(0, 52)])
if headerdict is not None and isinstance(headerdict, dict):
headerdict.update({'x-amzn-requestid': req_id})
headerdict.update({"x-amzn-requestid": req_id})
return req_id
@ -239,13 +242,13 @@ def amz_crc32(f):
else:
if len(response) == 2:
body, new_headers = response
status = new_headers.get('status', 200)
status = new_headers.get("status", 200)
else:
status, new_headers, body = response
headers.update(new_headers)
# Cast status to string
if "status" in headers:
headers['status'] = str(headers['status'])
headers["status"] = str(headers["status"])
try:
# Doesnt work on python2 for some odd unicode strings
@ -271,7 +274,7 @@ def amzn_request_id(f):
else:
if len(response) == 2:
body, new_headers = response
status = new_headers.get('status', 200)
status = new_headers.get("status", 200)
else:
status, new_headers, body = response
headers.update(new_headers)
@ -280,7 +283,7 @@ def amzn_request_id(f):
# Update request ID in XML
try:
body = re.sub(r'(?<=<RequestId>).*(?=<\/RequestId>)', request_id, body)
body = re.sub(r"(?<=<RequestId>).*(?=<\/RequestId>)", request_id, body)
except Exception: # Will just ignore if it cant work on bytes (which are str's on python2)
pass
@ -293,7 +296,7 @@ def path_url(url):
parsed_url = urlparse(url)
path = parsed_url.path
if not path:
path = '/'
path = "/"
if parsed_url.query:
path = path + '?' + parsed_url.query
path = path + "?" + parsed_url.query
return path