diff --git a/moto/ec2/models.py b/moto/ec2/models.py index 95903b82..20d5af1c 100644 --- a/moto/ec2/models.py +++ b/moto/ec2/models.py @@ -1,4 +1,5 @@ import copy +import itertools from collections import defaultdict from boto.ec2.instance import Instance as BotoInstance, Reservation @@ -312,36 +313,36 @@ class SecurityGroup(object): class SecurityGroupBackend(object): def __init__(self): - self.groups = {} - self.vpc_groups = {} - + # the key in the dict group is the vpc_id or None (non-vpc) + self.groups = defaultdict(dict) super(SecurityGroupBackend, self).__init__() def create_security_group(self, name, description, vpc_id=None, force=False): group_id = random_security_group_id() if not force: - existing_group = self.get_security_group_from_name(name) + existing_group = self.get_security_group_from_name(name, vpc_id) if existing_group: return None group = SecurityGroup(group_id, name, description, vpc_id=vpc_id) - self.groups[group_id] = group + + self.groups[vpc_id][group_id] = group return group def describe_security_groups(self): - return self.groups.values() + return itertools.chain(*[x.values() for x in self.groups.values()]) - def delete_security_group(self, name_or_group_id): - if name_or_group_id in self.groups: + def delete_security_group(self, name_or_group_id, vpc_id): + if name_or_group_id in self.groups[vpc_id]: # Group Id - return self.groups.pop(name_or_group_id) + return self.groups[vpc_id].pop(name_or_group_id) else: # Group Name - group = self.get_security_group_from_name(name_or_group_id) + group = self.get_security_group_from_name(name_or_group_id, vpc_id) if group: - return self.groups.pop(group.id) + return self.groups[vpc_id].pop(group.id) - def get_security_group_from_name(self, name): - for group_id, group in self.groups.iteritems(): + def get_security_group_from_name(self, name, vpc_id): + for group_id, group in self.groups[vpc_id].iteritems(): if group.name == name: return group @@ -350,16 +351,16 @@ class SecurityGroupBackend(object): default_group = ec2_backend.create_security_group("default", "The default security group", force=True) return default_group - def authorize_security_group_ingress(self, group_name, ip_protocol, from_port, to_port, ip_ranges=None, source_group_names=None): - group = self.get_security_group_from_name(group_name) + def authorize_security_group_ingress(self, group_name, ip_protocol, from_port, to_port, ip_ranges=None, source_group_names=None, vpc_id=None): + group = self.get_security_group_from_name(group_name, vpc_id) source_groups = [] for source_group_name in source_group_names: - source_groups.append(self.get_security_group_from_name(source_group_name)) + source_groups.append(self.get_security_group_from_name(source_group_name, vpc_id)) security_rule = SecurityRule(ip_protocol, from_port, to_port, ip_ranges, source_groups) group.ingress_rules.append(security_rule) - def revoke_security_group_ingress(self, group_name, ip_protocol, from_port, to_port, ip_ranges=None, source_group_names=None): + def revoke_security_group_ingress(self, group_name, ip_protocol, from_port, to_port, ip_ranges=None, source_group_names=None, vpc_id=None): group = self.get_security_group_from_name(group_name) source_groups = [] for source_group_name in source_group_names: diff --git a/tests/test_ec2/test_security_groups.py b/tests/test_ec2/test_security_groups.py index 606faebc..6b33c360 100644 --- a/tests/test_ec2/test_security_groups.py +++ b/tests/test_ec2/test_security_groups.py @@ -31,8 +31,8 @@ def test_create_and_describe_vpc_security_group(): security_group.name.should.equal('test security group') security_group.description.should.equal('this is a test security group') - # Trying to create another group with the same name should throw an error - conn.create_security_group.when.called_with('test security group', 'this is a test security group').should.throw(EC2ResponseError) + # Trying to create another group with the same name in the same VPC should throw an error + conn.create_security_group.when.called_with('test security group', 'this is a test security group', vpc_id).should.throw(EC2ResponseError) all_groups = conn.get_all_security_groups() @@ -41,6 +41,21 @@ def test_create_and_describe_vpc_security_group(): all_groups.should.have.length_of(1) all_groups[0].name.should.equal('test security group') +@mock_ec2 +def test_create_two_security_groups_with_same_name_in_different_vpc(): + conn = boto.connect_ec2('the_key', 'the_secret') + vpc_id = 'vpc-5300000c' + vpc_id2 = 'vpc-5300000d' + + sg1 = conn.create_security_group('test security group', 'this is a test security group', vpc_id) + sg2 = conn.create_security_group('test security group', 'this is a test security group', vpc_id2) + + all_groups = conn.get_all_security_groups() + + all_groups.should.have.length_of(2) + all_groups[0].name.should.equal('test security group') + all_groups[1].name.should.equal('test security group') + @mock_ec2 def test_deleting_security_groups():