diff --git a/moto/ec2/models.py b/moto/ec2/models.py index 3a277523..e7e8a1dd 100755 --- a/moto/ec2/models.py +++ b/moto/ec2/models.py @@ -1360,22 +1360,25 @@ class SecurityGroupBackend(object): return group def describe_security_groups(self, group_ids=None, groupnames=None, filters=None): - all_groups = itertools.chain(*[x.values() - for x in self.groups.values()]) - groups = [] + matches = itertools.chain(*[x.values() + for x in self.groups.values()]) + if group_ids: + matches = [grp for grp in matches + if grp.id in group_ids] + if len(group_ids) > len(matches): + unknown_ids = set(group_ids) - set(matches) + raise InvalidSecurityGroupNotFoundError(unknown_ids) + if groupnames: + matches = [grp for grp in matches + if grp.name in groupnames] + if len(groupnames) > len(matches): + unknown_names = set(groupnames) - set(matches) + raise InvalidSecurityGroupNotFoundError(unknown_names) + if filters: + matches = [grp for grp in matches + if grp.matches_filters(filters)] - if group_ids or groupnames or filters: - for group in all_groups: - if ((group_ids and group.id not in group_ids) or - (groupnames and group.name not in groupnames)): - continue - if filters and not group.matches_filters(filters): - continue - groups.append(group) - else: - groups = all_groups - - return groups + return matches def _delete_security_group(self, vpc_id, group_id): if self.groups[vpc_id][group_id].enis: @@ -1772,11 +1775,17 @@ class EBSBackend(object): self.volumes[volume_id] = volume return volume - def describe_volumes(self, filters=None): + def describe_volumes(self, volume_ids=None, filters=None): + matches = self.volumes.values() + if volume_ids: + matches = [vol for vol in matches + if vol.id in volume_ids] + if len(volume_ids) > len(matches): + unknown_ids = set(volume_ids) - set(matches) + raise InvalidVolumeIdError(unknown_ids) if filters: - volumes = self.volumes.values() - return generic_filter(filters, volumes) - return self.volumes.values() + matches = generic_filter(filters, matches) + return matches def get_volume(self, volume_id): volume = self.volumes.get(volume_id, None) @@ -1824,11 +1833,17 @@ class EBSBackend(object): self.snapshots[snapshot_id] = snapshot return snapshot - def describe_snapshots(self, filters=None): + def describe_snapshots(self, snapshot_ids=None, filters=None): + matches = self.snapshots.values() + if snapshot_ids: + matches = [snap for snap in matches + if snap.id in snapshot_ids] + if len(snapshot_ids) > len(matches): + unknown_ids = set(snapshot_ids) - set(matches) + raise InvalidSnapshotIdError(unknown_ids) if filters: - snapshots = self.snapshots.values() - return generic_filter(filters, snapshots) - return self.snapshots.values() + matches = generic_filter(filters, matches) + return matches def get_snapshot(self, snapshot_id): snapshot = self.snapshots.get(snapshot_id, None) @@ -1947,12 +1962,16 @@ class VPCBackend(object): return self.vpcs.get(vpc_id) def get_all_vpcs(self, vpc_ids=None, filters=None): + matches = self.vpcs.values() if vpc_ids: - vpcs = [vpc for vpc in self.vpcs.values() if vpc.id in vpc_ids] - else: - vpcs = self.vpcs.values() - - return generic_filter(filters, vpcs) + matches = [vpc for vpc in matches + if vpc.id in vpc_ids] + if len(vpc_ids) > len(matches): + unknown_ids = set(vpc_ids) - set(matches) + raise InvalidVPCIdError(unknown_ids) + if filters: + matches = generic_filter(filters, matches) + return matches def delete_vpc(self, vpc_id): # Delete route table if only main route table remains. @@ -2189,16 +2208,19 @@ class SubnetBackend(object): return subnet def get_all_subnets(self, subnet_ids=None, filters=None): - subnets = [] + # Extract a list of all subnets + matches = itertools.chain(*[x.values() + for x in self.subnets.values()]) if subnet_ids: - for subnet_id in subnet_ids: - for items in self.subnets.values(): - if subnet_id in items: - subnets.append(items[subnet_id]) - else: - for items in self.subnets.values(): - subnets.extend(items.values()) - return generic_filter(filters, subnets) + matches = [sn for sn in matches + if sn.id in subnet_ids] + if len(subnet_ids) > len(matches): + unknown_ids = set(subnet_ids) - set(matches) + raise InvalidSubnetIdError(unknown_ids) + if filters: + matches = generic_filter(filters, matches) + + return matches def delete_subnet(self, subnet_id): for subnets in self.subnets.values(): diff --git a/moto/ec2/responses/elastic_block_store.py b/moto/ec2/responses/elastic_block_store.py index 8f12dc91..37b3e9a0 100644 --- a/moto/ec2/responses/elastic_block_store.py +++ b/moto/ec2/responses/elastic_block_store.py @@ -54,20 +54,14 @@ class ElasticBlockStore(BaseResponse): def describe_snapshots(self): filters = filters_from_querystring(self.querystring) snapshot_ids = self._get_multi_param('SnapshotId') - snapshots = self.ec2_backend.describe_snapshots(filters=filters) - # Describe snapshots to handle filter on snapshot_ids - snapshots = [ - s for s in snapshots if s.id in snapshot_ids] if snapshot_ids else snapshots + snapshots = self.ec2_backend.describe_snapshots(snapshot_ids=snapshot_ids, filters=filters) template = self.response_template(DESCRIBE_SNAPSHOTS_RESPONSE) return template.render(snapshots=snapshots) def describe_volumes(self): filters = filters_from_querystring(self.querystring) volume_ids = self._get_multi_param('VolumeId') - volumes = self.ec2_backend.describe_volumes(filters=filters) - # Describe volumes to handle filter on volume_ids - volumes = [ - v for v in volumes if v.id in volume_ids] if volume_ids else volumes + volumes = self.ec2_backend.describe_volumes(volume_ids=volume_ids, filters=filters) template = self.response_template(DESCRIBE_VOLUMES_RESPONSE) return template.render(volumes=volumes) diff --git a/tests/test_ec2/test_elastic_block_store.py b/tests/test_ec2/test_elastic_block_store.py index b238e68f..4427d484 100644 --- a/tests/test_ec2/test_elastic_block_store.py +++ b/tests/test_ec2/test_elastic_block_store.py @@ -83,6 +83,12 @@ def test_filter_volume_by_id(): vol2 = conn.get_all_volumes(volume_ids=[volume1.id, volume2.id]) vol2.should.have.length_of(2) + with assert_raises(EC2ResponseError) as cm: + conn.get_all_volumes(volume_ids=['vol-does_not_exist']) + cm.exception.code.should.equal('InvalidVolume.NotFound') + cm.exception.status.should.equal(400) + cm.exception.request_id.should_not.be.none + @mock_ec2_deprecated def test_volume_filters(): @@ -302,6 +308,12 @@ def test_filter_snapshot_by_id(): s.volume_id.should.be.within([volume2.id, volume3.id]) s.region.name.should.equal(conn.region.name) + with assert_raises(EC2ResponseError) as cm: + conn.get_all_snapshots(snapshot_ids=['snap-does_not_exist']) + cm.exception.code.should.equal('InvalidSnapshot.NotFound') + cm.exception.status.should.equal(400) + cm.exception.request_id.should_not.be.none + @mock_ec2_deprecated def test_snapshot_filters(): diff --git a/tests/test_ec2/test_security_groups.py b/tests/test_ec2/test_security_groups.py index 21ecad11..45e6e327 100644 --- a/tests/test_ec2/test_security_groups.py +++ b/tests/test_ec2/test_security_groups.py @@ -348,6 +348,15 @@ def test_get_all_security_groups(): resp.should.have.length_of(1) resp[0].id.should.equal(sg1.id) + with assert_raises(EC2ResponseError) as cm: + conn.get_all_security_groups(groupnames=['does_not_exist']) + cm.exception.code.should.equal('InvalidGroup.NotFound') + cm.exception.status.should.equal(400) + cm.exception.request_id.should_not.be.none + + resp.should.have.length_of(1) + resp[0].id.should.equal(sg1.id) + resp = conn.get_all_security_groups(filters={'vpc-id': ['vpc-mjm05d27']}) resp.should.have.length_of(1) resp[0].id.should.equal(sg1.id) @@ -681,3 +690,9 @@ def test_get_all_security_groups_filter_with_same_vpc_id(): security_groups = conn.get_all_security_groups( group_ids=[security_group.id], filters={'vpc-id': [vpc_id]}) security_groups.should.have.length_of(1) + + with assert_raises(EC2ResponseError) as cm: + conn.get_all_security_groups(group_ids=['does_not_exist']) + cm.exception.code.should.equal('InvalidGroup.NotFound') + cm.exception.status.should.equal(400) + cm.exception.request_id.should_not.be.none diff --git a/tests/test_ec2/test_subnets.py b/tests/test_ec2/test_subnets.py index 38565a28..99e6d45d 100644 --- a/tests/test_ec2/test_subnets.py +++ b/tests/test_ec2/test_subnets.py @@ -158,6 +158,32 @@ def test_modify_subnet_attribute_validation(): SubnetId=subnet.id, MapPublicIpOnLaunch={'Value': 'invalid'}) +@mock_ec2_deprecated +def test_subnet_get_by_id(): + ec2 = boto.ec2.connect_to_region('us-west-1') + conn = boto.vpc.connect_to_region('us-west-1') + vpcA = conn.create_vpc("10.0.0.0/16") + subnetA = conn.create_subnet( + vpcA.id, "10.0.0.0/24", availability_zone='us-west-1a') + vpcB = conn.create_vpc("10.0.0.0/16") + subnetB1 = conn.create_subnet( + vpcB.id, "10.0.0.0/24", availability_zone='us-west-1a') + subnetB2 = conn.create_subnet( + vpcB.id, "10.0.1.0/24", availability_zone='us-west-1b') + + subnets_by_id = conn.get_all_subnets(subnet_ids=[subnetA.id, subnetB1.id]) + subnets_by_id.should.have.length_of(2) + subnets_by_id = tuple(map(lambda s: s.id, subnets_by_id)) + subnetA.id.should.be.within(subnets_by_id) + subnetB1.id.should.be.within(subnets_by_id) + + with assert_raises(EC2ResponseError) as cm: + conn.get_all_subnets(subnet_ids=['subnet-does_not_exist']) + cm.exception.code.should.equal('InvalidSubnetID.NotFound') + cm.exception.status.should.equal(400) + cm.exception.request_id.should_not.be.none + + @mock_ec2_deprecated def test_get_subnets_filtering(): ec2 = boto.ec2.connect_to_region('us-west-1') diff --git a/tests/test_ec2/test_vpcs.py b/tests/test_ec2/test_vpcs.py index 904603f6..fc0a93cb 100644 --- a/tests/test_ec2/test_vpcs.py +++ b/tests/test_ec2/test_vpcs.py @@ -113,6 +113,12 @@ def test_vpc_get_by_id(): vpc1.id.should.be.within(vpc_ids) vpc2.id.should.be.within(vpc_ids) + with assert_raises(EC2ResponseError) as cm: + conn.get_all_vpcs(vpc_ids=['vpc-does_not_exist']) + cm.exception.code.should.equal('InvalidVpcID.NotFound') + cm.exception.status.should.equal(400) + cm.exception.request_id.should_not.be.none + @mock_ec2_deprecated def test_vpc_get_by_cidr_block():