Run black on moto & test directories.

This commit is contained in:
Asher Foa 2019-10-31 08:44:26 -07:00
commit 96e5b1993d
507 changed files with 52541 additions and 47814 deletions

View file

@ -2,6 +2,6 @@ from __future__ import unicode_literals
from .models import rds_backends
from ..core.models import base_decorator, deprecated_base_decorator
rds_backend = rds_backends['us-east-1']
rds_backend = rds_backends["us-east-1"]
mock_rds = base_decorator(rds_backends)
mock_rds_deprecated = deprecated_base_decorator(rds_backends)

View file

@ -5,38 +5,34 @@ from werkzeug.exceptions import BadRequest
class RDSClientError(BadRequest):
def __init__(self, code, message):
super(RDSClientError, self).__init__()
self.description = json.dumps({
"Error": {
"Code": code,
"Message": message,
'Type': 'Sender',
},
'RequestId': '6876f774-7273-11e4-85dc-39e55ca848d1',
})
self.description = json.dumps(
{
"Error": {"Code": code, "Message": message, "Type": "Sender"},
"RequestId": "6876f774-7273-11e4-85dc-39e55ca848d1",
}
)
class DBInstanceNotFoundError(RDSClientError):
def __init__(self, database_identifier):
super(DBInstanceNotFoundError, self).__init__(
'DBInstanceNotFound',
"Database {0} not found.".format(database_identifier))
"DBInstanceNotFound", "Database {0} not found.".format(database_identifier)
)
class DBSecurityGroupNotFoundError(RDSClientError):
def __init__(self, security_group_name):
super(DBSecurityGroupNotFoundError, self).__init__(
'DBSecurityGroupNotFound',
"Security Group {0} not found.".format(security_group_name))
"DBSecurityGroupNotFound",
"Security Group {0} not found.".format(security_group_name),
)
class DBSubnetGroupNotFoundError(RDSClientError):
def __init__(self, subnet_group_name):
super(DBSubnetGroupNotFoundError, self).__init__(
'DBSubnetGroupNotFound',
"Subnet Group {0} not found.".format(subnet_group_name))
"DBSubnetGroupNotFound",
"Subnet Group {0} not found.".format(subnet_group_name),
)

View file

