Implement basic SNS message filtering (#1521)

* Add support for FilterPolicy to sns subscription set_filter_attributes

* Add basic support for sns message filtering

This adds support for exact string value matching along with AND/OR
logic as described here:

https://docs.aws.amazon.com/sns/latest/dg/message-filtering.html

It does not provide support for:
- Anything-but string matching
- Prefix string matching
- Numeric Value Matching

The above filter policies (if configured) will not match messages.
This commit is contained in:
Iain Bullard 2018-03-21 15:49:11 +00:00 committed by Jack Danger
commit d3d9557d49
5 changed files with 194 additions and 8 deletions

View file

@ -42,11 +42,12 @@ class Topic(BaseModel):
self.subscriptions_confimed = 0
self.subscriptions_deleted = 0
def publish(self, message, subject=None):
def publish(self, message, subject=None, message_attributes=None):
message_id = six.text_type(uuid.uuid4())
subscriptions, _ = self.sns_backend.list_subscriptions(self.arn)
for subscription in subscriptions:
subscription.publish(message, message_id, subject=subject)
subscription.publish(message, message_id, subject=subject,
message_attributes=message_attributes)
return message_id
def get_cfn_attribute(self, attribute_name):
@ -81,9 +82,14 @@ class Subscription(BaseModel):
self.protocol = protocol
self.arn = make_arn_for_subscription(self.topic.arn)
self.attributes = {}
self._filter_policy = None # filter policy as a dict, not json.
self.confirmed = False
def publish(self, message, message_id, subject=None):
def publish(self, message, message_id, subject=None,
message_attributes=None):
if not self._matches_filter_policy(message_attributes):
return
if self.protocol == 'sqs':
queue_name = self.endpoint.split(":")[-1]
region = self.endpoint.split(":")[3]
@ -98,6 +104,28 @@ class Subscription(BaseModel):
region = self.arn.split(':')[3]
lambda_backends[region].send_message(function_name, message, subject=subject)
def _matches_filter_policy(self, message_attributes):
# TODO: support Anything-but matching, prefix matching and
# numeric value matching.
if not self._filter_policy:
return True
if message_attributes is None:
message_attributes = {}
def _field_match(field, rules, message_attributes):
if field not in message_attributes:
return False
for rule in rules:
if isinstance(rule, six.string_types):
# only string value matching is supported
if message_attributes[field] == rule:
return True
return False
return all(_field_match(field, rules, message_attributes)
for field, rules in six.iteritems(self._filter_policy))
def get_post_data(self, message, message_id, subject):
return {
"Type": "Notification",
@ -274,13 +302,14 @@ class SNSBackend(BaseBackend):
else:
return self._get_values_nexttoken(self.subscriptions, next_token)
def publish(self, arn, message, subject=None):
def publish(self, arn, message, subject=None, message_attributes=None):
if subject is not None and len(subject) >= 100:
raise ValueError('Subject must be less than 100 characters')
try:
topic = self.get_topic(arn)
message_id = topic.publish(message, subject=subject)
message_id = topic.publish(message, subject=subject,
message_attributes=message_attributes)
except SNSNotFoundError:
endpoint = self.get_endpoint(arn)
message_id = endpoint.publish(message)
@ -352,7 +381,7 @@ class SNSBackend(BaseBackend):
return subscription.attributes
def set_subscription_attributes(self, arn, name, value):
if name not in ['RawMessageDelivery', 'DeliveryPolicy']:
if name not in ['RawMessageDelivery', 'DeliveryPolicy', 'FilterPolicy']:
raise SNSInvalidParameter('AttributeName')
# TODO: should do validation
@ -363,6 +392,9 @@ class SNSBackend(BaseBackend):
subscription.attributes[name] = value
if name == 'FilterPolicy':
subscription._filter_policy = json.loads(value)
sns_backends = {}
for region in boto.sns.regions():

View file

@ -241,6 +241,10 @@ class SNSResponse(BaseResponse):
phone_number = self._get_param('PhoneNumber')
subject = self._get_param('Subject')
message_attributes = self._get_map_prefix('MessageAttributes.entry',
key_end='Name',
value_end='Value')
if phone_number is not None:
# Check phone is correct syntax (e164)
if not is_e164(phone_number):
@ -265,7 +269,9 @@ class SNSResponse(BaseResponse):
message = self._get_param('Message')
try:
message_id = self.backend.publish(arn, message, subject=subject)
message_id = self.backend.publish(
arn, message, subject=subject,
message_attributes=message_attributes)
except ValueError as err:
error_response = self._error('InvalidParameter', str(err))
return error_response, dict(status=400)

View file

@ -30,7 +30,7 @@ class SQSResponse(BaseResponse):
@property
def attribute(self):
if not hasattr(self, '_attribute'):
self._attribute = self._get_map_prefix('Attribute', key_end='Name', value_end='Value')
self._attribute = self._get_map_prefix('Attribute', key_end='.Name', value_end='.Value')
return self._attribute
def _get_queue_name(self):