Merge branch 'master' into batch

This commit is contained in:
Terry Cain 2017-09-29 22:11:15 +01:00
commit ea10c4dfb6
No known key found for this signature in database
GPG key ID: 14D90844E4E9B9F3
47 changed files with 2457 additions and 273 deletions

View file

@ -23,10 +23,11 @@ from .elbv2 import mock_elbv2 # flake8: noqa
from .emr import mock_emr, mock_emr_deprecated # flake8: noqa
from .events import mock_events # flake8: noqa
from .glacier import mock_glacier, mock_glacier_deprecated # flake8: noqa
from .opsworks import mock_opsworks, mock_opsworks_deprecated # flake8: noqa
from .iam import mock_iam, mock_iam_deprecated # flake8: noqa
from .kinesis import mock_kinesis, mock_kinesis_deprecated # flake8: noqa
from .kms import mock_kms, mock_kms_deprecated # flake8: noqa
from .opsworks import mock_opsworks, mock_opsworks_deprecated # flake8: noqa
from .polly import mock_polly # flake8: noqa
from .rds import mock_rds, mock_rds_deprecated # flake8: noqa
from .rds2 import mock_rds2, mock_rds2_deprecated # flake8: noqa
from .redshift import mock_redshift, mock_redshift_deprecated # flake8: noqa
@ -39,6 +40,7 @@ from .ssm import mock_ssm # flake8: noqa
from .route53 import mock_route53, mock_route53_deprecated # flake8: noqa
from .swf import mock_swf, mock_swf_deprecated # flake8: noqa
from .xray import mock_xray # flake8: noqa
from .logs import mock_logs, mock_logs_deprecated # flake8: noqa
try:

View file