@ -11,33 +11,34 @@ from moto.rds2.models import rds2_backends
class Database(BaseModel):
def get_cfn_attribute(self, attribute_name):
if attribute_name == 'Endpoint.Address':
if attribute_name == "Endpoint.Address":
return self.address
elif attribute_name == 'Endpoint.Port':
elif attribute_name == "Endpoint.Port":
return self.port
raise UnformattedGetAttTemplateException()
@classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name):
properties = cloudformation_json['Properties']
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
properties = cloudformation_json["Properties"]
db_instance_identifier = properties.get('DBInstanceIdentifier')
db_instance_identifier = properties.get("DBInstanceIdentifier")
if not db_instance_identifier:
db_instance_identifier = resource_name.lower() + get_random_hex(12)
db_security_groups = properties.get('DBSecurityGroups')
db_security_groups = properties.get("DBSecurityGroups")
if not db_security_groups:
db_security_groups = []
security_groups = [group.group_name for group in db_security_groups]
db_subnet_group = properties.get("DBSubnetGroupName")
db_subnet_group_name = db_subnet_group.subnet_name if db_subnet_group else None
db_kwargs = {
"auto_minor_version_upgrade": properties.get('AutoMinorVersionUpgrade'),
"allocated_storage": properties.get('AllocatedStorage'),
"auto_minor_version_upgrade": properties.get("AutoMinorVersionUpgrade"),
"allocated_storage": properties.get("AllocatedStorage"),
"availability_zone": properties.get("AvailabilityZone"),
"backup_retention_period": properties.get("BackupRetentionPeriod"),
"db_instance_class": properties.get('DBInstanceClass'),
"db_instance_class": properties.get("DBInstanceClass"),
"db_instance_identifier": db_instance_identifier,
"db_name": properties.get("DBName"),
"db_subnet_group_name": db_subnet_group_name,
@ -45,10 +46,10 @@ class Database(BaseModel):
"engine_version": properties.get("EngineVersion"),
"iops": properties.get("Iops"),
"kms_key_id": properties.get("KmsKeyId"),
"master_password": properties.get('MasterUserPassword'),
"master_username": properties.get('MasterUsername'),
"master_password": properties.get("MasterUserPassword"),
"master_username": properties.get("MasterUsername"),
"multi_az": properties.get("MultiAZ"),
"port": properties.get('Port', 3306),
"port": properties.get("Port", 3306),
"publicly_accessible": properties.get("PubliclyAccessible"),
"copy_tags_to_snapshot": properties.get("CopyTagsToSnapshot"),
"region": region_name,
@ -69,7 +70,8 @@ class Database(BaseModel):
return database
def to_xml(self):
template = Template("""<DBInstance>
template = Template(
"""<DBInstance>
<BackupRetentionPeriod>{{ database.backup_retention_period }}</BackupRetentionPeriod>
<DBInstanceStatus>{{ database.status }}</DBInstanceStatus>
<MultiAZ>{{ database.multi_az }}</MultiAZ>
@ -152,7 +154,8 @@ class Database(BaseModel):
<Port>{{ database.port }}</Port>
</Endpoint>
<DBInstanceArn>{{ database.db_instance_arn }}</DBInstanceArn>
</DBInstance>""")
</DBInstance>"""
)
return template.render(database=self)
def delete(self, region_name):
@ -161,7 +164,6 @@ class Database(BaseModel):
class SecurityGroup(BaseModel):
def __init__(self, group_name, description):
self.group_name = group_name
self.description = description
@ -170,7 +172,8 @@ class SecurityGroup(BaseModel):
self.ec2_security_groups = []
def to_xml(self):
template = Template("""<DBSecurityGroup>
template = Template(
"""<DBSecurityGroup>
<EC2SecurityGroups>
{% for security_group in security_group.ec2_security_groups %}
<EC2SecurityGroup>
@ -193,7 +196,8 @@ class SecurityGroup(BaseModel):
</IPRanges>
<OwnerId>{{ security_group.ownder_id }}</OwnerId>
<DBSecurityGroupName>{{ security_group.group_name }}</DBSecurityGroupName>
</DBSecurityGroup>""")
</DBSecurityGroup>"""
)
return template.render(security_group=self)
def authorize_cidr(self, cidr_ip):
@ -203,20 +207,19 @@ class SecurityGroup(BaseModel):
self.ec2_security_groups.append(security_group)
@classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name):
properties = cloudformation_json['Properties']
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
properties = cloudformation_json["Properties"]
group_name = resource_name.lower() + get_random_hex(12)
description = properties['GroupDescription']
security_group_ingress_rules = properties.get(
'DBSecurityGroupIngress', [])
tags = properties.get('Tags')
description = properties["GroupDescription"]
security_group_ingress_rules = properties.get("DBSecurityGroupIngress", [])
tags = properties.get("Tags")
ec2_backend = ec2_backends[region_name]
rds_backend = rds_backends[region_name]
security_group = rds_backend.create_security_group(
group_name,
description,
tags,
group_name, description, tags
)
for security_group_ingress in security_group_ingress_rules:
@ -224,12 +227,10 @@ class SecurityGroup(BaseModel):
if ingress_type == "CIDRIP":
security_group.authorize_cidr(ingress_value)
elif ingress_type == "EC2SecurityGroupName":
subnet = ec2_backend.get_security_group_from_name(
ingress_value)
subnet = ec2_backend.get_security_group_from_name(ingress_value)
security_group.authorize_security_group(subnet)
elif ingress_type == "EC2SecurityGroupId":
subnet = ec2_backend.get_security_group_from_id(
ingress_value)
subnet = ec2_backend.get_security_group_from_id(ingress_value)
security_group.authorize_security_group(subnet)
return security_group
@ -239,7 +240,6 @@ class SecurityGroup(BaseModel):
class SubnetGroup(BaseModel):
def __init__(self, subnet_name, description, subnets):
self.subnet_name = subnet_name
self.description = description
@ -249,7 +249,8 @@ class SubnetGroup(BaseModel):
self.vpc_id = self.subnets[0].vpc_id
def to_xml(self):
template = Template("""<DBSubnetGroup>
template = Template(
"""<DBSubnetGroup>
<VpcId>{{ subnet_group.vpc_id }}</VpcId>
<SubnetGroupStatus>{{ subnet_group.status }}</SubnetGroupStatus>
<DBSubnetGroupDescription>{{ subnet_group.description }}</DBSubnetGroupDescription>
@ -266,27 +267,26 @@ class SubnetGroup(BaseModel):
</Subnet>
{% endfor %}
</Subnets>
</DBSubnetGroup>""")
</DBSubnetGroup>"""
)
return template.render(subnet_group=self)
@classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name):
properties = cloudformation_json['Properties']
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
properties = cloudformation_json["Properties"]
subnet_name = resource_name.lower() + get_random_hex(12)
description = properties['DBSubnetGroupDescription']
subnet_ids = properties['SubnetIds']
tags = properties.get('Tags')
description = properties["DBSubnetGroupDescription"]
subnet_ids = properties["SubnetIds"]
tags = properties.get("Tags")
ec2_backend = ec2_backends[region_name]
subnets = [ec2_backend.get_subnet(subnet_id)
for subnet_id in subnet_ids]
subnets = [ec2_backend.get_subnet(subnet_id) for subnet_id in subnet_ids]
rds_backend = rds_backends[region_name]
subnet_group = rds_backend.create_subnet_group(
subnet_name,
description,
subnets,
tags,
subnet_name, description, subnets, tags
)
return subnet_group
@ -296,7 +296,6 @@ class SubnetGroup(BaseModel):
class RDSBackend(BaseBackend):
def __init__(self, region):
self.region = region
@ -314,5 +313,6 @@ class RDSBackend(BaseBackend):
return rds2_backends[self.region]
rds_backends = dict((region.name, RDSBackend(region.name))
for region in boto.rds.regions())
rds_backends = dict(
(region.name, RDSBackend(region.name)) for region in boto.rds.regions()
)

