129 lines
4.8 KiB
Python
129 lines
4.8 KiB
Python
from __future__ import unicode_literals
|
|
|
|
import os
|
|
import json
|
|
import boto3
|
|
import base64
|
|
|
|
from moto.core import BaseBackend, BaseModel
|
|
from moto.dynamodb2.models import dynamodb_backends
|
|
|
|
|
|
class ShardIterator(BaseModel):
|
|
def __init__(self, streams_backend, stream_shard, shard_iterator_type, sequence_number=None):
|
|
self.id = base64.b64encode(os.urandom(472)).decode('utf-8')
|
|
self.streams_backend = streams_backend
|
|
self.stream_shard = stream_shard
|
|
self.shard_iterator_type = shard_iterator_type
|
|
if shard_iterator_type == 'TRIM_HORIZON':
|
|
self.sequence_number = stream_shard.starting_sequence_number
|
|
elif shard_iterator_type == 'LATEST':
|
|
self.sequence_number = stream_shard.starting_sequence_number + len(stream_shard.items)
|
|
elif shard_iterator_type == 'AT_SEQUENCE_NUMBER':
|
|
self.sequence_number = sequence_number
|
|
elif shard_iterator_type == 'AFTER_SEQUENCE_NUMBER':
|
|
self.sequence_number = sequence_number + 1
|
|
|
|
@property
|
|
def arn(self):
|
|
return '{}/stream/{}|1|{}'.format(
|
|
self.stream_shard.table.table_arn,
|
|
self.stream_shard.table.latest_stream_label,
|
|
self.id)
|
|
|
|
def to_json(self):
|
|
return {
|
|
'ShardIterator': self.arn
|
|
}
|
|
|
|
def get(self, limit=1000):
|
|
items = self.stream_shard.get(self.sequence_number, limit)
|
|
try:
|
|
last_sequence_number = max(i['dynamodb']['SequenceNumber'] for i in items)
|
|
new_shard_iterator = ShardIterator(self.streams_backend,
|
|
self.stream_shard,
|
|
'AFTER_SEQUENCE_NUMBER',
|
|
last_sequence_number)
|
|
except ValueError:
|
|
new_shard_iterator = ShardIterator(self.streams_backend,
|
|
self.stream_shard,
|
|
'AT_SEQUENCE_NUMBER',
|
|
self.sequence_number)
|
|
|
|
self.streams_backend.shard_iterators[new_shard_iterator.arn] = new_shard_iterator
|
|
return {
|
|
'NextShardIterator': new_shard_iterator.arn,
|
|
'Records': items
|
|
}
|
|
|
|
|
|
class DynamoDBStreamsBackend(BaseBackend):
|
|
def __init__(self, region):
|
|
self.region = region
|
|
self.shard_iterators = {}
|
|
|
|
def reset(self):
|
|
region = self.region
|
|
self.__dict__ = {}
|
|
self.__init__(region)
|
|
|
|
@property
|
|
def dynamodb(self):
|
|
return dynamodb_backends[self.region]
|
|
|
|
def _get_table_from_arn(self, arn):
|
|
table_name = arn.split(':', 6)[5].split('/')[1]
|
|
return self.dynamodb.get_table(table_name)
|
|
|
|
def describe_stream(self, arn):
|
|
table = self._get_table_from_arn(arn)
|
|
resp = {'StreamDescription': {
|
|
'StreamArn': arn,
|
|
'StreamLabel': table.latest_stream_label,
|
|
'StreamStatus': ('ENABLED' if table.latest_stream_label
|
|
else 'DISABLED'),
|
|
'StreamViewType': table.stream_specification['StreamViewType'],
|
|
'CreationRequestDateTime': table.stream_shard.created_on.isoformat(),
|
|
'TableName': table.name,
|
|
'KeySchema': table.schema,
|
|
'Shards': ([table.stream_shard.to_json()] if table.stream_shard
|
|
else [])
|
|
}}
|
|
|
|
return json.dumps(resp)
|
|
|
|
def list_streams(self, table_name=None):
|
|
streams = []
|
|
for table in self.dynamodb.tables.values():
|
|
if table_name is not None and table.name != table_name:
|
|
continue
|
|
if table.latest_stream_label:
|
|
d = table.describe(base_key='Table')
|
|
streams.append({
|
|
'StreamArn': d['Table']['LatestStreamArn'],
|
|
'TableName': d['Table']['TableName'],
|
|
'StreamLabel': d['Table']['LatestStreamLabel']
|
|
})
|
|
|
|
return json.dumps({'Streams': streams})
|
|
|
|
def get_shard_iterator(self, arn, shard_id, shard_iterator_type, sequence_number=None):
|
|
table = self._get_table_from_arn(arn)
|
|
assert table.stream_shard.id == shard_id
|
|
|
|
shard_iterator = ShardIterator(self, table.stream_shard,
|
|
shard_iterator_type,
|
|
sequence_number)
|
|
self.shard_iterators[shard_iterator.arn] = shard_iterator
|
|
|
|
return json.dumps(shard_iterator.to_json())
|
|
|
|
def get_records(self, iterator_arn, limit):
|
|
shard_iterator = self.shard_iterators[iterator_arn]
|
|
return json.dumps(shard_iterator.get(limit))
|
|
|
|
|
|
available_regions = boto3.session.Session().get_available_regions(
|
|
'dynamodbstreams')
|
|
dynamodbstreams_backends = {region: DynamoDBStreamsBackend(region=region)
|
|
for region in available_regions}
|