@ -1,34 +1,150 @@
from __future__ import unicode_literals
import base64
from collections import defaultdict
import datetime
import docker.errors
import hashlib
import io
import logging
import os
import json
import sys
import re
import zipfile
try:
from StringIO import StringIO
except:
from io import StringIO
import uuid
import functools
import tarfile
import calendar
import threading
import traceback
import requests.adapters
import boto.awslambda
from moto.core import BaseBackend, BaseModel
from moto.core.utils import unix_time_millis
from moto.s3.models import s3_backend
from moto.logs.models import logs_backends
from moto.s3.exceptions import MissingBucket, MissingKey
from moto import settings
logger = logging.getLogger(__name__)
try:
from tempfile import TemporaryDirectory
except ImportError:
from backports.tempfile import TemporaryDirectory
_stderr_regex = re.compile(r'START|END|REPORT RequestId: .*')
_orig_adapter_send = requests.adapters.HTTPAdapter.send
def zip2tar(zip_bytes):
with TemporaryDirectory() as td:
tarname = os.path.join(td, 'data.tar')
timeshift = int((datetime.datetime.now() -
datetime.datetime.utcnow()).total_seconds())
with zipfile.ZipFile(io.BytesIO(zip_bytes), 'r') as zipf, \
tarfile.TarFile(tarname, 'w') as tarf:
for zipinfo in zipf.infolist():
if zipinfo.filename[-1] == '/': # is_dir() is py3.6+
continue
tarinfo = tarfile.TarInfo(name=zipinfo.filename)
tarinfo.size = zipinfo.file_size
tarinfo.mtime = calendar.timegm(zipinfo.date_time) - timeshift
infile = zipf.open(zipinfo.filename)
tarf.addfile(tarinfo, infile)
with open(tarname, 'rb') as f:
tar_data = f.read()
return tar_data
class _VolumeRefCount:
__slots__ = "refcount", "volume"
def __init__(self, refcount, volume):
self.refcount = refcount
self.volume = volume
class _DockerDataVolumeContext:
_data_vol_map = defaultdict(lambda: _VolumeRefCount(0, None)) # {sha256: _VolumeRefCount}
_lock = threading.Lock()
def __init__(self, lambda_func):
self._lambda_func = lambda_func
self._vol_ref = None
@property
def name(self):
return self._vol_ref.volume.name
def __enter__(self):
# See if volume is already known
with self.__class__._lock:
self._vol_ref = self.__class__._data_vol_map[self._lambda_func.code_sha_256]
self._vol_ref.refcount += 1
if self._vol_ref.refcount > 1:
return self
# See if the volume already exists
for vol in self._lambda_func.docker_client.volumes.list():
if vol.name == self._lambda_func.code_sha_256:
self._vol_ref.volume = vol
return self
# It doesn't exist so we need to create it
self._vol_ref.volume = self._lambda_func.docker_client.volumes.create(self._lambda_func.code_sha_256)
container = self._lambda_func.docker_client.containers.run('alpine', 'sleep 100', volumes={self.name: '/tmp/data'}, detach=True)
try:
tar_bytes = zip2tar(self._lambda_func.code_bytes)
container.put_archive('/tmp/data', tar_bytes)
finally:
container.remove(force=True)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
with self.__class__._lock:
self._vol_ref.refcount -= 1
if self._vol_ref.refcount == 0:
try:
self._vol_ref.volume.remove()
except docker.errors.APIError as e:
if e.status_code != 409:
raise
raise # multiple processes trying to use same volume?
class LambdaFunction(BaseModel):
def __init__(self, spec, validate_s3=True):
def __init__(self, spec, region, validate_s3=True):
# required
self.region = region
self.code = spec['Code']
self.function_name = spec['FunctionName']
self.handler = spec['Handler']
self.role = spec['Role']
self.run_time = spec['Runtime']
self.logs_backend = logs_backends[self.region]
self.environment_vars = spec.get('Environment', {}).get('Variables', {})
self.docker_client = docker.from_env()
# Unfortunately mocking replaces this method w/o fallback enabled, so we
# need to replace it if we detect it's been mocked
if requests.adapters.HTTPAdapter.send != _orig_adapter_send:
_orig_get_adapter = self.docker_client.api.get_adapter
def replace_adapter_send(*args, **kwargs):
adapter = _orig_get_adapter(*args, **kwargs)
if isinstance(adapter, requests.adapters.HTTPAdapter):
adapter.send = functools.partial(_orig_adapter_send, adapter)
return adapter
self.docker_client.api.get_adapter = replace_adapter_send
# optional
self.description = spec.get('Description', '')
@ -36,13 +152,18 @@ class LambdaFunction(BaseModel):
self.publish = spec.get('Publish', False) # this is ignored currently
self.timeout = spec.get('Timeout', 3)
self.logs_group_name = '/aws/lambda/{}'.format(self.function_name)
self.logs_backend.ensure_log_group(self.logs_group_name, [])
# this isn't finished yet. it needs to find out the VpcId value
self._vpc_config = spec.get(
'VpcConfig', {'SubnetIds': [], 'SecurityGroupIds': []})
# auto-generated
self.version = '$LATEST'
self.last_modified = datetime.datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S')
self.last_modified = datetime.datetime.utcnow().strftime(
'%Y-%m-%d %H:%M:%S')
if 'ZipFile' in self.code:
# more hackery to handle unicode/bytes/str in python3 and python2 -
# argh!
@ -52,12 +173,13 @@ class LambdaFunction(BaseModel):
except Exception:
to_unzip_code = base64.b64decode(self.code['ZipFile'])
zbuffer = io.BytesIO()
zbuffer.write(to_unzip_code)
zip_file = zipfile.ZipFile(zbuffer, 'r', zipfile.ZIP_DEFLATED)
self.code = zip_file.read("".join(zip_file.namelist()))
self.code_bytes = to_unzip_code
self.code_size = len(to_unzip_code)
self.code_sha_256 = hashlib.sha256(to_unzip_code).hexdigest()
# TODO: we should be putting this in a lambda bucket
self.code['UUID'] = str(uuid.uuid4())
self.code['S3Key'] = '{}-{}'.format(self.function_name, self.code['UUID'])
else:
# validate s3 bucket and key
key = None
@ -76,10 +198,12 @@ class LambdaFunction(BaseModel):
"InvalidParameterValueException",
"Error occurred while GetObject. S3 Error Code: NoSuchKey. S3 Error Message: The specified key does not exist.")
if key:
self.code_bytes = key.value
self.code_size = key.size
self.code_sha_256 = hashlib.sha256(key.value).hexdigest()
self.function_arn = 'arn:aws:lambda:123456789012:function:{0}'.format(
self.function_name)
self.function_arn = 'arn:aws:lambda:{}:123456789012:function:{}'.format(
self.region, self.function_name)
self.tags = dict()
@ -94,7 +218,7 @@ class LambdaFunction(BaseModel):
return json.dumps(self.get_configuration())
def get_configuration(self):
return {
config = {
"CodeSha256": self.code_sha_256,
"CodeSize": self.code_size,
"Description": self.description,
@ -110,70 +234,105 @@ class LambdaFunction(BaseModel):
"VpcConfig": self.vpc_config,
}
def get_code(self):
if isinstance(self.code, dict):
return {
"Code": {
"Location": "s3://lambda-functions.aws.amazon.com/{0}".format(self.code['S3Key']),
"RepositoryType": "S3"
},
"Configuration": self.get_configuration(),
}
else:
return {
"Configuration": self.get_configuration(),
if self.environment_vars:
config['Environment'] = {
'Variables': self.environment_vars
}
def convert(self, s):
return config
def get_code(self):
return {
"Code": {
"Location": "s3://awslambda-{0}-tasks.s3-{0}.amazonaws.com/{1}".format(self.region, self.code['S3Key']),
"RepositoryType": "S3"
},
"Configuration": self.get_configuration(),
}
@staticmethod
def convert(s):
try:
return str(s, encoding='utf-8')
except:
return s
def is_json(self, test_str):
@staticmethod
def is_json(test_str):
try:
response = json.loads(test_str)
except:
response = test_str
return response
def _invoke_lambda(self, code, event={}, context={}):
# TO DO: context not yet implemented
try:
mycode = "\n".join(['import json',
self.convert(self.code),
self.convert('print(json.dumps(lambda_handler(%s, %s)))' % (self.is_json(self.convert(event)), context))])
def _invoke_lambda(self, code, event=None, context=None):
# TODO: context not yet implemented
if event is None:
event = dict()
if context is None:
context = {}
except Exception as ex:
print("Exception %s", ex)
errored = False
try:
original_stdout = sys.stdout
original_stderr = sys.stderr
codeOut = StringIO()
codeErr = StringIO()
sys.stdout = codeOut
sys.stderr = codeErr
exec(mycode)
exec_err = codeErr.getvalue()
exec_out = codeOut.getvalue()
result = self.convert(exec_out.strip())
if exec_err:
result = "\n".join([exec_out.strip(), self.convert(exec_err)])
except Exception as ex:
errored = True
result = '%s\n\n\nException %s' % (mycode, ex)
finally:
codeErr.close()
codeOut.close()
sys.stdout = original_stdout
sys.stderr = original_stderr
return self.convert(result), errored
# TODO: I believe we can keep the container running and feed events as needed
# also need to hook it up to the other services so it can make kws/s3 etc calls
# Should get invoke_id /RequestId from invovation
env_vars = {
"AWS_LAMBDA_FUNCTION_TIMEOUT": self.timeout,
"AWS_LAMBDA_FUNCTION_NAME": self.function_name,
"AWS_LAMBDA_FUNCTION_MEMORY_SIZE": self.memory_size,
"AWS_LAMBDA_FUNCTION_VERSION": self.version,
"AWS_REGION": self.region,
}
env_vars.update(self.environment_vars)
container = output = exit_code = None
with _DockerDataVolumeContext(self) as data_vol:
try:
run_kwargs = dict(links={'motoserver': 'motoserver'}) if settings.TEST_SERVER_MODE else {}
container = self.docker_client.containers.run(
"lambci/lambda:{}".format(self.run_time),
[self.handler, json.dumps(event)], remove=False,
mem_limit="{}m".format(self.memory_size),
volumes=["{}:/var/task".format(data_vol.name)], environment=env_vars, detach=True, **run_kwargs)
finally:
if container:
exit_code = container.wait()
output = container.logs(stdout=False, stderr=True)
output += container.logs(stdout=True, stderr=False)
container.remove()
output = output.decode('utf-8')
# Send output to "logs" backend
invoke_id = uuid.uuid4().hex
log_stream_name = "{date.year}/{date.month:02d}/{date.day:02d}/[{version}]{invoke_id}".format(
date=datetime.datetime.utcnow(), version=self.version, invoke_id=invoke_id
)
self.logs_backend.create_log_stream(self.logs_group_name, log_stream_name)
log_events = [{'timestamp': unix_time_millis(), "message": line}
for line in output.splitlines()]
self.logs_backend.put_log_events(self.logs_group_name, log_stream_name, log_events, None)
if exit_code != 0:
raise Exception(
'lambda invoke failed output: {}'.format(output))
# strip out RequestId lines
output = os.linesep.join([line for line in self.convert(output).splitlines() if not _stderr_regex.match(line)])
return output, False
except BaseException as e:
traceback.print_exc()
return "error running lambda: {}".format(e), True
def invoke(self, body, request_headers, response_headers):
payload = dict()
if body:
body = json.loads(body)
# Get the invocation type:
res, errored = self._invoke_lambda(code=self.code, event=body)
if request_headers.get("x-amz-invocation-type") == "RequestResponse":
@ -189,7 +348,8 @@ class LambdaFunction(BaseModel):
return result
@classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name):
def create_from_cloudformation_json(cls, resource_name, cloudformation_json,
region_name):
properties = cloudformation_json['Properties']
# required
@ -212,17 +372,19 @@ class LambdaFunction(BaseModel):
# this snippet converts this plaintext code to a proper base64-encoded ZIP file.
if 'ZipFile' in properties['Code']:
spec['Code']['ZipFile'] = base64.b64encode(
cls._create_zipfile_from_plaintext_code(spec['Code']['ZipFile']))
cls._create_zipfile_from_plaintext_code(
spec['Code']['ZipFile']))
backend = lambda_backends[region_name]
fn = backend.create_function(spec)
return fn
def get_cfn_attribute(self, attribute_name):
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
from moto.cloudformation.exceptions import \
UnformattedGetAttTemplateException
if attribute_name == 'Arn':
region = 'us-east-1'
return 'arn:aws:lambda:{0}:123456789012:function:{1}'.format(region, self.function_name)
return 'arn:aws:lambda:{0}:123456789012:function:{1}'.format(
self.region, self.function_name)
raise UnformattedGetAttTemplateException()
@staticmethod
@ -236,7 +398,6 @@ class LambdaFunction(BaseModel):
class EventSourceMapping(BaseModel):
def __init__(self, spec):
# required
self.function_name = spec['FunctionName']
@ -246,10 +407,12 @@ class EventSourceMapping(BaseModel):
# optional
self.batch_size = spec.get('BatchSize', 100)
self.enabled = spec.get('Enabled', True)
self.starting_position_timestamp = spec.get('StartingPositionTimestamp', None)
self.starting_position_timestamp = spec.get('StartingPositionTimestamp',
None)
@classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name):
def create_from_cloudformation_json(cls, resource_name, cloudformation_json,
region_name):
properties = cloudformation_json['Properties']
spec = {
'FunctionName': properties['FunctionName'],
@ -264,12 +427,12 @@ class EventSourceMapping(BaseModel):
class LambdaVersion(BaseModel):
def __init__(self, spec):
self.version = spec['Version']
@classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name):
def create_from_cloudformation_json(cls, resource_name, cloudformation_json,
region_name):
properties = cloudformation_json['Properties']
spec = {
'Version': properties.get('Version')
@ -278,9 +441,14 @@ class LambdaVersion(BaseModel):
class LambdaBackend(BaseBackend):
def __init__(self):
def __init__(self, region_name):
self._functions = {}
self.region_name = region_name
def reset(self):
region_name = self.region_name
self.__dict__ = {}
self.__init__(region_name)
def has_function(self, function_name):
return function_name in self._functions
@ -289,7 +457,7 @@ class LambdaBackend(BaseBackend):
return self.get_function_by_arn(function_arn) is not None
def create_function(self, spec):
fn = LambdaFunction(spec)
fn = LambdaFunction(spec, self.region_name)
self._functions[fn.function_name] = fn
return fn
@ -308,6 +476,42 @@ class LambdaBackend(BaseBackend):
def list_functions(self):
return self._functions.values()
def send_message(self, function_name, message):
event = {
"Records": [
{
"EventVersion": "1.0",
"EventSubscriptionArn": "arn:aws:sns:EXAMPLE",
"EventSource": "aws:sns",
"Sns": {
"SignatureVersion": "1",
"Timestamp": "1970-01-01T00:00:00.000Z",
"Signature": "EXAMPLE",
"SigningCertUrl": "EXAMPLE",
"MessageId": "95df01b4-ee98-5cb9-9903-4c221d41eb5e",
"Message": message,
"MessageAttributes": {
"Test": {
"Type": "String",
"Value": "TestString"
},
"TestBinary": {
"Type": "Binary",
"Value": "TestBinary"
}
},
"Type": "Notification",
"UnsubscribeUrl": "EXAMPLE",
"TopicArn": "arn:aws:sns:EXAMPLE",
"Subject": "TestInvoke"
}
}
]
}
self._functions[function_name].invoke(json.dumps(event), {}, {})
pass
def list_tags(self, resource):
return self.get_function_by_arn(resource).tags
@ -328,10 +532,8 @@ def do_validate_s3():
return os.environ.get('VALIDATE_LAMBDA_S3', '') in ['', '1', 'true']
lambda_backends = {}
for region in boto.awslambda.regions():
lambda_backends[region.name] = LambdaBackend()
# Handle us forgotten regions, unless Lambda truly only runs out of US and
for region in ['ap-southeast-2']:
lambda_backends[region] = LambdaBackend()
lambda_backends = {_region.name: LambdaBackend(_region.name)
for _region in boto.awslambda.regions()}
lambda_backends['ap-southeast-2'] = LambdaBackend('ap-southeast-2')

View file

@ -9,8 +9,8 @@ response = LambdaResponse()
url_paths = {
'{0}/(?P<api_version>[^/]+)/functions/?$': response.root,
'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/?$': response.function,
'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/invocations/?$': response.invoke,
'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/invoke-async/?$': response.invoke_async,
'{0}/(?P<api_version>[^/]+)/tags/(?P<resource_arn>.+)': response.tag
r'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/?$': response.function,
r'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/invocations/?$': response.invoke,
r'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/invoke-async/?$': response.invoke_async,
r'{0}/(?P<api_version>[^/]+)/tags/(?P<resource_arn>.+)': response.tag
}

View file

@ -23,7 +23,9 @@ from moto.iam import iam_backends
from moto.instance_metadata import instance_metadata_backends
from moto.kinesis import kinesis_backends
from moto.kms import kms_backends
from moto.logs import logs_backends
from moto.opsworks import opsworks_backends
from moto.polly import polly_backends
from moto.rds2 import rds2_backends
from moto.redshift import redshift_backends
from moto.route53 import route53_backends
@ -56,9 +58,11 @@ BACKENDS = {
'iam': iam_backends,
'moto_api': moto_api_backends,
'instance_metadata': instance_metadata_backends,
'opsworks': opsworks_backends,
'logs': logs_backends,
'kinesis': kinesis_backends,
'kms': kms_backends,
'opsworks': opsworks_backends,
'polly': polly_backends,
'redshift': redshift_backends,
'rds': rds2_backends,
's3': s3_backends,

View file

@ -2,6 +2,8 @@ from __future__ import unicode_literals
import copy
import itertools
import json
import os
import re
import six
@ -109,6 +111,9 @@ from .utils import (
is_tag_filter,
)
RESOURCES_DIR = os.path.join(os.path.dirname(__file__), 'resources')
INSTANCE_TYPES = json.load(open(os.path.join(RESOURCES_DIR, 'instance_types.json'), 'r'))
def utc_date_and_time():
return datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%S.000Z')
@ -3662,6 +3667,5 @@ class EC2Backend(BaseBackend, InstanceBackend, TagBackend, AmiBackend,
return True
ec2_backends = {}
for region in RegionsAndZonesBackend.regions:
ec2_backends[region.name] = EC2Backend(region.name)
ec2_backends = {region.name: EC2Backend(region.name)
for region in RegionsAndZonesBackend.regions}

File diff suppressed because one or more lines are too long

View file

@ -5,7 +5,12 @@ from moto.core.responses import BaseResponse
class General(BaseResponse):
def get_console_output(self):
instance_id = self._get_multi_param('InstanceId')[0]
instance_id = self._get_param('InstanceId')
if not instance_id:
# For compatibility with boto.
# See: https://github.com/spulec/moto/pull/1152#issuecomment-332487599
instance_id = self._get_multi_param('InstanceId')[0]
instance = self.ec2_backend.get_instance(instance_id)
template = self.response_template(GET_CONSOLE_OUTPUT_RESULT)
return template.render(instance=instance)

View file

@ -18,8 +18,8 @@ class EC2ContainerServiceResponse(BaseResponse):
except ValueError:
return {}
def _get_param(self, param):
return self.request_params.get(param, None)
def _get_param(self, param, if_none=None):
return self.request_params.get(param, if_none)
def create_cluster(self):
cluster_name = self._get_param('clusterName')

5
moto/logs/__init__.py Normal file
View file

@ -0,0 +1,5 @@
from .models import logs_backends
from ..core.models import base_decorator, deprecated_base_decorator
mock_logs = base_decorator(logs_backends)
mock_logs_deprecated = deprecated_base_decorator(logs_backends)

228
moto/logs/models.py Normal file
View file

@ -0,0 +1,228 @@
from moto.core import BaseBackend
import boto.logs
from moto.core.utils import unix_time_millis
class LogEvent:
_event_id = 0
def __init__(self, ingestion_time, log_event):
self.ingestionTime = ingestion_time
self.timestamp = log_event["timestamp"]
self.message = log_event['message']
self.eventId = self.__class__._event_id
self.__class__._event_id += 1
def to_filter_dict(self):
return {
"eventId": self.eventId,
"ingestionTime": self.ingestionTime,
# "logStreamName":
"message": self.message,
"timestamp": self.timestamp
}
class LogStream:
_log_ids = 0
def __init__(self, region, log_group, name):
self.region = region
self.arn = "arn:aws:logs:{region}:{id}:log-group:{log_group}:log-stream:{log_stream}".format(
region=region, id=self.__class__._log_ids, log_group=log_group, log_stream=name)
self.creationTime = unix_time_millis()
self.firstEventTimestamp = None
self.lastEventTimestamp = None
self.lastIngestionTime = None
self.logStreamName = name
self.storedBytes = 0
self.uploadSequenceToken = 0 # I'm guessing this is token needed for sequenceToken by put_events
self.events = []
self.__class__._log_ids += 1
def to_describe_dict(self):
return {
"arn": self.arn,
"creationTime": self.creationTime,
"firstEventTimestamp": self.firstEventTimestamp,
"lastEventTimestamp": self.lastEventTimestamp,
"lastIngestionTime": self.lastIngestionTime,
"logStreamName": self.logStreamName,
"storedBytes": self.storedBytes,
"uploadSequenceToken": str(self.uploadSequenceToken),
}
def put_log_events(self, log_group_name, log_stream_name, log_events, sequence_token):
# TODO: ensure sequence_token
# TODO: to be thread safe this would need a lock
self.lastIngestionTime = unix_time_millis()
# TODO: make this match AWS if possible
self.storedBytes += sum([len(log_event["message"]) for log_event in log_events])
self.events += [LogEvent(self.lastIngestionTime, log_event) for log_event in log_events]
self.uploadSequenceToken += 1
return self.uploadSequenceToken
def get_log_events(self, log_group_name, log_stream_name, start_time, end_time, limit, next_token, start_from_head):
def filter_func(event):
if start_time and event.timestamp < start_time:
return False
if end_time and event.timestamp > end_time:
return False
return True
events = sorted(filter(filter_func, self.events), key=lambda event: event.timestamp, reverse=start_from_head)
back_token = next_token
if next_token is None:
next_token = 0
events_page = events[next_token: next_token + limit]
next_token += limit
if next_token >= len(self.events):
next_token = None
return events_page, back_token, next_token
def filter_log_events(self, log_group_name, log_stream_names, start_time, end_time, limit, next_token, filter_pattern, interleaved):
def filter_func(event):
if start_time and event.timestamp < start_time:
return False
if end_time and event.timestamp > end_time:
return False
return True
events = []
for event in sorted(filter(filter_func, self.events), key=lambda x: x.timestamp):
event_obj = event.to_filter_dict()
event_obj['logStreamName'] = self.logStreamName
events.append(event_obj)
return events
class LogGroup:
def __init__(self, region, name, tags):
self.name = name
self.region = region
self.tags = tags
self.streams = dict() # {name: LogStream}
def create_log_stream(self, log_stream_name):
assert log_stream_name not in self.streams
self.streams[log_stream_name] = LogStream(self.region, self.name, log_stream_name)
def delete_log_stream(self, log_stream_name):
assert log_stream_name in self.streams
del self.streams[log_stream_name]
def describe_log_streams(self, descending, limit, log_group_name, log_stream_name_prefix, next_token, order_by):
log_streams = [stream.to_describe_dict() for name, stream in self.streams.items() if name.startswith(log_stream_name_prefix)]
def sorter(stream):
return stream.name if order_by == 'logStreamName' else stream.lastEventTimestamp
if next_token is None:
next_token = 0
log_streams = sorted(log_streams, key=sorter, reverse=descending)
new_token = next_token + limit
log_streams_page = log_streams[next_token: new_token]
if new_token >= len(log_streams):
new_token = None
return log_streams_page, new_token
def put_log_events(self, log_group_name, log_stream_name, log_events, sequence_token):
assert log_stream_name in self.streams
stream = self.streams[log_stream_name]
return stream.put_log_events(log_group_name, log_stream_name, log_events, sequence_token)
def get_log_events(self, log_group_name, log_stream_name, start_time, end_time, limit, next_token, start_from_head):
assert log_stream_name in self.streams
stream = self.streams[log_stream_name]
return stream.get_log_events(log_group_name, log_stream_name, start_time, end_time, limit, next_token, start_from_head)
def filter_log_events(self, log_group_name, log_stream_names, start_time, end_time, limit, next_token, filter_pattern, interleaved):
assert not filter_pattern # TODO: impl
streams = [stream for name, stream in self.streams.items() if not log_stream_names or name in log_stream_names]
events = []
for stream in streams:
events += stream.filter_log_events(log_group_name, log_stream_names, start_time, end_time, limit, next_token, filter_pattern, interleaved)
if interleaved:
events = sorted(events, key=lambda event: event.timestamp)
if next_token is None:
next_token = 0
events_page = events[next_token: next_token + limit]
next_token += limit
if next_token >= len(events):
next_token = None
searched_streams = [{"logStreamName": stream.logStreamName, "searchedCompletely": True} for stream in streams]
return events_page, next_token, searched_streams
class LogsBackend(BaseBackend):
def __init__(self, region_name):
self.region_name = region_name
self.groups = dict() # { logGroupName: LogGroup}
def reset(self):
region_name = self.region_name
self.__dict__ = {}
self.__init__(region_name)
def create_log_group(self, log_group_name, tags):
assert log_group_name not in self.groups
self.groups[log_group_name] = LogGroup(self.region_name, log_group_name, tags)
def ensure_log_group(self, log_group_name, tags):
if log_group_name in self.groups:
return
self.groups[log_group_name] = LogGroup(self.region_name, log_group_name, tags)
def delete_log_group(self, log_group_name):
assert log_group_name in self.groups
del self.groups[log_group_name]
def create_log_stream(self, log_group_name, log_stream_name):
assert log_group_name in self.groups
log_group = self.groups[log_group_name]
return log_group.create_log_stream(log_stream_name)
def delete_log_stream(self, log_group_name, log_stream_name):
assert log_group_name in self.groups
log_group = self.groups[log_group_name]
return log_group.delete_log_stream(log_stream_name)
def describe_log_streams(self, descending, limit, log_group_name, log_stream_name_prefix, next_token, order_by):
assert log_group_name in self.groups
log_group = self.groups[log_group_name]
return log_group.describe_log_streams(descending, limit, log_group_name, log_stream_name_prefix, next_token, order_by)
def put_log_events(self, log_group_name, log_stream_name, log_events, sequence_token):
# TODO: add support for sequence_tokens
assert log_group_name in self.groups
log_group = self.groups[log_group_name]
return log_group.put_log_events(log_group_name, log_stream_name, log_events, sequence_token)
def get_log_events(self, log_group_name, log_stream_name, start_time, end_time, limit, next_token, start_from_head):
assert log_group_name in self.groups
log_group = self.groups[log_group_name]
return log_group.get_log_events(log_group_name, log_stream_name, start_time, end_time, limit, next_token, start_from_head)
def filter_log_events(self, log_group_name, log_stream_names, start_time, end_time, limit, next_token, filter_pattern, interleaved):
assert log_group_name in self.groups
log_group = self.groups[log_group_name]
return log_group.filter_log_events(log_group_name, log_stream_names, start_time, end_time, limit, next_token, filter_pattern, interleaved)
logs_backends = {region.name: LogsBackend(region.name) for region in boto.logs.regions()}

114
moto/logs/responses.py Normal file
View file

@ -0,0 +1,114 @@
from moto.core.responses import BaseResponse
from .models import logs_backends
import json
# See http://docs.aws.amazon.com/AmazonCloudWatchLogs/latest/APIReference/Welcome.html
class LogsResponse(BaseResponse):
@property
def logs_backend(self):
return logs_backends[self.region]
@property
def request_params(self):
try:
return json.loads(self.body)
except ValueError:
return {}
def _get_param(self, param, if_none=None):
return self.request_params.get(param, if_none)
def create_log_group(self):
log_group_name = self._get_param('logGroupName')
tags = self._get_param('tags')
assert 1 <= len(log_group_name) <= 512 # TODO: assert pattern
self.logs_backend.create_log_group(log_group_name, tags)
return ''
def delete_log_group(self):
log_group_name = self._get_param('logGroupName')
self.logs_backend.delete_log_group(log_group_name)
return ''
def create_log_stream(self):
log_group_name = self._get_param('logGroupName')
log_stream_name = self._get_param('logStreamName')
self.logs_backend.create_log_stream(log_group_name, log_stream_name)
return ''
def delete_log_stream(self):
log_group_name = self._get_param('logGroupName')
log_stream_name = self._get_param('logStreamName')
self.logs_backend.delete_log_stream(log_group_name, log_stream_name)
return ''
def describe_log_streams(self):
log_group_name = self._get_param('logGroupName')
log_stream_name_prefix = self._get_param('logStreamNamePrefix')
descending = self._get_param('descending', False)
limit = self._get_param('limit', 50)
assert limit <= 50
next_token = self._get_param('nextToken')
order_by = self._get_param('orderBy', 'LogStreamName')
assert order_by in {'LogStreamName', 'LastEventTime'}
if order_by == 'LastEventTime':
assert not log_stream_name_prefix
streams, next_token = self.logs_backend.describe_log_streams(
descending, limit, log_group_name, log_stream_name_prefix,
next_token, order_by)
return json.dumps({
"logStreams": streams,
"nextToken": next_token
})
def put_log_events(self):
log_group_name = self._get_param('logGroupName')
log_stream_name = self._get_param('logStreamName')
log_events = self._get_param('logEvents')
sequence_token = self._get_param('sequenceToken')
next_sequence_token = self.logs_backend.put_log_events(log_group_name, log_stream_name, log_events, sequence_token)
return json.dumps({'nextSequenceToken': next_sequence_token})
def get_log_events(self):
log_group_name = self._get_param('logGroupName')
log_stream_name = self._get_param('logStreamName')
start_time = self._get_param('startTime')
end_time = self._get_param("endTime")
limit = self._get_param('limit', 10000)
assert limit <= 10000
next_token = self._get_param('nextToken')
start_from_head = self._get_param('startFromHead')
events, next_backward_token, next_foward_token = \
self.logs_backend.get_log_events(log_group_name, log_stream_name, start_time, end_time, limit, next_token, start_from_head)
return json.dumps({
"events": events,
"nextBackwardToken": next_backward_token,
"nextForwardToken": next_foward_token
})
def filter_log_events(self):
log_group_name = self._get_param('logGroupName')
log_stream_names = self._get_param('logStreamNames', [])
start_time = self._get_param('startTime')
# impl, see: http://docs.aws.amazon.com/AmazonCloudWatch/latest/logs/FilterAndPatternSyntax.html
filter_pattern = self._get_param('filterPattern')
interleaved = self._get_param('interleaved', False)
end_time = self._get_param("endTime")
limit = self._get_param('limit', 10000)
assert limit <= 10000
next_token = self._get_param('nextToken')
events, next_token, searched_streams = self.logs_backend.filter_log_events(log_group_name, log_stream_names, start_time, end_time, limit, next_token, filter_pattern, interleaved)
return json.dumps({
"events": events,
"nextToken": next_token,
"searchedLogStreams": searched_streams
})

9
moto/logs/urls.py Normal file
View file

@ -0,0 +1,9 @@
from .responses import LogsResponse
url_bases = [
"https?://logs.(.+).amazonaws.com",
]
url_paths = {
'{0}/$': LogsResponse.dispatch,
}

6
moto/polly/__init__.py Normal file
View file

@ -0,0 +1,6 @@
from __future__ import unicode_literals
from .models import polly_backends
from ..core.models import base_decorator
polly_backend = polly_backends['us-east-1']
mock_polly = base_decorator(polly_backends)

114
moto/polly/models.py Normal file
View file

@ -0,0 +1,114 @@
from __future__ import unicode_literals
from xml.etree import ElementTree as ET
import datetime
import boto3
from moto.core import BaseBackend, BaseModel
from .resources import VOICE_DATA
from .utils import make_arn_for_lexicon
DEFAULT_ACCOUNT_ID = 123456789012
class Lexicon(BaseModel):
def __init__(self, name, content, region_name):
self.name = name
self.content = content
self.size = 0
self.alphabet = None
self.last_modified = None
self.language_code = None
self.lexemes_count = 0
self.arn = make_arn_for_lexicon(DEFAULT_ACCOUNT_ID, name, region_name)
self.update()
def update(self, content=None):
if content is not None:
self.content = content
# Probably a very naive approach, but it'll do for now.
try:
root = ET.fromstring(self.content)
self.size = len(self.content)
self.last_modified = int((datetime.datetime.now() -
datetime.datetime(1970, 1, 1)).total_seconds())
self.lexemes_count = len(root.findall('.'))
for key, value in root.attrib.items():
if key.endswith('alphabet'):
self.alphabet = value
elif key.endswith('lang'):
self.language_code = value
except Exception as err:
raise ValueError('Failure parsing XML: {0}'.format(err))
def to_dict(self):
return {
'Attributes': {
'Alphabet': self.alphabet,
'LanguageCode': self.language_code,
'LastModified': self.last_modified,
'LexemesCount': self.lexemes_count,
'LexiconArn': self.arn,
'Size': self.size
}
}
def __repr__(self):
return '<Lexicon {0}>'.format(self.name)
class PollyBackend(BaseBackend):
def __init__(self, region_name=None):
super(PollyBackend, self).__init__()
self.region_name = region_name
self._lexicons = {}
def reset(self):
region_name = self.region_name
self.__dict__ = {}
self.__init__(region_name)
def describe_voices(self, language_code, next_token):
if language_code is None:
return VOICE_DATA
return [item for item in VOICE_DATA if item['LanguageCode'] == language_code]
def delete_lexicon(self, name):
# implement here
del self._lexicons[name]
def get_lexicon(self, name):
# Raises KeyError
return self._lexicons[name]
def list_lexicons(self, next_token):
result = []
for name, lexicon in self._lexicons.items():
lexicon_dict = lexicon.to_dict()
lexicon_dict['Name'] = name
result.append(lexicon_dict)
return result
def put_lexicon(self, name, content):
# If lexicon content is bad, it will raise ValueError
if name in self._lexicons:
# Regenerated all the stats from the XML
# but keeps the ARN
self._lexicons.update(content)
else:
lexicon = Lexicon(name, content, region_name=self.region_name)
self._lexicons[name] = lexicon
available_regions = boto3.session.Session().get_available_regions("polly")
polly_backends = {region: PollyBackend(region_name=region) for region in available_regions}

63
moto/polly/resources.py Normal file
View file

@ -0,0 +1,63 @@
# -*- coding: utf-8 -*-
VOICE_DATA = [
{'Id': 'Joanna', 'LanguageCode': 'en-US', 'LanguageName': 'US English', 'Gender': 'Female', 'Name': 'Joanna'},
{'Id': 'Mizuki', 'LanguageCode': 'ja-JP', 'LanguageName': 'Japanese', 'Gender': 'Female', 'Name': 'Mizuki'},
{'Id': 'Filiz', 'LanguageCode': 'tr-TR', 'LanguageName': 'Turkish', 'Gender': 'Female', 'Name': 'Filiz'},
{'Id': 'Astrid', 'LanguageCode': 'sv-SE', 'LanguageName': 'Swedish', 'Gender': 'Female', 'Name': 'Astrid'},
{'Id': 'Tatyana', 'LanguageCode': 'ru-RU', 'LanguageName': 'Russian', 'Gender': 'Female', 'Name': 'Tatyana'},
{'Id': 'Maxim', 'LanguageCode': 'ru-RU', 'LanguageName': 'Russian', 'Gender': 'Male', 'Name': 'Maxim'},
{'Id': 'Carmen', 'LanguageCode': 'ro-RO', 'LanguageName': 'Romanian', 'Gender': 'Female', 'Name': 'Carmen'},
{'Id': 'Ines', 'LanguageCode': 'pt-PT', 'LanguageName': 'Portuguese', 'Gender': 'Female', 'Name': 'Inês'},
{'Id': 'Cristiano', 'LanguageCode': 'pt-PT', 'LanguageName': 'Portuguese', 'Gender': 'Male', 'Name': 'Cristiano'},
{'Id': 'Vitoria', 'LanguageCode': 'pt-BR', 'LanguageName': 'Brazilian Portuguese', 'Gender': 'Female', 'Name': 'Vitória'},
{'Id': 'Ricardo', 'LanguageCode': 'pt-BR', 'LanguageName': 'Brazilian Portuguese', 'Gender': 'Male', 'Name': 'Ricardo'},
{'Id': 'Maja', 'LanguageCode': 'pl-PL', 'LanguageName': 'Polish', 'Gender': 'Female', 'Name': 'Maja'},
{'Id': 'Jan', 'LanguageCode': 'pl-PL', 'LanguageName': 'Polish', 'Gender': 'Male', 'Name': 'Jan'},
{'Id': 'Ewa', 'LanguageCode': 'pl-PL', 'LanguageName': 'Polish', 'Gender': 'Female', 'Name': 'Ewa'},
{'Id': 'Ruben', 'LanguageCode': 'nl-NL', 'LanguageName': 'Dutch', 'Gender': 'Male', 'Name': 'Ruben'},
{'Id': 'Lotte', 'LanguageCode': 'nl-NL', 'LanguageName': 'Dutch', 'Gender': 'Female', 'Name': 'Lotte'},
{'Id': 'Liv', 'LanguageCode': 'nb-NO', 'LanguageName': 'Norwegian', 'Gender': 'Female', 'Name': 'Liv'},
{'Id': 'Giorgio', 'LanguageCode': 'it-IT', 'LanguageName': 'Italian', 'Gender': 'Male', 'Name': 'Giorgio'},
{'Id': 'Carla', 'LanguageCode': 'it-IT', 'LanguageName': 'Italian', 'Gender': 'Female', 'Name': 'Carla'},
{'Id': 'Karl', 'LanguageCode': 'is-IS', 'LanguageName': 'Icelandic', 'Gender': 'Male', 'Name': 'Karl'},
{'Id': 'Dora', 'LanguageCode': 'is-IS', 'LanguageName': 'Icelandic', 'Gender': 'Female', 'Name': 'Dóra'},
{'Id': 'Mathieu', 'LanguageCode': 'fr-FR', 'LanguageName': 'French', 'Gender': 'Male', 'Name': 'Mathieu'},
{'Id': 'Celine', 'LanguageCode': 'fr-FR', 'LanguageName': 'French', 'Gender': 'Female', 'Name': 'Céline'},
{'Id': 'Chantal', 'LanguageCode': 'fr-CA', 'LanguageName': 'Canadian French', 'Gender': 'Female', 'Name': 'Chantal'},
{'Id': 'Penelope', 'LanguageCode': 'es-US', 'LanguageName': 'US Spanish', 'Gender': 'Female', 'Name': 'Penélope'},
{'Id': 'Miguel', 'LanguageCode': 'es-US', 'LanguageName': 'US Spanish', 'Gender': 'Male', 'Name': 'Miguel'},
{'Id': 'Enrique', 'LanguageCode': 'es-ES', 'LanguageName': 'Castilian Spanish', 'Gender': 'Male', 'Name': 'Enrique'},
{'Id': 'Conchita', 'LanguageCode': 'es-ES', 'LanguageName': 'Castilian Spanish', 'Gender': 'Female', 'Name': 'Conchita'},
{'Id': 'Geraint', 'LanguageCode': 'en-GB-WLS', 'LanguageName': 'Welsh English', 'Gender': 'Male', 'Name': 'Geraint'},
{'Id': 'Salli', 'LanguageCode': 'en-US', 'LanguageName': 'US English', 'Gender': 'Female', 'Name': 'Salli'},
{'Id': 'Kimberly', 'LanguageCode': 'en-US', 'LanguageName': 'US English', 'Gender': 'Female', 'Name': 'Kimberly'},
{'Id': 'Kendra', 'LanguageCode': 'en-US', 'LanguageName': 'US English', 'Gender': 'Female', 'Name': 'Kendra'},
{'Id': 'Justin', 'LanguageCode': 'en-US', 'LanguageName': 'US English', 'Gender': 'Male', 'Name': 'Justin'},
{'Id': 'Joey', 'LanguageCode': 'en-US', 'LanguageName': 'US English', 'Gender': 'Male', 'Name': 'Joey'},
{'Id': 'Ivy', 'LanguageCode': 'en-US', 'LanguageName': 'US English', 'Gender': 'Female', 'Name': 'Ivy'},
{'Id': 'Raveena', 'LanguageCode': 'en-IN', 'LanguageName': 'Indian English', 'Gender': 'Female', 'Name': 'Raveena'},
{'Id': 'Emma', 'LanguageCode': 'en-GB', 'LanguageName': 'British English', 'Gender': 'Female', 'Name': 'Emma'},
{'Id': 'Brian', 'LanguageCode': 'en-GB', 'LanguageName': 'British English', 'Gender': 'Male', 'Name': 'Brian'},
{'Id': 'Amy', 'LanguageCode': 'en-GB', 'LanguageName': 'British English', 'Gender': 'Female', 'Name': 'Amy'},
{'Id': 'Russell', 'LanguageCode': 'en-AU', 'LanguageName': 'Australian English', 'Gender': 'Male', 'Name': 'Russell'},
{'Id': 'Nicole', 'LanguageCode': 'en-AU', 'LanguageName': 'Australian English', 'Gender': 'Female', 'Name': 'Nicole'},
{'Id': 'Vicki', 'LanguageCode': 'de-DE', 'LanguageName': 'German', 'Gender': 'Female', 'Name': 'Vicki'},
{'Id': 'Marlene', 'LanguageCode': 'de-DE', 'LanguageName': 'German', 'Gender': 'Female', 'Name': 'Marlene'},
{'Id': 'Hans', 'LanguageCode': 'de-DE', 'LanguageName': 'German', 'Gender': 'Male', 'Name': 'Hans'},
{'Id': 'Naja', 'LanguageCode': 'da-DK', 'LanguageName': 'Danish', 'Gender': 'Female', 'Name': 'Naja'},
{'Id': 'Mads', 'LanguageCode': 'da-DK', 'LanguageName': 'Danish', 'Gender': 'Male', 'Name': 'Mads'},
{'Id': 'Gwyneth', 'LanguageCode': 'cy-GB', 'LanguageName': 'Welsh', 'Gender': 'Female', 'Name': 'Gwyneth'},
{'Id': 'Jacek', 'LanguageCode': 'pl-PL', 'LanguageName': 'Polish', 'Gender': 'Male', 'Name': 'Jacek'}
]
# {...} is also shorthand set syntax
LANGUAGE_CODES = {'cy-GB', 'da-DK', 'de-DE', 'en-AU', 'en-GB', 'en-GB-WLS', 'en-IN', 'en-US', 'es-ES', 'es-US',
'fr-CA', 'fr-FR', 'is-IS', 'it-IT', 'ja-JP', 'nb-NO', 'nl-NL', 'pl-PL', 'pt-BR', 'pt-PT', 'ro-RO',
'ru-RU', 'sv-SE', 'tr-TR'}
VOICE_IDS = {'Geraint', 'Gwyneth', 'Mads', 'Naja', 'Hans', 'Marlene', 'Nicole', 'Russell', 'Amy', 'Brian', 'Emma',
'Raveena', 'Ivy', 'Joanna', 'Joey', 'Justin', 'Kendra', 'Kimberly', 'Salli', 'Conchita', 'Enrique',
'Miguel', 'Penelope', 'Chantal', 'Celine', 'Mathieu', 'Dora', 'Karl', 'Carla', 'Giorgio', 'Mizuki',
'Liv', 'Lotte', 'Ruben', 'Ewa', 'Jacek', 'Jan', 'Maja', 'Ricardo', 'Vitoria', 'Cristiano', 'Ines',
'Carmen', 'Maxim', 'Tatyana', 'Astrid', 'Filiz'}

188
moto/polly/responses.py Normal file
View file

@ -0,0 +1,188 @@
from __future__ import unicode_literals
import json
import re
from six.moves.urllib.parse import urlsplit
from moto.core.responses import BaseResponse
from .models import polly_backends
from .resources import LANGUAGE_CODES, VOICE_IDS
LEXICON_NAME_REGEX = re.compile(r'^[0-9A-Za-z]{1,20}$')
class PollyResponse(BaseResponse):
@property
def polly_backend(self):
return polly_backends[self.region]
@property
def json(self):
if not hasattr(self, '_json'):
self._json = json.loads(self.body)
return self._json
def _error(self, code, message):
return json.dumps({'__type': code, 'message': message}), dict(status=400)
def _get_action(self):
# Amazon is now naming things /v1/api_name
url_parts = urlsplit(self.uri).path.lstrip('/').split('/')
# [0] = 'v1'
return url_parts[1]
# DescribeVoices
def voices(self):
language_code = self._get_param('LanguageCode')
next_token = self._get_param('NextToken')
if language_code is not None and language_code not in LANGUAGE_CODES:
msg = "1 validation error detected: Value '{0}' at 'languageCode' failed to satisfy constraint: " \
"Member must satisfy enum value set: [{1}]".format(language_code, ', '.join(LANGUAGE_CODES))
return msg, dict(status=400)
voices = self.polly_backend.describe_voices(language_code, next_token)
return json.dumps({'Voices': voices})
def lexicons(self):
# Dish out requests based on methods
# anything after the /v1/lexicons/
args = urlsplit(self.uri).path.lstrip('/').split('/')[2:]
if self.method == 'GET':
if len(args) == 0:
return self._get_lexicons_list()
else:
return self._get_lexicon(*args)
elif self.method == 'PUT':
return self._put_lexicons(*args)
elif self.method == 'DELETE':
return self._delete_lexicon(*args)
return self._error('InvalidAction', 'Bad route')
# PutLexicon
def _put_lexicons(self, lexicon_name):
if LEXICON_NAME_REGEX.match(lexicon_name) is None:
return self._error('InvalidParameterValue', 'Lexicon name must match [0-9A-Za-z]{1,20}')
if 'Content' not in self.json:
return self._error('MissingParameter', 'Content is missing from the body')
self.polly_backend.put_lexicon(lexicon_name, self.json['Content'])
return ''
# ListLexicons
def _get_lexicons_list(self):
next_token = self._get_param('NextToken')
result = {
'Lexicons': self.polly_backend.list_lexicons(next_token)
}
return json.dumps(result)
# GetLexicon
def _get_lexicon(self, lexicon_name):
try:
lexicon = self.polly_backend.get_lexicon(lexicon_name)
except KeyError:
return self._error('LexiconNotFoundException', 'Lexicon not found')
result = {
'Lexicon': {
'Name': lexicon_name,
'Content': lexicon.content
},
'LexiconAttributes': lexicon.to_dict()['Attributes']
}
return json.dumps(result)
# DeleteLexicon
def _delete_lexicon(self, lexicon_name):
try:
self.polly_backend.delete_lexicon(lexicon_name)
except KeyError:
return self._error('LexiconNotFoundException', 'Lexicon not found')
return ''
# SynthesizeSpeech
def speech(self):
# Sanity check params
args = {
'lexicon_names': None,
'sample_rate': 22050,
'speech_marks': None,
'text': None,
'text_type': 'text'
}
if 'LexiconNames' in self.json:
for lex in self.json['LexiconNames']:
try:
self.polly_backend.get_lexicon(lex)
except KeyError:
return self._error('LexiconNotFoundException', 'Lexicon not found')
args['lexicon_names'] = self.json['LexiconNames']
if 'OutputFormat' not in self.json:
return self._error('MissingParameter', 'Missing parameter OutputFormat')
if self.json['OutputFormat'] not in ('json', 'mp3', 'ogg_vorbis', 'pcm'):
return self._error('InvalidParameterValue', 'Not one of json, mp3, ogg_vorbis, pcm')
args['output_format'] = self.json['OutputFormat']
if 'SampleRate' in self.json:
sample_rate = int(self.json['SampleRate'])
if sample_rate not in (8000, 16000, 22050):
return self._error('InvalidSampleRateException', 'The specified sample rate is not valid.')
args['sample_rate'] = sample_rate
if 'SpeechMarkTypes' in self.json:
for value in self.json['SpeechMarkTypes']:
if value not in ('sentance', 'ssml', 'viseme', 'word'):
return self._error('InvalidParameterValue', 'Not one of sentance, ssml, viseme, word')
args['speech_marks'] = self.json['SpeechMarkTypes']
if 'Text' not in self.json:
return self._error('MissingParameter', 'Missing parameter Text')
args['text'] = self.json['Text']
if 'TextType' in self.json:
if self.json['TextType'] not in ('ssml', 'text'):
return self._error('InvalidParameterValue', 'Not one of ssml, text')
args['text_type'] = self.json['TextType']
if 'VoiceId' not in self.json:
return self._error('MissingParameter', 'Missing parameter VoiceId')
if self.json['VoiceId'] not in VOICE_IDS:
return self._error('InvalidParameterValue', 'Not one of {0}'.format(', '.join(VOICE_IDS)))
args['voice_id'] = self.json['VoiceId']
# More validation
if len(args['text']) > 3000:
return self._error('TextLengthExceededException', 'Text too long')
if args['speech_marks'] is not None and args['output_format'] != 'json':
return self._error('MarksNotSupportedForFormatException', 'OutputFormat must be json')
if args['speech_marks'] is not None and args['text_type'] == 'text':
return self._error('SsmlMarksNotSupportedForTextTypeException', 'TextType must be ssml')
content_type = 'audio/json'
if args['output_format'] == 'mp3':
content_type = 'audio/mpeg'
elif args['output_format'] == 'ogg_vorbis':
content_type = 'audio/ogg'
elif args['output_format'] == 'pcm':
content_type = 'audio/pcm'
headers = {'Content-Type': content_type}
return '\x00\x00\x00\x00\x00\x00\x00\x00', headers

13
moto/polly/urls.py Normal file
View file

@ -0,0 +1,13 @@
from __future__ import unicode_literals
from .responses import PollyResponse
url_bases = [
"https?://polly.(.+).amazonaws.com",
]
url_paths = {
'{0}/v1/voices': PollyResponse.dispatch,
'{0}/v1/lexicons/(?P<lexicon>[^/]+)': PollyResponse.dispatch,
'{0}/v1/lexicons': PollyResponse.dispatch,
'{0}/v1/speech': PollyResponse.dispatch,
}

5
moto/polly/utils.py Normal file
View file

@ -0,0 +1,5 @@
from __future__ import unicode_literals
def make_arn_for_lexicon(account_id, name, region_name):
return "arn:aws:polly:{0}:{1}:lexicon/{2}".format(region_name, account_id, name)

View file

@ -71,3 +71,25 @@ class ClusterSnapshotAlreadyExistsError(RedshiftClientError):
'ClusterSnapshotAlreadyExists',
"Cannot create the snapshot because a snapshot with the "
"identifier {0} already exists".format(snapshot_identifier))
class InvalidParameterValueError(RedshiftClientError):
def __init__(self, message):
super(InvalidParameterValueError, self).__init__(
'InvalidParameterValue',
message)
class ResourceNotFoundFaultError(RedshiftClientError):
code = 404
def __init__(self, resource_type=None, resource_name=None, message=None):
if resource_type and not resource_name:
msg = "resource of type '{0}' not found.".format(resource_type)
else:
msg = "{0} ({1}) not found.".format(resource_type, resource_name)
if message:
msg = message
super(ResourceNotFoundFaultError, self).__init__(
'ResourceNotFoundFault', msg)

View file

@ -15,11 +15,51 @@ from .exceptions import (
ClusterSnapshotAlreadyExistsError,
ClusterSnapshotNotFoundError,
ClusterSubnetGroupNotFoundError,
InvalidParameterValueError,
InvalidSubnetError,
ResourceNotFoundFaultError
)
class Cluster(BaseModel):
ACCOUNT_ID = 123456789012
class TaggableResourceMixin(object):
resource_type = None
def __init__(self, region_name, tags):
self.region = region_name
self.tags = tags or []
@property
def resource_id(self):
return None
@property
def arn(self):
return "arn:aws:redshift:{region}:{account_id}:{resource_type}:{resource_id}".format(
region=self.region,
account_id=ACCOUNT_ID,
resource_type=self.resource_type,
resource_id=self.resource_id)
def create_tags(self, tags):
new_keys = [tag_set['Key'] for tag_set in tags]
self.tags = [tag_set for tag_set in self.tags
if tag_set['Key'] not in new_keys]
self.tags.extend(tags)
return self.tags
def delete_tags(self, tag_keys):
self.tags = [tag_set for tag_set in self.tags
if tag_set['Key'] not in tag_keys]
return self.tags
class Cluster(TaggableResourceMixin, BaseModel):
resource_type = 'cluster'
def __init__(self, redshift_backend, cluster_identifier, node_type, master_username,
master_user_password, db_name, cluster_type, cluster_security_groups,
@ -27,7 +67,8 @@ class Cluster(BaseModel):
preferred_maintenance_window, cluster_parameter_group_name,
automated_snapshot_retention_period, port, cluster_version,
allow_version_upgrade, number_of_nodes, publicly_accessible,
encrypted, region):
encrypted, region_name, tags=None):
super(Cluster, self).__init__(region_name, tags)
self.redshift_backend = redshift_backend
self.cluster_identifier = cluster_identifier
self.status = 'available'
@ -57,13 +98,12 @@ class Cluster(BaseModel):
else:
self.cluster_security_groups = ["Default"]
self.region = region
if availability_zone:
self.availability_zone = availability_zone
else:
# This could probably be smarter, but there doesn't appear to be a
# way to pull AZs for a region in boto
self.availability_zone = region + "a"
self.availability_zone = region_name + "a"
if cluster_type == 'single-node':
self.number_of_nodes = 1
@ -106,7 +146,7 @@ class Cluster(BaseModel):
number_of_nodes=properties.get('NumberOfNodes'),
publicly_accessible=properties.get("PubliclyAccessible"),
encrypted=properties.get("Encrypted"),
region=region_name,
region_name=region_name,
)
return cluster
@ -149,6 +189,10 @@ class Cluster(BaseModel):
if parameter_group.cluster_parameter_group_name in self.cluster_parameter_group_name
]
@property
def resource_id(self):
return self.cluster_identifier
def to_json(self):
return {
"MasterUsername": self.master_username,
@ -180,18 +224,21 @@ class Cluster(BaseModel):
"ClusterIdentifier": self.cluster_identifier,
"AllowVersionUpgrade": self.allow_version_upgrade,
"Endpoint": {
"Address": '{}.{}.redshift.amazonaws.com'.format(
self.cluster_identifier,
self.region),
"Address": self.endpoint,
"Port": self.port
},
"PendingModifiedValues": []
"PendingModifiedValues": [],
"Tags": self.tags
}
class SubnetGroup(BaseModel):
class SubnetGroup(TaggableResourceMixin, BaseModel):
def __init__(self, ec2_backend, cluster_subnet_group_name, description, subnet_ids):
resource_type = 'subnetgroup'
def __init__(self, ec2_backend, cluster_subnet_group_name, description, subnet_ids,
region_name, tags=None):
super(SubnetGroup, self).__init__(region_name, tags)
self.ec2_backend = ec2_backend
self.cluster_subnet_group_name = cluster_subnet_group_name
self.description = description
@ -208,6 +255,7 @@ class SubnetGroup(BaseModel):
cluster_subnet_group_name=resource_name,
description=properties.get("Description"),
subnet_ids=properties.get("SubnetIds", []),
region_name=region_name
)
return subnet_group
@ -219,6 +267,10 @@ class SubnetGroup(BaseModel):
def vpc_id(self):
return self.subnets[0].vpc_id
@property
def resource_id(self):
return self.cluster_subnet_group_name
def to_json(self):
return {
"VpcId": self.vpc_id,
@ -232,27 +284,39 @@ class SubnetGroup(BaseModel):
"Name": subnet.availability_zone
},
} for subnet in self.subnets],
"Tags": self.tags
}
class SecurityGroup(BaseModel):
class SecurityGroup(TaggableResourceMixin, BaseModel):
def __init__(self, cluster_security_group_name, description):
resource_type = 'securitygroup'
def __init__(self, cluster_security_group_name, description, region_name, tags=None):
super(SecurityGroup, self).__init__(region_name, tags)
self.cluster_security_group_name = cluster_security_group_name
self.description = description
@property
def resource_id(self):
return self.cluster_security_group_name
def to_json(self):
return {
"EC2SecurityGroups": [],
"IPRanges": [],
"Description": self.description,
"ClusterSecurityGroupName": self.cluster_security_group_name,
"Tags": self.tags
}
class ParameterGroup(BaseModel):
class ParameterGroup(TaggableResourceMixin, BaseModel):
def __init__(self, cluster_parameter_group_name, group_family, description):
resource_type = 'parametergroup'
def __init__(self, cluster_parameter_group_name, group_family, description, region_name, tags=None):
super(ParameterGroup, self).__init__(region_name, tags)
self.cluster_parameter_group_name = cluster_parameter_group_name
self.group_family = group_family
self.description = description
@ -266,34 +330,41 @@ class ParameterGroup(BaseModel):
cluster_parameter_group_name=resource_name,
description=properties.get("Description"),
group_family=properties.get("ParameterGroupFamily"),
region_name=region_name
)
return parameter_group
@property
def resource_id(self):
return self.cluster_parameter_group_name
def to_json(self):
return {
"ParameterGroupFamily": self.group_family,
"Description": self.description,
"ParameterGroupName": self.cluster_parameter_group_name,
"Tags": self.tags
}
class Snapshot(BaseModel):
class Snapshot(TaggableResourceMixin, BaseModel):
def __init__(self, cluster, snapshot_identifier, tags=None):
resource_type = 'snapshot'
def __init__(self, cluster, snapshot_identifier, region_name, tags=None):
super(Snapshot, self).__init__(region_name, tags)
self.cluster = copy.copy(cluster)
self.snapshot_identifier = snapshot_identifier
self.snapshot_type = 'manual'
self.status = 'available'
self.tags = tags or []
self.create_time = iso_8601_datetime_with_milliseconds(
datetime.datetime.now())
@property
def arn(self):
return "arn:aws:redshift:{0}:1234567890:snapshot:{1}/{2}".format(
self.cluster.region,
self.cluster.cluster_identifier,
self.snapshot_identifier)
def resource_id(self):
return "{cluster_id}/{snapshot_id}".format(
cluster_id=self.cluster.cluster_identifier,
snapshot_id=self.snapshot_identifier)
def to_json(self):
return {
@ -315,26 +386,36 @@ class Snapshot(BaseModel):
class RedshiftBackend(BaseBackend):
def __init__(self, ec2_backend):
def __init__(self, ec2_backend, region_name):
self.region = region_name
self.clusters = {}
self.subnet_groups = {}
self.security_groups = {
"Default": SecurityGroup("Default", "Default Redshift Security Group")
"Default": SecurityGroup("Default", "Default Redshift Security Group", self.region)
}
self.parameter_groups = {
"default.redshift-1.0": ParameterGroup(
"default.redshift-1.0",
"redshift-1.0",
"Default Redshift parameter group",
self.region
)
}
self.ec2_backend = ec2_backend
self.snapshots = OrderedDict()
self.RESOURCE_TYPE_MAP = {
'cluster': self.clusters,
'parametergroup': self.parameter_groups,
'securitygroup': self.security_groups,
'snapshot': self.snapshots,
'subnetgroup': self.subnet_groups
}
def reset(self):
ec2_backend = self.ec2_backend
region_name = self.region
self.__dict__ = {}
self.__init__(ec2_backend)
self.__init__(ec2_backend, region_name)
def create_cluster(self, **cluster_kwargs):
cluster_identifier = cluster_kwargs['cluster_identifier']
@ -373,9 +454,10 @@ class RedshiftBackend(BaseBackend):
return self.clusters.pop(cluster_identifier)
raise ClusterNotFoundError(cluster_identifier)
def create_cluster_subnet_group(self, cluster_subnet_group_name, description, subnet_ids):
def create_cluster_subnet_group(self, cluster_subnet_group_name, description, subnet_ids,
region_name, tags=None):
subnet_group = SubnetGroup(
self.ec2_backend, cluster_subnet_group_name, description, subnet_ids)
self.ec2_backend, cluster_subnet_group_name, description, subnet_ids, region_name, tags)
self.subnet_groups[cluster_subnet_group_name] = subnet_group
return subnet_group
@ -393,9 +475,9 @@ class RedshiftBackend(BaseBackend):
return self.subnet_groups.pop(subnet_identifier)
raise ClusterSubnetGroupNotFoundError(subnet_identifier)
def create_cluster_security_group(self, cluster_security_group_name, description):
def create_cluster_security_group(self, cluster_security_group_name, description, region_name, tags=None):
security_group = SecurityGroup(
cluster_security_group_name, description)
cluster_security_group_name, description, region_name, tags)
self.security_groups[cluster_security_group_name] = security_group
return security_group
@ -414,9 +496,9 @@ class RedshiftBackend(BaseBackend):
raise ClusterSecurityGroupNotFoundError(security_group_identifier)
def create_cluster_parameter_group(self, cluster_parameter_group_name,
group_family, description):
group_family, description, region_name, tags=None):
parameter_group = ParameterGroup(
cluster_parameter_group_name, group_family, description)
cluster_parameter_group_name, group_family, description, region_name, tags)
self.parameter_groups[cluster_parameter_group_name] = parameter_group
return parameter_group
@ -435,17 +517,17 @@ class RedshiftBackend(BaseBackend):
return self.parameter_groups.pop(parameter_group_name)
raise ClusterParameterGroupNotFoundError(parameter_group_name)
def create_snapshot(self, cluster_identifier, snapshot_identifier, tags):
def create_cluster_snapshot(self, cluster_identifier, snapshot_identifier, region_name, tags):
cluster = self.clusters.get(cluster_identifier)
if not cluster:
raise ClusterNotFoundError(cluster_identifier)
if self.snapshots.get(snapshot_identifier) is not None:
raise ClusterSnapshotAlreadyExistsError(snapshot_identifier)
snapshot = Snapshot(cluster, snapshot_identifier, tags)
snapshot = Snapshot(cluster, snapshot_identifier, region_name, tags)
self.snapshots[snapshot_identifier] = snapshot
return snapshot
def describe_snapshots(self, cluster_identifier, snapshot_identifier):
def describe_cluster_snapshots(self, cluster_identifier=None, snapshot_identifier=None):
if cluster_identifier:
for snapshot in self.snapshots.values():
if snapshot.cluster.cluster_identifier == cluster_identifier:
@ -459,7 +541,7 @@ class RedshiftBackend(BaseBackend):
return self.snapshots.values()
def delete_snapshot(self, snapshot_identifier):
def delete_cluster_snapshot(self, snapshot_identifier):
if snapshot_identifier not in self.snapshots:
raise ClusterSnapshotNotFoundError(snapshot_identifier)
@ -467,23 +549,105 @@ class RedshiftBackend(BaseBackend):
deleted_snapshot.status = 'deleted'
return deleted_snapshot
def describe_tags_for_resource_type(self, resource_type):
def restore_from_cluster_snapshot(self, **kwargs):
snapshot_identifier = kwargs.pop('snapshot_identifier')
snapshot = self.describe_cluster_snapshots(snapshot_identifier=snapshot_identifier)[0]
create_kwargs = {
"node_type": snapshot.cluster.node_type,
"master_username": snapshot.cluster.master_username,
"master_user_password": snapshot.cluster.master_user_password,
"db_name": snapshot.cluster.db_name,
"cluster_type": 'multi-node' if snapshot.cluster.number_of_nodes > 1 else 'single-node',
"availability_zone": snapshot.cluster.availability_zone,
"port": snapshot.cluster.port,
"cluster_version": snapshot.cluster.cluster_version,
"number_of_nodes": snapshot.cluster.number_of_nodes,
"encrypted": snapshot.cluster.encrypted,
"tags": snapshot.cluster.tags
}
create_kwargs.update(kwargs)
return self.create_cluster(**create_kwargs)
def _get_resource_from_arn(self, arn):
try:
arn_breakdown = arn.split(':')
resource_type = arn_breakdown[5]
if resource_type == 'snapshot':
resource_id = arn_breakdown[6].split('/')[1]
else:
resource_id = arn_breakdown[6]
except IndexError:
resource_type = resource_id = arn
resources = self.RESOURCE_TYPE_MAP.get(resource_type)
if resources is None:
message = (
"Tagging is not supported for this type of resource: '{0}' "
"(the ARN is potentially malformed, please check the ARN "
"documentation for more information)".format(resource_type))
raise ResourceNotFoundFaultError(message=message)
try:
resource = resources[resource_id]
except KeyError:
raise ResourceNotFoundFaultError(resource_type, resource_id)
else:
return resource
@staticmethod
def _describe_tags_for_resources(resources):
tagged_resources = []
if resource_type == 'Snapshot':
for snapshot in self.snapshots.values():
for tag in snapshot.tags:
data = {
'ResourceName': snapshot.arn,
'ResourceType': 'snapshot',
'Tag': {
'Key': tag['Key'],
'Value': tag['Value']
}
for resource in resources:
for tag in resource.tags:
data = {
'ResourceName': resource.arn,
'ResourceType': resource.resource_type,
'Tag': {
'Key': tag['Key'],
'Value': tag['Value']
}
tagged_resources.append(data)
}
tagged_resources.append(data)
return tagged_resources
def _describe_tags_for_resource_type(self, resource_type):
resources = self.RESOURCE_TYPE_MAP.get(resource_type)
if not resources:
raise ResourceNotFoundFaultError(resource_type=resource_type)
return self._describe_tags_for_resources(resources.values())
def _describe_tags_for_resource_name(self, resource_name):
resource = self._get_resource_from_arn(resource_name)
return self._describe_tags_for_resources([resource])
def create_tags(self, resource_name, tags):
resource = self._get_resource_from_arn(resource_name)
resource.create_tags(tags)
def describe_tags(self, resource_name, resource_type):
if resource_name and resource_type:
raise InvalidParameterValueError(
"You cannot filter a list of resources using an Amazon "
"Resource Name (ARN) and a resource type together in the "
"same request. Retry the request using either an ARN or "
"a resource type, but not both.")
if resource_type:
return self._describe_tags_for_resource_type(resource_type.lower())
if resource_name:
return self._describe_tags_for_resource_name(resource_name)
# If name and type are not specified, return all tagged resources.
# TODO: Implement aws marker pagination
tagged_resources = []
for resource_type in self.RESOURCE_TYPE_MAP:
try:
tagged_resources += self._describe_tags_for_resource_type(resource_type)
except ResourceNotFoundFaultError:
pass
return tagged_resources
def delete_tags(self, resource_name, tag_keys):
resource = self._get_resource_from_arn(resource_name)
resource.delete_tags(tag_keys)
redshift_backends = {}
for region in boto.redshift.regions():
redshift_backends[region.name] = RedshiftBackend(ec2_backends[region.name])
redshift_backends[region.name] = RedshiftBackend(ec2_backends[region.name], region.name)

View file

@ -57,6 +57,15 @@ class RedshiftResponse(BaseResponse):
count += 1
return unpacked_list
def unpack_list_params(self, label):
unpacked_list = list()
count = 1
while self._get_param('{0}.{1}'.format(label, count)):
unpacked_list.append(self._get_param(
'{0}.{1}'.format(label, count)))
count += 1
return unpacked_list
def create_cluster(self):
cluster_kwargs = {
"cluster_identifier": self._get_param('ClusterIdentifier'),
@ -78,7 +87,8 @@ class RedshiftResponse(BaseResponse):
"number_of_nodes": self._get_int_param('NumberOfNodes'),
"publicly_accessible": self._get_param("PubliclyAccessible"),
"encrypted": self._get_param("Encrypted"),
"region": self.region,
"region_name": self.region,
"tags": self.unpack_complex_list_params('Tags.Tag', ('Key', 'Value'))
}
cluster = self.redshift_backend.create_cluster(**cluster_kwargs).to_json()
cluster['ClusterStatus'] = 'creating'
@ -94,23 +104,8 @@ class RedshiftResponse(BaseResponse):
})
def restore_from_cluster_snapshot(self):
snapshot_identifier = self._get_param('SnapshotIdentifier')
snapshots = self.redshift_backend.describe_snapshots(
None,
snapshot_identifier)
snapshot = snapshots[0]
kwargs_from_snapshot = {
"node_type": snapshot.cluster.node_type,
"master_username": snapshot.cluster.master_username,
"master_user_password": snapshot.cluster.master_user_password,
"db_name": snapshot.cluster.db_name,
"cluster_type": 'multi-node' if snapshot.cluster.number_of_nodes > 1 else 'single-node',
"availability_zone": snapshot.cluster.availability_zone,
"port": snapshot.cluster.port,
"cluster_version": snapshot.cluster.cluster_version,
"number_of_nodes": snapshot.cluster.number_of_nodes,
}
kwargs_from_request = {
restore_kwargs = {
"snapshot_identifier": self._get_param('SnapshotIdentifier'),
"cluster_identifier": self._get_param('ClusterIdentifier'),
"port": self._get_int_param('Port'),
"availability_zone": self._get_param('AvailabilityZone'),
@ -129,12 +124,9 @@ class RedshiftResponse(BaseResponse):
'PreferredMaintenanceWindow'),
"automated_snapshot_retention_period": self._get_int_param(
'AutomatedSnapshotRetentionPeriod'),
"region": self.region,
"encrypted": False,
"region_name": self.region,
}
kwargs_from_snapshot.update(kwargs_from_request)
cluster_kwargs = kwargs_from_snapshot
cluster = self.redshift_backend.create_cluster(**cluster_kwargs).to_json()
cluster = self.redshift_backend.restore_from_cluster_snapshot(**restore_kwargs).to_json()
cluster['ClusterStatus'] = 'creating'
return self.get_response({
"RestoreFromClusterSnapshotResponse": {
@ -230,11 +222,14 @@ class RedshiftResponse(BaseResponse):
# according to the AWS documentation
if not subnet_ids:
subnet_ids = self._get_multi_param('SubnetIds.SubnetIdentifier')
tags = self.unpack_complex_list_params('Tags.Tag', ('Key', 'Value'))
subnet_group = self.redshift_backend.create_cluster_subnet_group(
cluster_subnet_group_name=cluster_subnet_group_name,
description=description,
subnet_ids=subnet_ids,
region_name=self.region,
tags=tags
)
return self.get_response({
@ -280,10 +275,13 @@ class RedshiftResponse(BaseResponse):
cluster_security_group_name = self._get_param(
'ClusterSecurityGroupName')
description = self._get_param('Description')
tags = self.unpack_complex_list_params('Tags.Tag', ('Key', 'Value'))
security_group = self.redshift_backend.create_cluster_security_group(
cluster_security_group_name=cluster_security_group_name,
description=description,
region_name=self.region,
tags=tags
)
return self.get_response({
@ -331,11 +329,14 @@ class RedshiftResponse(BaseResponse):
cluster_parameter_group_name = self._get_param('ParameterGroupName')
group_family = self._get_param('ParameterGroupFamily')
description = self._get_param('Description')
tags = self.unpack_complex_list_params('Tags.Tag', ('Key', 'Value'))
parameter_group = self.redshift_backend.create_cluster_parameter_group(
cluster_parameter_group_name,
group_family,
description,
self.region,
tags
)
return self.get_response({
@ -381,11 +382,12 @@ class RedshiftResponse(BaseResponse):
def create_cluster_snapshot(self):
cluster_identifier = self._get_param('ClusterIdentifier')
snapshot_identifier = self._get_param('SnapshotIdentifier')
tags = self.unpack_complex_list_params(
'Tags.Tag', ('Key', 'Value'))
snapshot = self.redshift_backend.create_snapshot(cluster_identifier,
snapshot_identifier,
tags)
tags = self.unpack_complex_list_params('Tags.Tag', ('Key', 'Value'))
snapshot = self.redshift_backend.create_cluster_snapshot(cluster_identifier,
snapshot_identifier,
self.region,
tags)
return self.get_response({
'CreateClusterSnapshotResponse': {
"CreateClusterSnapshotResult": {
@ -399,9 +401,9 @@ class RedshiftResponse(BaseResponse):
def describe_cluster_snapshots(self):
cluster_identifier = self._get_param('ClusterIdentifier')
snapshot_identifier = self._get_param('DBSnapshotIdentifier')
snapshots = self.redshift_backend.describe_snapshots(cluster_identifier,
snapshot_identifier)
snapshot_identifier = self._get_param('SnapshotIdentifier')
snapshots = self.redshift_backend.describe_cluster_snapshots(cluster_identifier,
snapshot_identifier)
return self.get_response({
"DescribeClusterSnapshotsResponse": {
"DescribeClusterSnapshotsResult": {
@ -415,7 +417,7 @@ class RedshiftResponse(BaseResponse):
def delete_cluster_snapshot(self):
snapshot_identifier = self._get_param('SnapshotIdentifier')
snapshot = self.redshift_backend.delete_snapshot(snapshot_identifier)
snapshot = self.redshift_backend.delete_cluster_snapshot(snapshot_identifier)
return self.get_response({
"DeleteClusterSnapshotResponse": {
@ -428,13 +430,26 @@ class RedshiftResponse(BaseResponse):
}
})
def create_tags(self):
resource_name = self._get_param('ResourceName')
tags = self.unpack_complex_list_params('Tags.Tag', ('Key', 'Value'))
self.redshift_backend.create_tags(resource_name, tags)
return self.get_response({
"CreateTagsResponse": {
"ResponseMetadata": {
"RequestId": "384ac68d-3775-11df-8963-01868b7c937a",
}
}
})
def describe_tags(self):
resource_name = self._get_param('ResourceName')
resource_type = self._get_param('ResourceType')
if resource_type != 'Snapshot':
raise NotImplementedError(
"The describe_tags action has not been fully implemented.")
tagged_resources = \
self.redshift_backend.describe_tags_for_resource_type(resource_type)
tagged_resources = self.redshift_backend.describe_tags(resource_name,
resource_type)
return self.get_response({
"DescribeTagsResponse": {
"DescribeTagsResult": {
@ -445,3 +460,17 @@ class RedshiftResponse(BaseResponse):
}
}
})
def delete_tags(self):
resource_name = self._get_param('ResourceName')
tag_keys = self.unpack_list_params('TagKeys.TagKey')
self.redshift_backend.delete_tags(resource_name, tag_keys)
return self.get_response({
"DeleteTagsResponse": {
"ResponseMetadata": {
"RequestId": "384ac68d-3775-11df-8963-01868b7c937a",
}
}
})

View file

@ -4,7 +4,7 @@ from .responses import S3ResponseInstance
url_bases = [
"https?://s3(.*).amazonaws.com",
"https?://(?P<bucket_name>[a-zA-Z0-9\-_.]*)\.?s3(.*).amazonaws.com"
r"https?://(?P<bucket_name>[a-zA-Z0-9\-_.]*)\.?s3(.*).amazonaws.com"
]

View file

@ -1,22 +1,23 @@
from __future__ import unicode_literals
import argparse
import json
import re
import sys
import argparse
import six
from six.moves.urllib.parse import urlencode
from threading import Lock
import six
from flask import Flask
from flask.testing import FlaskClient
from six.moves.urllib.parse import urlencode
from werkzeug.routing import BaseConverter
from werkzeug.serving import run_simple
from moto.backends import BACKENDS
from moto.core.utils import convert_flask_to_httpretty_response
HTTP_METHODS = ["GET", "POST", "PUT", "DELETE", "HEAD", "PATCH"]
@ -61,7 +62,7 @@ class DomainDispatcherApplication(object):
host = "instance_metadata"
else:
host = environ['HTTP_HOST'].split(':')[0]
if host == "localhost":
if host in {'localhost', 'motoserver'} or host.startswith("192.168."):
# Fall back to parsing auth header to find service
# ['Credential=sdffdsa', '20170220', 'us-east-1', 'sns', 'aws4_request']
try:

View file

@ -12,6 +12,8 @@ from moto.compat import OrderedDict
from moto.core import BaseBackend, BaseModel
from moto.core.utils import iso_8601_datetime_with_milliseconds
from moto.sqs import sqs_backends
from moto.awslambda import lambda_backends
from .exceptions import (
SNSNotFoundError, DuplicateSnsEndpointError, SnsEndpointDisabled, SNSInvalidParameter
)
@ -88,6 +90,11 @@ class Subscription(BaseModel):
elif self.protocol in ['http', 'https']:
post_data = self.get_post_data(message, message_id)
requests.post(self.endpoint, json=post_data)
elif self.protocol == 'lambda':
# TODO: support bad function name
function_name = self.endpoint.split(":")[-1]
region = self.arn.split(':')[3]
lambda_backends[region].send_message(function_name, message)
def get_post_data(self, message, message_id):
return {
@ -221,6 +228,12 @@ class SNSBackend(BaseBackend):
except KeyError:
raise SNSNotFoundError("Topic with arn {0} not found".format(arn))
def get_topic_from_phone_number(self, number):
for subscription in self.subscriptions.values():
if subscription.protocol == 'sms' and subscription.endpoint == number:
return subscription.topic.arn
raise SNSNotFoundError('Could not find valid subscription')
def set_topic_attribute(self, topic_arn, attribute_name, attribute_value):
topic = self.get_topic(topic_arn)
setattr(topic, attribute_name, attribute_value)

View file

@ -6,6 +6,8 @@ from collections import defaultdict
from moto.core.responses import BaseResponse
from moto.core.utils import camelcase_to_underscores
from .models import sns_backends
from .exceptions import SNSNotFoundError
from .utils import is_e164
class SNSResponse(BaseResponse):
@ -136,6 +138,13 @@ class SNSResponse(BaseResponse):
topic_arn = self._get_param('TopicArn')
endpoint = self._get_param('Endpoint')
protocol = self._get_param('Protocol')
if protocol == 'sms' and not is_e164(endpoint):
return self._error(
'InvalidParameter',
'Phone number does not meet the E164 format'
), dict(status=400)
subscription = self.backend.subscribe(topic_arn, endpoint, protocol)
if self.request_json:
@ -229,7 +238,28 @@ class SNSResponse(BaseResponse):
def publish(self):
target_arn = self._get_param('TargetArn')
topic_arn = self._get_param('TopicArn')
arn = target_arn if target_arn else topic_arn
phone_number = self._get_param('PhoneNumber')
if phone_number is not None:
# Check phone is correct syntax (e164)
if not is_e164(phone_number):
return self._error(
'InvalidParameter',
'Phone number does not meet the E164 format'
), dict(status=400)
# Look up topic arn by phone number
try:
arn = self.backend.get_topic_from_phone_number(phone_number)
except SNSNotFoundError:
return self._error(
'ParameterValueInvalid',
'Could not find topic associated with phone number'
), dict(status=400)
elif target_arn is not None:
arn = target_arn
else:
arn = topic_arn
message = self._get_param('Message')
message_id = self.backend.publish(arn, message)

View file

@ -1,6 +1,9 @@
from __future__ import unicode_literals
import re
import uuid
E164_REGEX = re.compile(r'^\+?[1-9]\d{1,14}$')
def make_arn_for_topic(account_id, name, region_name):
return "arn:aws:sns:{0}:{1}:{2}".format(region_name, account_id, name)
@ -9,3 +12,7 @@ def make_arn_for_topic(account_id, name, region_name):
def make_arn_for_subscription(topic_arn):
subscription_id = uuid.uuid4()
return "{0}:{1}".format(topic_arn, subscription_id)
def is_e164(number):
return E164_REGEX.match(number) is not None