View file

@ -6,19 +6,18 @@ from .models import rds_backends
class RDSResponse(BaseResponse):
@property
def backend(self):
return rds_backends[self.region]
def _get_db_kwargs(self):
args = {
"auto_minor_version_upgrade": self._get_param('AutoMinorVersionUpgrade'),
"allocated_storage": self._get_int_param('AllocatedStorage'),
"auto_minor_version_upgrade": self._get_param("AutoMinorVersionUpgrade"),
"allocated_storage": self._get_int_param("AllocatedStorage"),
"availability_zone": self._get_param("AvailabilityZone"),
"backup_retention_period": self._get_param("BackupRetentionPeriod"),
"db_instance_class": self._get_param('DBInstanceClass'),
"db_instance_identifier": self._get_param('DBInstanceIdentifier'),
"db_instance_class": self._get_param("DBInstanceClass"),
"db_instance_identifier": self._get_param("DBInstanceIdentifier"),
"db_name": self._get_param("DBName"),
# DBParameterGroupName
"db_subnet_group_name": self._get_param("DBSubnetGroupName"),
@ -26,48 +25,48 @@ class RDSResponse(BaseResponse):
"engine_version": self._get_param("EngineVersion"),
"iops": self._get_int_param("Iops"),
"kms_key_id": self._get_param("KmsKeyId"),
"master_password": self._get_param('MasterUserPassword'),
"master_username": self._get_param('MasterUsername'),
"master_password": self._get_param("MasterUserPassword"),
"master_username": self._get_param("MasterUsername"),
"multi_az": self._get_bool_param("MultiAZ"),
# OptionGroupName
"port": self._get_param('Port'),
"port": self._get_param("Port"),
# PreferredBackupWindow
# PreferredMaintenanceWindow
"publicly_accessible": self._get_param("PubliclyAccessible"),
"region": self.region,
"security_groups": self._get_multi_param('DBSecurityGroups.member'),
"security_groups": self._get_multi_param("DBSecurityGroups.member"),
"storage_encrypted": self._get_param("StorageEncrypted"),
"storage_type": self._get_param("StorageType"),
# VpcSecurityGroupIds.member.N
"tags": list(),
}
args['tags'] = self.unpack_complex_list_params(
'Tags.Tag', ('Key', 'Value'))
args["tags"] = self.unpack_complex_list_params("Tags.Tag", ("Key", "Value"))
return args
def _get_db_replica_kwargs(self):
return {
"auto_minor_version_upgrade": self._get_param('AutoMinorVersionUpgrade'),
"auto_minor_version_upgrade": self._get_param("AutoMinorVersionUpgrade"),
"availability_zone": self._get_param("AvailabilityZone"),
"db_instance_class": self._get_param('DBInstanceClass'),
"db_instance_identifier": self._get_param('DBInstanceIdentifier'),
"db_instance_class": self._get_param("DBInstanceClass"),
"db_instance_identifier": self._get_param("DBInstanceIdentifier"),
"db_subnet_group_name": self._get_param("DBSubnetGroupName"),
"iops": self._get_int_param("Iops"),
# OptionGroupName
"port": self._get_param('Port'),
"port": self._get_param("Port"),
"publicly_accessible": self._get_param("PubliclyAccessible"),
"source_db_identifier": self._get_param('SourceDBInstanceIdentifier'),
"source_db_identifier": self._get_param("SourceDBInstanceIdentifier"),
"storage_type": self._get_param("StorageType"),
}
def unpack_complex_list_params(self, label, names):
unpacked_list = list()
count = 1
while self._get_param('{0}.{1}.{2}'.format(label, count, names[0])):
while self._get_param("{0}.{1}.{2}".format(label, count, names[0])):
param = dict()
for i in range(len(names)):
param[names[i]] = self._get_param(
'{0}.{1}.{2}'.format(label, count, names[i]))
"{0}.{1}.{2}".format(label, count, names[i])
)
unpacked_list.append(param)
count += 1
return unpacked_list
@ -87,16 +86,18 @@ class RDSResponse(BaseResponse):
return template.render(database=database)
def describe_db_instances(self):
db_instance_identifier = self._get_param('DBInstanceIdentifier')
db_instance_identifier = self._get_param("DBInstanceIdentifier")
all_instances = list(self.backend.describe_databases(db_instance_identifier))
marker = self._get_param('Marker')
marker = self._get_param("Marker")
all_ids = [instance.db_instance_identifier for instance in all_instances]
if marker:
start = all_ids.index(marker) + 1
else:
start = 0
page_size = self._get_int_param('MaxRecords', 50) # the default is 100, but using 50 to make testing easier
instances_resp = all_instances[start:start + page_size]
page_size = self._get_int_param(
"MaxRecords", 50
) # the default is 100, but using 50 to make testing easier
instances_resp = all_instances[start : start + page_size]
next_marker = None
if len(all_instances) > start + page_size:
next_marker = instances_resp[-1].db_instance_identifier
@ -105,73 +106,74 @@ class RDSResponse(BaseResponse):
return template.render(databases=instances_resp, marker=next_marker)
def modify_db_instance(self):
db_instance_identifier = self._get_param('DBInstanceIdentifier')
db_instance_identifier = self._get_param("DBInstanceIdentifier")
db_kwargs = self._get_db_kwargs()
new_db_instance_identifier = self._get_param('NewDBInstanceIdentifier')
new_db_instance_identifier = self._get_param("NewDBInstanceIdentifier")
if new_db_instance_identifier:
db_kwargs['new_db_instance_identifier'] = new_db_instance_identifier
database = self.backend.modify_database(
db_instance_identifier, db_kwargs)
db_kwargs["new_db_instance_identifier"] = new_db_instance_identifier
database = self.backend.modify_database(db_instance_identifier, db_kwargs)
template = self.response_template(MODIFY_DATABASE_TEMPLATE)
return template.render(database=database)
def delete_db_instance(self):
db_instance_identifier = self._get_param('DBInstanceIdentifier')
db_instance_identifier = self._get_param("DBInstanceIdentifier")
database = self.backend.delete_database(db_instance_identifier)
template = self.response_template(DELETE_DATABASE_TEMPLATE)
return template.render(database=database)
def create_db_security_group(self):
group_name = self._get_param('DBSecurityGroupName')
description = self._get_param('DBSecurityGroupDescription')
tags = self.unpack_complex_list_params('Tags.Tag', ('Key', 'Value'))
group_name = self._get_param("DBSecurityGroupName")
description = self._get_param("DBSecurityGroupDescription")
tags = self.unpack_complex_list_params("Tags.Tag", ("Key", "Value"))
security_group = self.backend.create_security_group(
group_name, description, tags)
group_name, description, tags
)
template = self.response_template(CREATE_SECURITY_GROUP_TEMPLATE)
return template.render(security_group=security_group)
def describe_db_security_groups(self):
security_group_name = self._get_param('DBSecurityGroupName')
security_groups = self.backend.describe_security_groups(
security_group_name)
security_group_name = self._get_param("DBSecurityGroupName")
security_groups = self.backend.describe_security_groups(security_group_name)
template = self.response_template(DESCRIBE_SECURITY_GROUPS_TEMPLATE)
return template.render(security_groups=security_groups)
def delete_db_security_group(self):
security_group_name = self._get_param('DBSecurityGroupName')
security_group = self.backend.delete_security_group(
security_group_name)
security_group_name = self._get_param("DBSecurityGroupName")
security_group = self.backend.delete_security_group(security_group_name)
template = self.response_template(DELETE_SECURITY_GROUP_TEMPLATE)
return template.render(security_group=security_group)
def authorize_db_security_group_ingress(self):
security_group_name = self._get_param('DBSecurityGroupName')
cidr_ip = self._get_param('CIDRIP')
security_group_name = self._get_param("DBSecurityGroupName")
cidr_ip = self._get_param("CIDRIP")
security_group = self.backend.authorize_security_group(
security_group_name, cidr_ip)
security_group_name, cidr_ip
)
template = self.response_template(AUTHORIZE_SECURITY_GROUP_TEMPLATE)
return template.render(security_group=security_group)
def create_db_subnet_group(self):
subnet_name = self._get_param('DBSubnetGroupName')
description = self._get_param('DBSubnetGroupDescription')
subnet_ids = self._get_multi_param('SubnetIds.member')
subnets = [ec2_backends[self.region].get_subnet(
subnet_id) for subnet_id in subnet_ids]
tags = self.unpack_complex_list_params('Tags.Tag', ('Key', 'Value'))
subnet_name = self._get_param("DBSubnetGroupName")
description = self._get_param("DBSubnetGroupDescription")
subnet_ids = self._get_multi_param("SubnetIds.member")
subnets = [
ec2_backends[self.region].get_subnet(subnet_id) for subnet_id in subnet_ids
]
tags = self.unpack_complex_list_params("Tags.Tag", ("Key", "Value"))
subnet_group = self.backend.create_subnet_group(
subnet_name, description, subnets, tags)
subnet_name, description, subnets, tags
)
template = self.response_template(CREATE_SUBNET_GROUP_TEMPLATE)
return template.render(subnet_group=subnet_group)
def describe_db_subnet_groups(self):
subnet_name = self._get_param('DBSubnetGroupName')
subnet_name = self._get_param("DBSubnetGroupName")
subnet_groups = self.backend.describe_subnet_groups(subnet_name)
template = self.response_template(DESCRIBE_SUBNET_GROUPS_TEMPLATE)
return template.render(subnet_groups=subnet_groups)
def delete_db_subnet_group(self):
subnet_name = self._get_param('DBSubnetGroupName')
subnet_name = self._get_param("DBSubnetGroupName")
subnet_group = self.backend.delete_subnet_group(subnet_name)
template = self.response_template(DELETE_SUBNET_GROUP_TEMPLATE)
return template.render(subnet_group=subnet_group)

View file

@ -1,10 +1,6 @@
from __future__ import unicode_literals
from .responses import RDSResponse
url_bases = [
"https?://rds(\..+)?.amazonaws.com",
]
url_bases = ["https?://rds(\..+)?.amazonaws.com"]
url_paths = {
'{0}/$': RDSResponse.dispatch,
}
url_paths = {"{0}/$": RDSResponse.dispatch}