Back to Black

This commit is contained in:
Matěj Cepl 2020-10-06 08:46:05 +02:00
commit 5697ff87a8
112 changed files with 1803 additions and 977 deletions

View file

@ -3,35 +3,64 @@ import json
# Taken from free tier list when creating an instance
instances = [
'ami-760aaa0f', 'ami-bb9a6bc2', 'ami-35e92e4c', 'ami-785db401', 'ami-b7e93bce', 'ami-dca37ea5', 'ami-999844e0',
'ami-9b32e8e2', 'ami-f8e54081', 'ami-bceb39c5', 'ami-03cf127a', 'ami-1ecc1e67', 'ami-c2ff2dbb', 'ami-12c6146b',
'ami-d1cb19a8', 'ami-61db0918', 'ami-56ec3e2f', 'ami-84ee3cfd', 'ami-86ee3cff', 'ami-f0e83a89', 'ami-1f12c066',
'ami-afee3cd6', 'ami-1812c061', 'ami-77ed3f0e', 'ami-3bf32142', 'ami-6ef02217', 'ami-f4cf1d8d', 'ami-3df32144',
'ami-c6f321bf', 'ami-24f3215d', 'ami-fa7cdd89', 'ami-1e749f67', 'ami-a9cc1ed0', 'ami-8104a4f8'
"ami-760aaa0f",
"ami-bb9a6bc2",
"ami-35e92e4c",
"ami-785db401",
"ami-b7e93bce",
"ami-dca37ea5",
"ami-999844e0",
"ami-9b32e8e2",
"ami-f8e54081",
"ami-bceb39c5",
"ami-03cf127a",
"ami-1ecc1e67",
"ami-c2ff2dbb",
"ami-12c6146b",
"ami-d1cb19a8",
"ami-61db0918",
"ami-56ec3e2f",
"ami-84ee3cfd",
"ami-86ee3cff",
"ami-f0e83a89",
"ami-1f12c066",
"ami-afee3cd6",
"ami-1812c061",
"ami-77ed3f0e",
"ami-3bf32142",
"ami-6ef02217",
"ami-f4cf1d8d",
"ami-3df32144",
"ami-c6f321bf",
"ami-24f3215d",
"ami-fa7cdd89",
"ami-1e749f67",
"ami-a9cc1ed0",
"ami-8104a4f8",
]
client = boto3.client('ec2', region_name='eu-west-1')
client = boto3.client("ec2", region_name="eu-west-1")
test = client.describe_images(ImageIds=instances)
result = []
for image in test['Images']:
for image in test["Images"]:
try:
tmp = {
'ami_id': image['ImageId'],
'name': image['Name'],
'description': image['Description'],
'owner_id': image['OwnerId'],
'public': image['Public'],
'virtualization_type': image['VirtualizationType'],
'architecture': image['Architecture'],
'state': image['State'],
'platform': image.get('Platform'),
'image_type': image['ImageType'],
'hypervisor': image['Hypervisor'],
'root_device_name': image['RootDeviceName'],
'root_device_type': image['RootDeviceType'],
'sriov': image.get('SriovNetSupport', 'simple')
"ami_id": image["ImageId"],
"name": image["Name"],
"description": image["Description"],
"owner_id": image["OwnerId"],
"public": image["Public"],
"virtualization_type": image["VirtualizationType"],
"architecture": image["Architecture"],
"state": image["State"],
"platform": image.get("Platform"),
"image_type": image["ImageType"],
"hypervisor": image["Hypervisor"],
"root_device_name": image["RootDeviceName"],
"root_device_type": image["RootDeviceType"],
"sriov": image.get("SriovNetSupport", "simple"),
}
result.append(tmp)
except Exception as err:

View file

