From 08c4eff0b26518cf14c7eacaa1eb060bfc17ea50 Mon Sep 17 00:00:00 2001 From: Nuwan Goonasekera Date: Mon, 18 Sep 2017 23:12:39 +0530 Subject: [PATCH] Added invalid id exceptions when filtering snapshots and volumes --- moto/ec2/models.py | 28 +++++++++++++++------- moto/ec2/responses/elastic_block_store.py | 10 ++------ tests/test_ec2/test_elastic_block_store.py | 12 ++++++++++ 3 files changed, 34 insertions(+), 16 deletions(-) diff --git a/moto/ec2/models.py b/moto/ec2/models.py index 8e5bcdfb..4b143eea 100755 --- a/moto/ec2/models.py +++ b/moto/ec2/models.py @@ -1775,11 +1775,17 @@ class EBSBackend(object): self.volumes[volume_id] = volume return volume - def describe_volumes(self, filters=None): + def describe_volumes(self, volume_ids=None, filters=None): + matches = self.volumes.values() + if volume_ids: + matches = [vol for vol in matches + if vol.id in volume_ids] + if len(volume_ids) > len(matches): + unknown_ids = set(volume_ids) - set(matches) + raise InvalidVolumeIdError(unknown_ids) if filters: - volumes = self.volumes.values() - return generic_filter(filters, volumes) - return self.volumes.values() + matches = generic_filter(filters, matches) + return matches def get_volume(self, volume_id): volume = self.volumes.get(volume_id, None) @@ -1827,11 +1833,17 @@ class EBSBackend(object): self.snapshots[snapshot_id] = snapshot return snapshot - def describe_snapshots(self, filters=None): + def describe_snapshots(self, snapshot_ids=None, filters=None): + matches = self.snapshots.values() + if snapshot_ids: + matches = [vol for vol in matches + if vol.id in snapshot_ids] + if len(snapshot_ids) > len(matches): + unknown_ids = set(snapshot_ids) - set(matches) + raise InvalidSnapshotIdError(unknown_ids) if filters: - snapshots = self.snapshots.values() - return generic_filter(filters, snapshots) - return self.snapshots.values() + matches = generic_filter(filters, matches) + return matches def get_snapshot(self, snapshot_id): snapshot = self.snapshots.get(snapshot_id, None) diff --git a/moto/ec2/responses/elastic_block_store.py b/moto/ec2/responses/elastic_block_store.py index 8f12dc91..37b3e9a0 100644 --- a/moto/ec2/responses/elastic_block_store.py +++ b/moto/ec2/responses/elastic_block_store.py @@ -54,20 +54,14 @@ class ElasticBlockStore(BaseResponse): def describe_snapshots(self): filters = filters_from_querystring(self.querystring) snapshot_ids = self._get_multi_param('SnapshotId') - snapshots = self.ec2_backend.describe_snapshots(filters=filters) - # Describe snapshots to handle filter on snapshot_ids - snapshots = [ - s for s in snapshots if s.id in snapshot_ids] if snapshot_ids else snapshots + snapshots = self.ec2_backend.describe_snapshots(snapshot_ids=snapshot_ids, filters=filters) template = self.response_template(DESCRIBE_SNAPSHOTS_RESPONSE) return template.render(snapshots=snapshots) def describe_volumes(self): filters = filters_from_querystring(self.querystring) volume_ids = self._get_multi_param('VolumeId') - volumes = self.ec2_backend.describe_volumes(filters=filters) - # Describe volumes to handle filter on volume_ids - volumes = [ - v for v in volumes if v.id in volume_ids] if volume_ids else volumes + volumes = self.ec2_backend.describe_volumes(volume_ids=volume_ids, filters=filters) template = self.response_template(DESCRIBE_VOLUMES_RESPONSE) return template.render(volumes=volumes) diff --git a/tests/test_ec2/test_elastic_block_store.py b/tests/test_ec2/test_elastic_block_store.py index b238e68f..4427d484 100644 --- a/tests/test_ec2/test_elastic_block_store.py +++ b/tests/test_ec2/test_elastic_block_store.py @@ -83,6 +83,12 @@ def test_filter_volume_by_id(): vol2 = conn.get_all_volumes(volume_ids=[volume1.id, volume2.id]) vol2.should.have.length_of(2) + with assert_raises(EC2ResponseError) as cm: + conn.get_all_volumes(volume_ids=['vol-does_not_exist']) + cm.exception.code.should.equal('InvalidVolume.NotFound') + cm.exception.status.should.equal(400) + cm.exception.request_id.should_not.be.none + @mock_ec2_deprecated def test_volume_filters(): @@ -302,6 +308,12 @@ def test_filter_snapshot_by_id(): s.volume_id.should.be.within([volume2.id, volume3.id]) s.region.name.should.equal(conn.region.name) + with assert_raises(EC2ResponseError) as cm: + conn.get_all_snapshots(snapshot_ids=['snap-does_not_exist']) + cm.exception.code.should.equal('InvalidSnapshot.NotFound') + cm.exception.status.should.equal(400) + cm.exception.request_id.should_not.be.none + @mock_ec2_deprecated def test_snapshot_filters():