Implementing RDS Snapshots

This commit is contained in:
Jack Danger Canty 2017-06-20 13:51:25 -07:00
commit 63f01039c3
4 changed files with 221 additions and 2 deletions

View file

@ -1,6 +1,7 @@
from __future__ import unicode_literals
import copy
import datetime
from collections import defaultdict
import boto.rds2
@ -10,9 +11,11 @@ from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
from moto.compat import OrderedDict
from moto.core import BaseBackend, BaseModel
from moto.core.utils import get_random_hex
from moto.core.utils import iso_8601_datetime_with_milliseconds
from moto.ec2.models import ec2_backends
from .exceptions import (RDSClientError,
DBInstanceNotFoundError,
DBSnapshotNotFoundError,
DBSecurityGroupNotFoundError,
DBSubnetGroupNotFoundError,
DBParameterGroupNotFoundError)
@ -205,7 +208,7 @@ class Database(BaseModel):
{% endif %}
{% if database.iops %}
<Iops>{{ database.iops }}</Iops>
<StorageType>io1</StorageType>
<StorageType>standard</StorageType>
{% else %}
<StorageType>{{ database.storage_type }}</StorageType>
{% endif %}
@ -399,6 +402,53 @@ class Database(BaseModel):
backend.delete_database(self.db_instance_identifier)
class Snapshot(BaseModel):
def __init__(self, database, snapshot_id, tags):
self.database = database
self.snapshot_id = snapshot_id
self.tags = tags
self.created_at = iso_8601_datetime_with_milliseconds(datetime.datetime.now())
@property
def snapshot_arn(self):
return "arn:aws:rds:{0}:1234567890:snapshot:{1}".format(self.database.region, self.snapshot_id)
def to_xml(self):
template = Template("""<DBSnapshot>
<DBSnapshotIdentifier>{{ snapshot.snapshot_id }}</DBSnapshotIdentifier>
<DBInstanceIdentifier>{{ database.db_instance_identifier }}</DBInstanceIdentifier>
<SnapshotCreateTime>{{ snapshot.created_at }}</SnapshotCreateTime>
<Engine>{{ database.engine }}</Engine>
<AllocatedStorage>{{ database.allocated_storage }}</AllocatedStorage>
<Status>available</Status>
<Port>{{ database.port }}</Port>
<AvailabilityZone>{{ database.availability_zone }}</AvailabilityZone>
<VpcId>{{ database.db_subnet_group.vpc_id }}</VpcId>
<InstanceCreateTime>{{ snapshot.created_at }}</InstanceCreateTime>
<MasterUsername>{{ database.master_username }}</MasterUsername>
<EngineVersion>{{ database.engine_version }}</EngineVersion>
<LicenseModel>general-public-license</LicenseModel>
<SnapshotType>manual</SnapshotType>
{% if database.iops %}
<Iops>{{ database.iops }}</Iops>
<StorageType>io1</StorageType>
{% else %}
<StorageType>{{ database.storage_type }}</StorageType>
{% endif %}
<OptionGroupName>{{ database.option_group_name }}</OptionGroupName>
<PercentProgress>{{ 100 }}</PercentProgress>
<SourceRegion>{{ database.region }}</SourceRegion>
<SourceDBSnapshotIdentifier></SourceDBSnapshotIdentifier>
<TdeCredentialArn></TdeCredentialArn>
<Encrypted>{{ database.storage_encrypted }}</Encrypted>
<KmsKeyId>{{ database.kms_key_id }}</KmsKeyId>
<DBSnapshotArn>{{ snapshot.snapshot_arn }}</DBSnapshotArn>
<Timezone></Timezone>
<IAMDatabaseAuthenticationEnabled>false</IAMDatabaseAuthenticationEnabled>
</DBSnapshot>""")
return template.render(snapshot=self, database=self.database)
class SecurityGroup(BaseModel):
def __init__(self, group_name, description, tags):
@ -607,6 +657,7 @@ class RDS2Backend(BaseBackend):
self.arn_regex = re_compile(
r'^arn:aws:rds:.*:[0-9]*:(db|es|og|pg|ri|secgrp|snapshot|subgrp):.*$')
self.databases = OrderedDict()
self.snapshots = OrderedDict()
self.db_parameter_groups = {}
self.option_groups = {}
self.security_groups = {}
@ -624,6 +675,20 @@ class RDS2Backend(BaseBackend):
self.databases[database_id] = database
return database
def create_snapshot(self, db_instance_identifier, db_snapshot_identifier, tags):
database = self.databases.get(db_instance_identifier)
if not database:
raise DBInstanceNotFoundError(db_instance_identifier)
snapshot = Snapshot(database, db_snapshot_identifier, tags)
self.snapshots[db_snapshot_identifier] = snapshot
return snapshot
def delete_snapshot(self, db_snapshot_identifier):
if db_snapshot_identifier not in self.snapshots:
raise DBSnapshotNotFoundError()
return self.snapshots.pop(db_snapshot_identifier)
def create_database_replica(self, db_kwargs):
database_id = db_kwargs['db_instance_identifier']
source_database_id = db_kwargs['source_db_identifier']
@ -646,6 +711,20 @@ class RDS2Backend(BaseBackend):
raise DBInstanceNotFoundError(db_instance_identifier)
return self.databases.values()
def describe_snapshots(self, db_instance_identifier, db_snapshot_identifier):
if db_instance_identifier:
for snapshot in self.snapshots.values():
if snapshot.database.db_instance_identifier == db_instance_identifier:
return [snapshot]
raise DBSnapshotNotFoundError()
if db_snapshot_identifier:
if db_snapshot_identifier in self.snapshots:
return [self.snapshots[db_snapshot_identifier]]
raise DBSnapshotNotFoundError()
return self.snapshots.values()
def modify_database(self, db_instance_identifier, db_kwargs):
database = self.describe_databases(db_instance_identifier)[0]
database.update(db_kwargs)