@ -7,12 +7,18 @@ import boto3
script_dir = os.path.dirname(os.path.abspath(__file__))
alternative_service_names = {'lambda': 'awslambda', 'dynamodb': 'dynamodb2'}
alternative_service_names = {"lambda": "awslambda", "dynamodb": "dynamodb2"}
def get_moto_implementation(service_name):
service_name = service_name.replace("-", "") if "-" in service_name else service_name
alt_service_name = alternative_service_names[service_name] if service_name in alternative_service_names else service_name
service_name = (
service_name.replace("-", "") if "-" in service_name else service_name
)
alt_service_name = (
alternative_service_names[service_name]
if service_name in alternative_service_names
else service_name
)
if hasattr(moto, "mock_{}".format(alt_service_name)):
mock = getattr(moto, "mock_{}".format(alt_service_name))
elif hasattr(moto, "mock_{}".format(service_name)):
@ -31,11 +37,13 @@ def calculate_implementation_coverage():
coverage = {}
for service_name in service_names:
moto_client = get_moto_implementation(service_name)
real_client = boto3.client(service_name, region_name='us-east-1')
real_client = boto3.client(service_name, region_name="us-east-1")
implemented = []
not_implemented = []
operation_names = [xform_name(op) for op in real_client.meta.service_model.operation_names]
operation_names = [
xform_name(op) for op in real_client.meta.service_model.operation_names
]
for op in operation_names:
if moto_client and op in dir(moto_client):
implemented.append(op)
@ -43,20 +51,22 @@ def calculate_implementation_coverage():
not_implemented.append(op)
coverage[service_name] = {
'implemented': implemented,
'not_implemented': not_implemented,
"implemented": implemented,
"not_implemented": not_implemented,
}
return coverage
def print_implementation_coverage(coverage):
for service_name in sorted(coverage):
implemented = coverage.get(service_name)['implemented']
not_implemented = coverage.get(service_name)['not_implemented']
implemented = coverage.get(service_name)["implemented"]
not_implemented = coverage.get(service_name)["not_implemented"]
operations = sorted(implemented + not_implemented)
if implemented and not_implemented:
percentage_implemented = int(100.0 * len(implemented) / (len(implemented) + len(not_implemented)))
percentage_implemented = int(
100.0 * len(implemented) / (len(implemented) + len(not_implemented))
)
elif implemented:
percentage_implemented = 100
else:
@ -84,12 +94,14 @@ def write_implementation_coverage_to_file(coverage):
print("Writing to {}".format(implementation_coverage_file))
with open(implementation_coverage_file, "w+") as file:
for service_name in sorted(coverage):
implemented = coverage.get(service_name)['implemented']
not_implemented = coverage.get(service_name)['not_implemented']
implemented = coverage.get(service_name)["implemented"]
not_implemented = coverage.get(service_name)["not_implemented"]
operations = sorted(implemented + not_implemented)
if implemented and not_implemented:
percentage_implemented = int(100.0 * len(implemented) / (len(implemented) + len(not_implemented)))
percentage_implemented = int(
100.0 * len(implemented) / (len(implemented) + len(not_implemented))
)
elif implemented:
percentage_implemented = 100
else:
@ -98,7 +110,9 @@ def write_implementation_coverage_to_file(coverage):
file.write("\n")
file.write("## {}\n".format(service_name))
file.write("<details>\n")
file.write("<summary>{}% implemented</summary>\n\n".format(percentage_implemented))
file.write(
"<summary>{}% implemented</summary>\n\n".format(percentage_implemented)
)
for op in operations:
if op in implemented:
file.write("- [X] {}\n".format(op))
@ -107,7 +121,7 @@ def write_implementation_coverage_to_file(coverage):
file.write("</details>\n")
if __name__ == '__main__':
if __name__ == "__main__":
cov = calculate_implementation_coverage()
write_implementation_coverage_to_file(cov)
print_implementation_coverage(cov)

View file

