Lambda improvements (#1344)

* Revamped the lambda function storage to do versioning.

* Flake8

* .

* Fixes

* Swapped around an if
This commit is contained in:
Terry Cain 2017-11-26 21:28:28 +00:00 committed by GitHub
commit d5ee48eedd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 326 additions and 125 deletions

View file

@ -2,6 +2,7 @@ from __future__ import unicode_literals
import base64
from collections import defaultdict
import copy
import datetime
import docker.errors
import hashlib
@ -17,18 +18,23 @@ import tarfile
import calendar
import threading
import traceback
import weakref
import requests.adapters
import boto.awslambda
from moto.core import BaseBackend, BaseModel
from moto.core.exceptions import RESTError
from moto.core.utils import unix_time_millis
from moto.s3.models import s3_backend
from moto.logs.models import logs_backends
from moto.s3.exceptions import MissingBucket, MissingKey
from moto import settings
from .utils import make_function_arn
logger = logging.getLogger(__name__)
ACCOUNT_ID = '123456789012'
try:
from tempfile import TemporaryDirectory
@ -121,7 +127,7 @@ class _DockerDataVolumeContext:
class LambdaFunction(BaseModel):
def __init__(self, spec, region, validate_s3=True):
def __init__(self, spec, region, validate_s3=True, version=1):
# required
self.region = region
self.code = spec['Code']
@ -161,7 +167,7 @@ class LambdaFunction(BaseModel):
'VpcConfig', {'SubnetIds': [], 'SecurityGroupIds': []})
# auto-generated
self.version = '$LATEST'
self.version = version
self.last_modified = datetime.datetime.utcnow().strftime(
'%Y-%m-%d %H:%M:%S')
@ -203,11 +209,15 @@ class LambdaFunction(BaseModel):
self.code_size = key.size
self.code_sha_256 = hashlib.sha256(key.value).hexdigest()
self.function_arn = 'arn:aws:lambda:{}:123456789012:function:{}'.format(
self.region, self.function_name)
self.function_arn = make_function_arn(self.region, ACCOUNT_ID, self.function_name, version)
self.tags = dict()
def set_version(self, version):
self.function_arn = make_function_arn(self.region, ACCOUNT_ID, self.function_name, version)
self.version = version
self.last_modified = datetime.datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S')
@property
def vpc_config(self):
config = self._vpc_config.copy()
@ -231,7 +241,7 @@ class LambdaFunction(BaseModel):
"Role": self.role,
"Runtime": self.run_time,
"Timeout": self.timeout,
"Version": self.version,
"Version": str(self.version),
"VpcConfig": self.vpc_config,
}
@ -389,8 +399,7 @@ class LambdaFunction(BaseModel):
from moto.cloudformation.exceptions import \
UnformattedGetAttTemplateException
if attribute_name == 'Arn':
return 'arn:aws:lambda:{0}:123456789012:function:{1}'.format(
self.region, self.function_name)
return make_function_arn(self.region, ACCOUNT_ID, self.function_name)
raise UnformattedGetAttTemplateException()
@staticmethod
@ -446,9 +455,121 @@ class LambdaVersion(BaseModel):
return LambdaVersion(spec)
class LambdaStorage(object):
def __init__(self):
# Format 'func_name' {'alias': {}, 'versions': []}
self._functions = {}
self._arns = weakref.WeakValueDictionary()
def _get_latest(self, name):
return self._functions[name]['latest']
def _get_version(self, name, version):
index = version - 1
try:
return self._functions[name]['versions'][index]
except IndexError:
return None
def _get_alias(self, name, alias):
return self._functions[name]['alias'].get(alias, None)
def get_function(self, name, qualifier=None):
if name not in self._functions:
return None
if qualifier is None:
return self._get_latest(name)
try:
return self._get_version(name, int(qualifier))
except ValueError:
return self._functions[name]['latest']
def get_arn(self, arn):
return self._arns.get(arn, None)
def put_function(self, fn):
"""
:param fn: Function
:type fn: LambdaFunction
"""
if fn.function_name in self._functions:
self._functions[fn.function_name]['latest'] = fn
else:
self._functions[fn.function_name] = {
'latest': fn,
'versions': [],
'alias': weakref.WeakValueDictionary()
}
self._arns[fn.function_arn] = fn
def publish_function(self, name):
if name not in self._functions:
return None
if not self._functions[name]['latest']:
return None
new_version = len(self._functions[name]['versions']) + 1
fn = copy.copy(self._functions[name]['latest'])
fn.set_version(new_version)
self._functions[name]['versions'].append(fn)
return fn
def del_function(self, name, qualifier=None):
if name in self._functions:
if not qualifier:
# Something is still reffing this so delete all arns
latest = self._functions[name]['latest'].function_arn
del self._arns[latest]
for fn in self._functions[name]['versions']:
del self._arns[fn.function_arn]
del self._functions[name]
return True
elif qualifier == '$LATEST':
self._functions[name]['latest'] = None
# If theres no functions left
if not self._functions[name]['versions'] and not self._functions[name]['latest']:
del self._functions[name]
return True
else:
fn = self.get_function(name, qualifier)
if fn:
self._functions[name]['versions'].remove(fn)
# If theres no functions left
if not self._functions[name]['versions'] and not self._functions[name]['latest']:
del self._functions[name]
return True
return False
def all(self):
result = []
for function_group in self._functions.values():
if function_group['latest'] is not None:
result.append(function_group['latest'])
result.extend(function_group['versions'])
return result
class LambdaBackend(BaseBackend):
def __init__(self, region_name):
self._functions = {}
self._lambdas = LambdaStorage()
self.region_name = region_name
def reset(self):
@ -456,31 +577,31 @@ class LambdaBackend(BaseBackend):
self.__dict__ = {}
self.__init__(region_name)
def has_function(self, function_name):
return function_name in self._functions
def has_function_arn(self, function_arn):
return self.get_function_by_arn(function_arn) is not None
def create_function(self, spec):
fn = LambdaFunction(spec, self.region_name)
self._functions[fn.function_name] = fn
function_name = spec.get('FunctionName', None)
if function_name is None:
raise RESTError('InvalidParameterValueException', 'Missing FunctionName')
fn = LambdaFunction(spec, self.region_name, version='$LATEST')
self._lambdas.put_function(fn)
return fn
def get_function(self, function_name):
return self._functions[function_name]
def publish_function(self, function_name):
return self._lambdas.publish_function(function_name)
def get_function(self, function_name, qualifier=None):
return self._lambdas.get_function(function_name, qualifier)
def get_function_by_arn(self, function_arn):
for function in self._functions.values():
if function.function_arn == function_arn:
return function
return None
return self._lambdas.get_arn(function_arn)
def delete_function(self, function_name):
del self._functions[function_name]
def delete_function(self, function_name, qualifier=None):
return self._lambdas.del_function(function_name, qualifier)
def list_functions(self):
return self._functions.values()
return self._lambdas.all()
def send_message(self, function_name, message):
event = {
@ -515,23 +636,31 @@ class LambdaBackend(BaseBackend):
]
}
self._functions[function_name].invoke(json.dumps(event), {}, {})
self._functions[function_name][-1].invoke(json.dumps(event), {}, {})
pass
def list_tags(self, resource):
return self.get_function_by_arn(resource).tags
def tag_resource(self, resource, tags):
self.get_function_by_arn(resource).tags.update(tags)
fn = self.get_function_by_arn(resource)
if not fn:
return False
fn.tags.update(tags)
return True
def untag_resource(self, resource, tagKeys):
function = self.get_function_by_arn(resource)
for key in tagKeys:
try:
del function.tags[key]
except KeyError:
pass
# Don't care
fn = self.get_function_by_arn(resource)
if fn:
for key in tagKeys:
try:
del fn.tags[key]
except KeyError:
pass
# Don't care
return True
return False
def add_policy(self, function_name, policy):
self.get_function(function_name).policy = policy