From fab37942c49bbc97d052df4f5f37e390d3d7d96d Mon Sep 17 00:00:00 2001 From: "Chris St. Pierre" Date: Thu, 8 May 2014 10:41:28 -0400 Subject: [PATCH] Consistent _get_multi_param() function in responses This abstracts _get_multi_param() into BaseResponse and makes it always ensure that the string it has been given ends with a '.'. It had been implemented in three different places, and in use it rarely postpended a trailing period, which could make it match parameters it shouldn't have. --- moto/autoscaling/responses.py | 9 +-------- moto/core/responses.py | 8 ++++++++ moto/ec2/responses/instances.py | 4 ---- moto/ec2/responses/spot_instances.py | 5 +---- 4 files changed, 10 insertions(+), 16 deletions(-) diff --git a/moto/autoscaling/responses.py b/moto/autoscaling/responses.py index 32eb7b30..267db1f9 100644 --- a/moto/autoscaling/responses.py +++ b/moto/autoscaling/responses.py @@ -6,18 +6,11 @@ from .models import autoscaling_backend class AutoScalingResponse(BaseResponse): - - def _get_param(self, param_name): - return self.querystring.get(param_name, [None])[0] - def _get_int_param(self, param_name): value = self._get_param(param_name) if value is not None: return int(value) - def _get_multi_param(self, param_prefix): - return [value[0] for key, value in self.querystring.items() if key.startswith(param_prefix)] - def _get_list_prefix(self, param_prefix): results = [] param_index = 1 @@ -43,7 +36,7 @@ class AutoScalingResponse(BaseResponse): name=self._get_param('LaunchConfigurationName'), image_id=self._get_param('ImageId'), key_name=self._get_param('KeyName'), - security_groups=self._get_multi_param('SecurityGroups.member.'), + security_groups=self._get_multi_param('SecurityGroups.member'), user_data=self._get_param('UserData'), instance_type=self._get_param('InstanceType'), instance_monitoring=instance_monitoring, diff --git a/moto/core/responses.py b/moto/core/responses.py index 8e6c5bfc..fee64184 100644 --- a/moto/core/responses.py +++ b/moto/core/responses.py @@ -66,6 +66,14 @@ class BaseResponse(object): def _get_param(self, param_name): return self.querystring.get(param_name, [None])[0] + def _get_multi_param(self, param_prefix): + if param_prefix.endswith("."): + prefix = param_prefix + else: + prefix = param_prefix + "." + return [value[0] for key, value in self.querystring.items() + if key.startswith(prefix)] + def metadata_response(request, full_url, headers): """ diff --git a/moto/ec2/responses/instances.py b/moto/ec2/responses/instances.py index 3b442eae..12b52607 100644 --- a/moto/ec2/responses/instances.py +++ b/moto/ec2/responses/instances.py @@ -8,10 +8,6 @@ from moto.ec2.exceptions import InvalidIdError class InstanceResponse(BaseResponse): - def _get_multi_param(self, param_prefix): - return [value[0] for key, value in self.querystring.items() - if key.startswith(param_prefix + ".")] - def describe_instances(self): instance_ids = instance_ids_from_querystring(self.querystring) if instance_ids: diff --git a/moto/ec2/responses/spot_instances.py b/moto/ec2/responses/spot_instances.py index 2c3d61a4..5ce8ee12 100644 --- a/moto/ec2/responses/spot_instances.py +++ b/moto/ec2/responses/spot_instances.py @@ -13,9 +13,6 @@ class SpotInstances(BaseResponse): if value is not None: return int(value) - def _get_multi_param(self, param_prefix): - return [value[0] for key, value in self.querystring.items() if key.startswith(param_prefix)] - def cancel_spot_instance_requests(self): request_ids = self._get_multi_param('SpotInstanceRequestId') requests = ec2_backend.cancel_spot_instance_requests(request_ids) @@ -49,7 +46,7 @@ class SpotInstances(BaseResponse): launch_group = self._get_param('LaunchGroup') availability_zone_group = self._get_param('AvailabilityZoneGroup') key_name = self._get_param('LaunchSpecification.KeyName') - security_groups = self._get_multi_param('LaunchSpecification.SecurityGroup.') + security_groups = self._get_multi_param('LaunchSpecification.SecurityGroup') user_data = self._get_param('LaunchSpecification.UserData') instance_type = self._get_param('LaunchSpecification.InstanceType') placement = self._get_param('LaunchSpecification.Placement.AvailabilityZone')