@ -17,9 +17,7 @@ from lxml import etree
import click
import jinja2
from prompt_toolkit import (
prompt
)
from prompt_toolkit import prompt
from prompt_toolkit.completion import WordCompleter
from prompt_toolkit.shortcuts import print_formatted_text
@ -29,35 +27,35 @@ import boto3
from moto.core.responses import BaseResponse
from moto.core import BaseBackend
from implementation_coverage import (
get_moto_implementation
)
from implementation_coverage import get_moto_implementation
from inflection import singularize
TEMPLATE_DIR = os.path.join(os.path.dirname(__file__), './template')
TEMPLATE_DIR = os.path.join(os.path.dirname(__file__), "./template")
INPUT_IGNORED_IN_BACKEND = ['Marker', 'PageSize']
OUTPUT_IGNORED_IN_BACKEND = ['NextMarker']
INPUT_IGNORED_IN_BACKEND = ["Marker", "PageSize"]
OUTPUT_IGNORED_IN_BACKEND = ["NextMarker"]
def print_progress(title, body, color):
click.secho(u'\t{}\t'.format(title), fg=color, nl=False)
click.secho(u"\t{}\t".format(title), fg=color, nl=False)
click.echo(body)
def select_service_and_operation():
service_names = Session().get_available_services()
service_completer = WordCompleter(service_names)
service_name = prompt(u'Select service: ', completer=service_completer)
service_name = prompt(u"Select service: ", completer=service_completer)
if service_name not in service_names:
click.secho(u'{} is not valid service'.format(service_name), fg='red')
click.secho(u"{} is not valid service".format(service_name), fg="red")
raise click.Abort()
moto_client = get_moto_implementation(service_name)
real_client = boto3.client(service_name, region_name='us-east-1')
real_client = boto3.client(service_name, region_name="us-east-1")
implemented = []
not_implemented = []
operation_names = [xform_name(op) for op in real_client.meta.service_model.operation_names]
operation_names = [
xform_name(op) for op in real_client.meta.service_model.operation_names
]
for op in operation_names:
if moto_client and op in dir(moto_client):
implemented.append(op)
@ -65,133 +63,148 @@ def select_service_and_operation():
not_implemented.append(op)
operation_completer = WordCompleter(operation_names)
click.echo('==Current Implementation Status==')
click.echo("==Current Implementation Status==")
for operation_name in operation_names:
check = 'X' if operation_name in implemented else ' '
click.secho('[{}] {}'.format(check, operation_name))
click.echo('=================================')
operation_name = prompt(u'Select Operation: ', completer=operation_completer)
check = "X" if operation_name in implemented else " "
click.secho("[{}] {}".format(check, operation_name))
click.echo("=================================")
operation_name = prompt(u"Select Operation: ", completer=operation_completer)
if operation_name not in operation_names:
click.secho('{} is not valid operation'.format(operation_name), fg='red')
click.secho("{} is not valid operation".format(operation_name), fg="red")
raise click.Abort()
if operation_name in implemented:
click.secho('{} is already implemented'.format(operation_name), fg='red')
click.secho("{} is already implemented".format(operation_name), fg="red")
raise click.Abort()
return service_name, operation_name
def get_escaped_service(service):
return service.replace('-', '')
return service.replace("-", "")
def get_lib_dir(service):
return os.path.join('moto', get_escaped_service(service))
return os.path.join("moto", get_escaped_service(service))
def get_test_dir(service):
return os.path.join('tests', 'test_{}'.format(get_escaped_service(service)))
return os.path.join("tests", "test_{}".format(get_escaped_service(service)))
def render_template(tmpl_dir, tmpl_filename, context, service, alt_filename=None):
is_test = True if 'test' in tmpl_dir else False
rendered = jinja2.Environment(
loader=jinja2.FileSystemLoader(tmpl_dir)
).get_template(tmpl_filename).render(context)
is_test = True if "test" in tmpl_dir else False
rendered = (
jinja2.Environment(loader=jinja2.FileSystemLoader(tmpl_dir))
.get_template(tmpl_filename)
.render(context)
)
dirname = get_test_dir(service) if is_test else get_lib_dir(service)
filename = alt_filename or os.path.splitext(tmpl_filename)[0]
filepath = os.path.join(dirname, filename)
if os.path.exists(filepath):
print_progress('skip creating', filepath, 'yellow')
print_progress("skip creating", filepath, "yellow")
else:
print_progress('creating', filepath, 'green')
with open(filepath, 'w') as f:
print_progress("creating", filepath, "green")
with open(filepath, "w") as f:
f.write(rendered)
def append_mock_to_init_py(service):
path = os.path.join(os.path.dirname(__file__), '..', 'moto', '__init__.py')
path = os.path.join(os.path.dirname(__file__), "..", "moto", "__init__.py")
with open(path) as f:
lines = [_.replace('\n', '') for _ in f.readlines()]
lines = [_.replace("\n", "") for _ in f.readlines()]
if any(_ for _ in lines if re.match('^mock_{}.*lazy_load(.*)$'.format(service), _)):
if any(_ for _ in lines if re.match("^mock_{}.*lazy_load(.*)$".format(service), _)):
return
filtered_lines = [_ for _ in lines if re.match('^mock_.*lazy_load(.*)$', _)]
filtered_lines = [_ for _ in lines if re.match("^mock_.*lazy_load(.*)$", _)]
last_import_line_index = lines.index(filtered_lines[-1])
new_line = 'mock_{} = lazy_load(".{}", "mock_{}")'.format(get_escaped_service(service), get_escaped_service(service), get_escaped_service(service))
new_line = 'mock_{} = lazy_load(".{}", "mock_{}")'.format(
get_escaped_service(service),
get_escaped_service(service),
get_escaped_service(service),
)
lines.insert(last_import_line_index + 1, new_line)
body = '\n'.join(lines) + '\n'
with open(path, 'w') as f:
body = "\n".join(lines) + "\n"
with open(path, "w") as f:
f.write(body)
def append_mock_dict_to_backends_py(service):
path = os.path.join(os.path.dirname(__file__), '..', 'moto', 'backends.py')
path = os.path.join(os.path.dirname(__file__), "..", "moto", "backends.py")
with open(path) as f:
lines = [_.replace('\n', '') for _ in f.readlines()]
lines = [_.replace("\n", "") for _ in f.readlines()]
if any(_ for _ in lines if re.match(".*\"{}\": {}_backends.*".format(service, service), _)):
if any(
_
for _ in lines
if re.match('.*"{}": {}_backends.*'.format(service, service), _)
):
return
filtered_lines = [_ for _ in lines if re.match(".*\".*\":.*_backends.*", _)]
filtered_lines = [_ for _ in lines if re.match('.*".*":.*_backends.*', _)]
last_elem_line_index = lines.index(filtered_lines[-1])
new_line = " \"{}\": (\"{}\", \"{}_backends\"),".format(service, get_escaped_service(service), get_escaped_service(service))
new_line = ' "{}": ("{}", "{}_backends"),'.format(
service, get_escaped_service(service), get_escaped_service(service)
)
prev_line = lines[last_elem_line_index]
if not prev_line.endswith('{') and not prev_line.endswith(','):
lines[last_elem_line_index] += ','
if not prev_line.endswith("{") and not prev_line.endswith(","):
lines[last_elem_line_index] += ","
lines.insert(last_elem_line_index + 1, new_line)
body = '\n'.join(lines) + '\n'
with open(path, 'w') as f:
body = "\n".join(lines) + "\n"
with open(path, "w") as f:
f.write(body)
def initialize_service(service, operation, api_protocol):
"""create lib and test dirs if not exist
"""
"""create lib and test dirs if not exist"""
lib_dir = get_lib_dir(service)
test_dir = get_test_dir(service)
print_progress('Initializing service', service, 'green')
print_progress("Initializing service", service, "green")
client = boto3.client(service)
service_class = client.__class__.__name__
endpoint_prefix = client._service_model.endpoint_prefix
tmpl_context = {
'service': service,
'service_class': service_class,
'endpoint_prefix': endpoint_prefix,
'api_protocol': api_protocol,
'escaped_service': get_escaped_service(service)
"service": service,
"service_class": service_class,
"endpoint_prefix": endpoint_prefix,
"api_protocol": api_protocol,
"escaped_service": get_escaped_service(service),
}
# initialize service directory
if os.path.exists(lib_dir):
print_progress('skip creating', lib_dir, 'yellow')
print_progress("skip creating", lib_dir, "yellow")
else:
print_progress('creating', lib_dir, 'green')
print_progress("creating", lib_dir, "green")
os.makedirs(lib_dir)
tmpl_dir = os.path.join(TEMPLATE_DIR, 'lib')
tmpl_dir = os.path.join(TEMPLATE_DIR, "lib")
for tmpl_filename in os.listdir(tmpl_dir):
render_template(
tmpl_dir, tmpl_filename, tmpl_context, service
)
render_template(tmpl_dir, tmpl_filename, tmpl_context, service)
# initialize test directory
if os.path.exists(test_dir):
print_progress('skip creating', test_dir, 'yellow')
print_progress("skip creating", test_dir, "yellow")
else:
print_progress('creating', test_dir, 'green')
print_progress("creating", test_dir, "green")
os.makedirs(test_dir)
tmpl_dir = os.path.join(TEMPLATE_DIR, 'test')
tmpl_dir = os.path.join(TEMPLATE_DIR, "test")
for tmpl_filename in os.listdir(tmpl_dir):
alt_filename = 'test_{}.py'.format(get_escaped_service(service)) if tmpl_filename == 'test_service.py.j2' else None
render_template(
tmpl_dir, tmpl_filename, tmpl_context, service, alt_filename
alt_filename = (
"test_{}.py".format(get_escaped_service(service))
if tmpl_filename == "test_service.py.j2"
else None
)
render_template(tmpl_dir, tmpl_filename, tmpl_context, service, alt_filename)
# append mock to init files
append_mock_to_init_py(service)
@ -199,22 +212,24 @@ def initialize_service(service, operation, api_protocol):
def to_upper_camel_case(s):
return ''.join([_.title() for _ in s.split('_')])
return "".join([_.title() for _ in s.split("_")])
def to_lower_camel_case(s):
words = s.split('_')
return ''.join(words[:1] + [_.title() for _ in words[1:]])
words = s.split("_")
return "".join(words[:1] + [_.title() for _ in words[1:]])
def to_snake_case(s):
s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', s)
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", s)
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
def get_operation_name_in_keys(operation_name, operation_keys):
index = [_.lower() for _ in operation_keys].index(operation_name.lower())
return operation_keys[index]
def get_function_in_responses(service, operation, protocol):
"""refers to definition of API in botocore, and autogenerates function
You can see example of elbv2 from link below.
@ -224,44 +239,56 @@ def get_function_in_responses(service, operation, protocol):
aws_operation_name = get_operation_name_in_keys(
to_upper_camel_case(operation),
list(client._service_model._service_description['operations'].keys())
list(client._service_model._service_description["operations"].keys()),
)
op_model = client._service_model.operation_model(aws_operation_name)
if not hasattr(op_model.output_shape, 'members'):
if not hasattr(op_model.output_shape, "members"):
outputs = {}
else:
outputs = op_model.output_shape.members
inputs = op_model.input_shape.members
input_names = [to_snake_case(_) for _ in inputs.keys() if _ not in INPUT_IGNORED_IN_BACKEND]
output_names = [to_snake_case(_) for _ in outputs.keys() if _ not in OUTPUT_IGNORED_IN_BACKEND]
body = '\ndef {}(self):\n'.format(operation)
input_names = [
to_snake_case(_) for _ in inputs.keys() if _ not in INPUT_IGNORED_IN_BACKEND
]
output_names = [
to_snake_case(_) for _ in outputs.keys() if _ not in OUTPUT_IGNORED_IN_BACKEND
]
body = "\ndef {}(self):\n".format(operation)
for input_name, input_type in inputs.items():
type_name = input_type.type_name
if type_name == 'integer':
if type_name == "integer":
arg_line_tmpl = ' {} = self._get_int_param("{}")\n'
elif type_name == 'list':
elif type_name == "list":
arg_line_tmpl = ' {} = self._get_list_prefix("{}.member")\n'
else:
arg_line_tmpl = ' {} = self._get_param("{}")\n'
body += arg_line_tmpl.format(to_snake_case(input_name), input_name)
if output_names:
body += ' {} = self.{}_backend.{}(\n'.format(', '.join(output_names), get_escaped_service(service), operation)
else:
body += ' self.{}_backend.{}(\n'.format(get_escaped_service(service), operation)
for input_name in input_names:
body += ' {}={},\n'.format(input_name, input_name)
body += ' )\n'
if protocol == 'query':
body += ' template = self.response_template({}_TEMPLATE)\n'.format(operation.upper())
body += ' return template.render({})\n'.format(
', '.join(['{}={}'.format(_, _) for _ in output_names])
body += " {} = self.{}_backend.{}(\n".format(
", ".join(output_names), get_escaped_service(service), operation
)
else:
body += " self.{}_backend.{}(\n".format(
get_escaped_service(service), operation
)
for input_name in input_names:
body += " {}={},\n".format(input_name, input_name)
body += " )\n"
if protocol == "query":
body += " template = self.response_template({}_TEMPLATE)\n".format(
operation.upper()
)
body += " return template.render({})\n".format(
", ".join(["{}={}".format(_, _) for _ in output_names])
)
elif protocol in ["json", "rest-json"]:
body += " # TODO: adjust response\n"
body += " return json.dumps(dict({}))\n".format(
", ".join(["{}={}".format(to_lower_camel_case(_), _) for _ in output_names])
)
elif protocol in ['json', 'rest-json']:
body += ' # TODO: adjust response\n'
body += ' return json.dumps(dict({}))\n'.format(', '.join(['{}={}'.format(to_lower_camel_case(_), _) for _ in output_names]))
return body
@ -273,44 +300,55 @@ def get_function_in_models(service, operation):
client = boto3.client(service)
aws_operation_name = get_operation_name_in_keys(
to_upper_camel_case(operation),
list(client._service_model._service_description['operations'].keys())
list(client._service_model._service_description["operations"].keys()),
)
op_model = client._service_model.operation_model(aws_operation_name)
inputs = op_model.input_shape.members
if not hasattr(op_model.output_shape, 'members'):
if not hasattr(op_model.output_shape, "members"):
outputs = {}
else:
outputs = op_model.output_shape.members
input_names = [to_snake_case(_) for _ in inputs.keys() if _ not in INPUT_IGNORED_IN_BACKEND]
output_names = [to_snake_case(_) for _ in outputs.keys() if _ not in OUTPUT_IGNORED_IN_BACKEND]
input_names = [
to_snake_case(_) for _ in inputs.keys() if _ not in INPUT_IGNORED_IN_BACKEND
]
output_names = [
to_snake_case(_) for _ in outputs.keys() if _ not in OUTPUT_IGNORED_IN_BACKEND
]
if input_names:
body = 'def {}(self, {}):\n'.format(operation, ', '.join(input_names))
body = "def {}(self, {}):\n".format(operation, ", ".join(input_names))
else:
body = 'def {}(self)\n'
body += ' # implement here\n'
body += ' return {}\n\n'.format(', '.join(output_names))
body = "def {}(self)\n"
body += " # implement here\n"
body += " return {}\n\n".format(", ".join(output_names))
return body
def _get_subtree(name, shape, replace_list, name_prefix=[]):
class_name = shape.__class__.__name__
if class_name in ('StringShape', 'Shape'):
if class_name in ("StringShape", "Shape"):
t = etree.Element(name)
if name_prefix:
t.text = '{{ %s.%s }}' % (name_prefix[-1], to_snake_case(name))
t.text = "{{ %s.%s }}" % (name_prefix[-1], to_snake_case(name))
else:
t.text = '{{ %s }}' % to_snake_case(name)
t.text = "{{ %s }}" % to_snake_case(name)
return t
elif class_name in ('ListShape', ):
elif class_name in ("ListShape",):
replace_list.append((name, name_prefix))
t = etree.Element(name)
t_member = etree.Element('member')
t_member = etree.Element("member")
t.append(t_member)
for nested_name, nested_shape in shape.member.members.items():
t_member.append(_get_subtree(nested_name, nested_shape, replace_list, name_prefix + [singularize(name.lower())]))
t_member.append(
_get_subtree(
nested_name,
nested_shape,
replace_list,
name_prefix + [singularize(name.lower())],
)
)
return t
raise ValueError('Not supported Shape')
raise ValueError("Not supported Shape")
def get_response_query_template(service, operation):
@ -323,22 +361,22 @@ def get_response_query_template(service, operation):
client = boto3.client(service)
aws_operation_name = get_operation_name_in_keys(
to_upper_camel_case(operation),
list(client._service_model._service_description['operations'].keys())
list(client._service_model._service_description["operations"].keys()),
)
op_model = client._service_model.operation_model(aws_operation_name)
result_wrapper = op_model.output_shape.serialization['resultWrapper']
response_wrapper = result_wrapper.replace('Result', 'Response')
result_wrapper = op_model.output_shape.serialization["resultWrapper"]
response_wrapper = result_wrapper.replace("Result", "Response")
metadata = op_model.metadata
xml_namespace = metadata['xmlNamespace']
xml_namespace = metadata["xmlNamespace"]
# build xml tree
t_root = etree.Element(response_wrapper, xmlns=xml_namespace)
t_root = etree.Element(response_wrapper, xmlns=xml_namespace)
# build metadata
t_metadata = etree.Element('ResponseMetadata')
t_request_id = etree.Element('RequestId')
t_request_id.text = '1549581b-12b7-11e3-895e-1334aEXAMPLE'
t_metadata = etree.Element("ResponseMetadata")
t_request_id = etree.Element("RequestId")
t_request_id.text = "1549581b-12b7-11e3-895e-1334aEXAMPLE"
t_metadata.append(t_request_id)
t_root.append(t_metadata)
@ -349,68 +387,73 @@ def get_response_query_template(service, operation):
for output_name, output_shape in outputs.items():
t_result.append(_get_subtree(output_name, output_shape, replace_list))
t_root.append(t_result)
xml_body = etree.tostring(t_root, pretty_print=True).decode('utf-8')
xml_body = etree.tostring(t_root, pretty_print=True).decode("utf-8")
xml_body_lines = xml_body.splitlines()
for replace in replace_list:
name = replace[0]
prefix = replace[1]
singular_name = singularize(name)
start_tag = '<%s>' % name
iter_name = '{}.{}'.format(prefix[-1], name.lower())if prefix else name.lower()
loop_start = '{%% for %s in %s %%}' % (singular_name.lower(), iter_name)
end_tag = '</%s>' % name
loop_end = '{{ endfor }}'
start_tag = "<%s>" % name
iter_name = "{}.{}".format(prefix[-1], name.lower()) if prefix else name.lower()
loop_start = "{%% for %s in %s %%}" % (singular_name.lower(), iter_name)
end_tag = "</%s>" % name
loop_end = "{{ endfor }}"
start_tag_indexes = [i for i, l in enumerate(xml_body_lines) if start_tag in l]
if len(start_tag_indexes) != 1:
raise Exception('tag %s not found in response body' % start_tag)
raise Exception("tag %s not found in response body" % start_tag)
start_tag_index = start_tag_indexes[0]
xml_body_lines.insert(start_tag_index + 1, loop_start)
end_tag_indexes = [i for i, l in enumerate(xml_body_lines) if end_tag in l]
if len(end_tag_indexes) != 1:
raise Exception('tag %s not found in response body' % end_tag)
raise Exception("tag %s not found in response body" % end_tag)
end_tag_index = end_tag_indexes[0]
xml_body_lines.insert(end_tag_index, loop_end)
xml_body = '\n'.join(xml_body_lines)
xml_body = "\n".join(xml_body_lines)
body = '\n{}_TEMPLATE = """{}"""'.format(operation.upper(), xml_body)
return body
def insert_code_to_class(path, base_class, new_code):
with open(path) as f:
lines = [_.replace('\n', '') for _ in f.readlines()]
mod_path = os.path.splitext(path)[0].replace('/', '.')
lines = [_.replace("\n", "") for _ in f.readlines()]
mod_path = os.path.splitext(path)[0].replace("/", ".")
mod = importlib.import_module(mod_path)
clsmembers = inspect.getmembers(mod, inspect.isclass)
_response_cls = [_[1] for _ in clsmembers if issubclass(_[1], base_class) and _[1] != base_class]
_response_cls = [
_[1] for _ in clsmembers if issubclass(_[1], base_class) and _[1] != base_class
]
if len(_response_cls) != 1:
raise Exception('unknown error, number of clsmembers is not 1')
raise Exception("unknown error, number of clsmembers is not 1")
response_cls = _response_cls[0]
code_lines, line_no = inspect.getsourcelines(response_cls)
end_line_no = line_no + len(code_lines)
func_lines = [' ' * 4 + _ for _ in new_code.splitlines()]
func_lines = [" " * 4 + _ for _ in new_code.splitlines()]
lines = lines[:end_line_no] + func_lines + lines[end_line_no:]
body = '\n'.join(lines) + '\n'
with open(path, 'w') as f:
body = "\n".join(lines) + "\n"
with open(path, "w") as f:
f.write(body)
def insert_url(service, operation, api_protocol):
client = boto3.client(service)
service_class = client.__class__.__name__
aws_operation_name = get_operation_name_in_keys(
to_upper_camel_case(operation),
list(client._service_model._service_description['operations'].keys())
list(client._service_model._service_description["operations"].keys()),
)
uri = client._service_model.operation_model(aws_operation_name).http['requestUri']
uri = client._service_model.operation_model(aws_operation_name).http["requestUri"]
path = os.path.join(os.path.dirname(__file__), '..', 'moto', get_escaped_service(service), 'urls.py')
path = os.path.join(
os.path.dirname(__file__), "..", "moto", get_escaped_service(service), "urls.py"
)
with open(path) as f:
lines = [_.replace('\n', '') for _ in f.readlines()]
lines = [_.replace("\n", "") for _ in f.readlines()]
if any(_ for _ in lines if re.match(uri, _)):
return
@ -418,50 +461,49 @@ def insert_url(service, operation, api_protocol):
url_paths_found = False
last_elem_line_index = -1
for i, line in enumerate(lines):
if line.startswith('url_paths'):
if line.startswith("url_paths"):
url_paths_found = True
if url_paths_found and line.startswith('}'):
if url_paths_found and line.startswith("}"):
last_elem_line_index = i - 1
prev_line = lines[last_elem_line_index]
if not prev_line.endswith('{') and not prev_line.endswith(','):
lines[last_elem_line_index] += ','
if not prev_line.endswith("{") and not prev_line.endswith(","):
lines[last_elem_line_index] += ","
# generate url pattern
if api_protocol == 'rest-json':
if api_protocol == "rest-json":
new_line = " '{0}/.*$': response.dispatch,"
else:
new_line = " '{0}%s$': %sResponse.dispatch," % (
uri, service_class
)
new_line = " '{0}%s$': %sResponse.dispatch," % (uri, service_class)
if new_line in lines:
return
lines.insert(last_elem_line_index + 1, new_line)
body = '\n'.join(lines) + '\n'
with open(path, 'w') as f:
body = "\n".join(lines) + "\n"
with open(path, "w") as f:
f.write(body)
def insert_codes(service, operation, api_protocol):
func_in_responses = get_function_in_responses(service, operation, api_protocol)
func_in_models = get_function_in_models(service, operation)
# edit responses.py
responses_path = 'moto/{}/responses.py'.format(get_escaped_service(service))
print_progress('inserting code', responses_path, 'green')
responses_path = "moto/{}/responses.py".format(get_escaped_service(service))
print_progress("inserting code", responses_path, "green")
insert_code_to_class(responses_path, BaseResponse, func_in_responses)
# insert template
if api_protocol == 'query':
if api_protocol == "query":
template = get_response_query_template(service, operation)
with open(responses_path) as f:
lines = [_[:-1] for _ in f.readlines()]
lines += template.splitlines()
with open(responses_path, 'w') as f:
f.write('\n'.join(lines))
with open(responses_path, "w") as f:
f.write("\n".join(lines))
# edit models.py
models_path = 'moto/{}/models.py'.format(get_escaped_service(service))
print_progress('inserting code', models_path, 'green')
models_path = "moto/{}/models.py".format(get_escaped_service(service))
print_progress("inserting code", models_path, "green")
insert_code_to_class(models_path, BaseBackend, func_in_models)
# edit urls.py
@ -471,15 +513,20 @@ def insert_codes(service, operation, api_protocol):
@click.command()
def main():
service, operation = select_service_and_operation()
api_protocol = boto3.client(service)._service_model.metadata['protocol']
api_protocol = boto3.client(service)._service_model.metadata["protocol"]
initialize_service(service, operation, api_protocol)
if api_protocol in ['query', 'json', 'rest-json']:
if api_protocol in ["query", "json", "rest-json"]:
insert_codes(service, operation, api_protocol)
else:
print_progress('skip inserting code', 'api protocol "{}" is not supported'.format(api_protocol), 'yellow')
print_progress(
"skip inserting code",
'api protocol "{}" is not supported'.format(api_protocol),
"yellow",
)
click.echo('You will still need to add the mock into "__init__.py"'.format(service))
if __name__ == '__main__':
if __name__ == "__main__":
main()

View file

@ -23,42 +23,53 @@ def json_serial(obj):
raise TypeError("Type not serializable")
client = boto3.client('iam')
client = boto3.client("iam")
policies = {}
paginator = client.get_paginator('list_policies')
paginator = client.get_paginator("list_policies")
try:
response_iterator = paginator.paginate(Scope='AWS')
response_iterator = paginator.paginate(Scope="AWS")
for response in response_iterator:
for policy in response['Policies']:
policies[policy['PolicyName']] = policy
for policy in response["Policies"]:
policies[policy["PolicyName"]] = policy
except NoCredentialsError:
print("USAGE:")
print("Put your AWS credentials into ~/.aws/credentials and run:")
print(__file__)
print("")
print("Or specify them on the command line:")
print("AWS_ACCESS_KEY_ID=your_personal_access_key AWS_SECRET_ACCESS_KEY=your_personal_secret {}".format(__file__))
print(
"AWS_ACCESS_KEY_ID=your_personal_access_key AWS_SECRET_ACCESS_KEY=your_personal_secret {}".format(
__file__
)
)
print("")
sys.exit(1)
for policy_name in policies:
response = client.get_policy_version(
PolicyArn=policies[policy_name]['Arn'],
VersionId=policies[policy_name]['DefaultVersionId'])
for key in response['PolicyVersion']:
if key != "CreateDate": # the policy's CreateDate should not be overwritten by its version's CreateDate
policies[policy_name][key] = response['PolicyVersion'][key]
PolicyArn=policies[policy_name]["Arn"],
VersionId=policies[policy_name]["DefaultVersionId"],
)
for key in response["PolicyVersion"]:
if (
key != "CreateDate"
): # the policy's CreateDate should not be overwritten by its version's CreateDate
policies[policy_name][key] = response["PolicyVersion"][key]
with open(output_file, 'w') as f:
triple_quote = '\"\"\"'
with open(output_file, "w") as f:
triple_quote = '"""'
f.write("# Imported via `make aws_managed_policies`\n")
f.write('aws_managed_policies_data = {}\n'.format(triple_quote))
f.write(json.dumps(policies,
sort_keys=True,
indent=4,
separators=(',', ': '),
default=json_serial))
f.write('{}\n'.format(triple_quote))
f.write("aws_managed_policies_data = {}\n".format(triple_quote))
f.write(
json.dumps(
policies,
sort_keys=True,
indent=4,
separators=(",", ": "),
default=json_serial,
)
)
f.write("{}\n".format(triple_quote))