diff --git a/moto/ssm/models.py b/moto/ssm/models.py index bdc98e61..656a1483 100644 --- a/moto/ssm/models.py +++ b/moto/ssm/models.py @@ -3,6 +3,7 @@ from __future__ import unicode_literals from collections import defaultdict from moto.core import BaseBackend, BaseModel +from moto.core.exceptions import RESTError from moto.ec2 import ec2_backends import datetime @@ -58,11 +59,86 @@ class Parameter(BaseModel): return r +MAX_TIMEOUT_SECONDS = 3600 + + +class Command(BaseModel): + def __init__(self, comment='', document_name='', timeout_seconds=MAX_TIMEOUT_SECONDS, + instance_ids=None, max_concurrency='', max_errors='', + notification_config=None, output_s3_bucket_name='', + output_s3_key_prefix='', output_s3_region='', parameters=None, + service_role_arn='', targets=None): + + if instance_ids is None: + instance_ids = [] + + if notification_config is None: + notification_config = {} + + if parameters is None: + parameters = {} + + if targets is None: + targets = [] + + self.error_count = 0 + self.completed_count = len(instance_ids) + self.target_count = len(instance_ids) + self.command_id = str(uuid.uuid4()) + self.status = 'Success' + self.status_details = 'Details placeholder' + + now = datetime.datetime.now() + self.requested_date_time = now.isoformat() + expires_after = now + datetime.timedelta(0, timeout_seconds) + self.expires_after = expires_after.isoformat() + + self.comment = comment + self.document_name = document_name + self.instance_ids = instance_ids + self.max_concurrency = max_concurrency + self.max_errors = max_errors + self.notification_config = notification_config + self.output_s3_bucket_name = output_s3_bucket_name + self.output_s3_key_prefix = output_s3_key_prefix + self.output_s3_region = output_s3_region + self.parameters = parameters + self.service_role_arn = service_role_arn + self.targets = targets + + def response_object(self): + r = { + 'CommandId': self.command_id, + 'Comment': self.comment, + 'CompletedCount': self.completed_count, + 'DocumentName': self.document_name, + 'ErrorCount': self.error_count, + 'ExpiresAfter': self.expires_after, + 'InstanceIds': self.instance_ids, + 'MaxConcurrency': self.max_concurrency, + 'MaxErrors': self.max_errors, + 'NotificationConfig': self.notification_config, + 'OutputS3Region': self.output_s3_region, + 'OutputS3BucketName': self.output_s3_bucket_name, + 'OutputS3KeyPrefix': self.output_s3_key_prefix, + 'Parameters': self.parameters, + 'RequestedDateTime': self.requested_date_time, + 'ServiceRole': self.service_role_arn, + 'Status': self.status, + 'StatusDetails': self.status_details, + 'TargetCount': self.target_count, + 'Targets': self.targets, + } + + return r + + class SimpleSystemManagerBackend(BaseBackend): def __init__(self): self._parameters = {} self._resource_tags = defaultdict(lambda: defaultdict(dict)) + self._commands = [] def delete_parameter(self, name): try: @@ -167,38 +243,61 @@ class SimpleSystemManagerBackend(BaseBackend): return self._resource_tags[resource_type][resource_id] def send_command(self, **kwargs): - instances = kwargs.get('InstanceIds', []) - now = datetime.datetime.now() - expires_after = now + datetime.timedelta(0, int(kwargs.get('TimeoutSeconds', 3600))) + command = Command( + comment=kwargs.get('Comment', ''), + document_name=kwargs.get('DocumentName'), + timeout_seconds=kwargs.get('TimeoutSeconds', 3600), + instance_ids=kwargs.get('InstanceIds', []), + max_concurrency=kwargs.get('MaxConcurrency', '50'), + max_errors=kwargs.get('MaxErrors', '0'), + notification_config=kwargs.get('NotificationConfig', { + 'NotificationArn': 'string', + 'NotificationEvents': ['Success'], + 'NotificationType': 'Command' + }), + output_s3_bucket_name=kwargs.get('OutputS3BucketName', ''), + output_s3_key_prefix=kwargs.get('OutputS3KeyPrefix', ''), + output_s3_region=kwargs.get('OutputS3Region', ''), + parameters=kwargs.get('Parameters', {}), + service_role_arn=kwargs.get('ServiceRoleArn', ''), + targets=kwargs.get('Targets', [])) + + self._commands.append(command) return { - 'Command': { - 'CommandId': str(uuid.uuid4()), - 'DocumentName': kwargs['DocumentName'], - 'Comment': kwargs.get('Comment'), - 'ExpiresAfter': expires_after.isoformat(), - 'Parameters': kwargs['Parameters'], - 'InstanceIds': kwargs['InstanceIds'], - 'Targets': kwargs.get('targets'), - 'RequestedDateTime': now.isoformat(), - 'Status': 'Success', - 'StatusDetails': 'string', - 'OutputS3Region': kwargs.get('OutputS3Region'), - 'OutputS3BucketName': kwargs.get('OutputS3BucketName'), - 'OutputS3KeyPrefix': kwargs.get('OutputS3KeyPrefix'), - 'MaxConcurrency': 'string', - 'MaxErrors': 'string', - 'TargetCount': len(instances), - 'CompletedCount': len(instances), - 'ErrorCount': 0, - 'ServiceRole': kwargs.get('ServiceRoleArn'), - 'NotificationConfig': { - 'NotificationArn': 'string', - 'NotificationEvents': ['Success'], - 'NotificationType': 'Command' - } - } + 'Command': command.response_object() } + def list_commands(self, **kwargs): + """ + https://docs.aws.amazon.com/systems-manager/latest/APIReference/API_ListCommands.html + """ + commands = self._commands + + command_id = kwargs.get('CommandId', None) + if command_id: + commands = [self.get_command_by_id(command_id)] + instance_id = kwargs.get('InstanceId', None) + if instance_id: + commands = self.get_commands_by_instance_id(instance_id) + + return { + 'Commands': [command.response_object() for command in commands] + } + + def get_command_by_id(self, id): + command = next( + (command for command in self._commands if command.command_id == id), None) + + if command is None: + raise RESTError('InvalidCommandId', 'Invalid command id.') + + return command + + def get_commands_by_instance_id(self, instance_id): + return [ + command for command in self._commands + if instance_id in command.instance_ids] + ssm_backends = {} for region, ec2_backend in ec2_backends.items(): diff --git a/moto/ssm/responses.py b/moto/ssm/responses.py index e35eca5e..fd0d8b63 100644 --- a/moto/ssm/responses.py +++ b/moto/ssm/responses.py @@ -205,3 +205,8 @@ class SimpleSystemManagerResponse(BaseResponse): return json.dumps( self.ssm_backend.send_command(**self.request_params) ) + + def list_commands(self): + return json.dumps( + self.ssm_backend.list_commands(**self.request_params) + ) diff --git a/tests/test_ssm/test_ssm_boto3.py b/tests/test_ssm/test_ssm_boto3.py index e58879bc..7a0685d5 100644 --- a/tests/test_ssm/test_ssm_boto3.py +++ b/tests/test_ssm/test_ssm_boto3.py @@ -4,6 +4,10 @@ import boto3 import botocore.exceptions import sure # noqa import datetime +import uuid + +from botocore.exceptions import ClientError +from nose.tools import assert_raises from moto import mock_ssm @@ -608,3 +612,59 @@ def test_send_command(): cmd['OutputS3KeyPrefix'].should.equal('pref') cmd['ExpiresAfter'].should.be.greater_than(before) + + # test sending a command without any optional parameters + response = client.send_command( + DocumentName=ssm_document) + + cmd = response['Command'] + + cmd['CommandId'].should_not.be(None) + cmd['DocumentName'].should.equal(ssm_document) + + +@mock_ssm +def test_list_commands(): + client = boto3.client('ssm', region_name='us-east-1') + + ssm_document = 'AWS-RunShellScript' + params = {'commands': ['#!/bin/bash\necho \'hello world\'']} + + response = client.send_command( + InstanceIds=['i-123456'], + DocumentName=ssm_document, + Parameters=params, + OutputS3Region='us-east-2', + OutputS3BucketName='the-bucket', + OutputS3KeyPrefix='pref') + + cmd = response['Command'] + cmd_id = cmd['CommandId'] + + # get the command by id + response = client.list_commands( + CommandId=cmd_id) + + cmds = response['Commands'] + len(cmds).should.equal(1) + cmds[0]['CommandId'].should.equal(cmd_id) + + # add another command with the same instance id to test listing by + # instance id + client.send_command( + InstanceIds=['i-123456'], + DocumentName=ssm_document) + + response = client.list_commands( + InstanceId='i-123456') + + cmds = response['Commands'] + len(cmds).should.equal(2) + + for cmd in cmds: + cmd['InstanceIds'].should.contain('i-123456') + + # test the error case for an invalid command id + with assert_raises(ClientError): + response = client.list_commands( + CommandId=str(uuid.uuid4()))