diff --git a/moto/swf/models/__init__.py b/moto/swf/models/__init__.py index 29fda4a3..6b5c04cf 100644 --- a/moto/swf/models/__init__.py +++ b/moto/swf/models/__init__.py @@ -9,7 +9,6 @@ from ..exceptions import ( SWFUnknownResourceFault, SWFDomainAlreadyExistsFault, SWFDomainDeprecatedFault, - SWFSerializationException, SWFTypeAlreadyExistsFault, SWFTypeDeprecatedFault, SWFValidationException, @@ -50,32 +49,12 @@ class SWFBackend(BaseBackend): return matching[0] return None - def _check_none_or_string(self, parameter): - if parameter is not None: - self._check_string(parameter) - - def _check_string(self, parameter): - if not isinstance(parameter, six.string_types): - raise SWFSerializationException(parameter) - - def _check_none_or_list_of_strings(self, parameter): - if parameter is not None: - self._check_list_of_strings(parameter) - - def _check_list_of_strings(self, parameter): - if not isinstance(parameter, list): - raise SWFSerializationException(parameter) - for i in parameter: - if not isinstance(i, six.string_types): - raise SWFSerializationException(parameter) - def _process_timeouts(self): for domain in self.domains: for wfe in domain.workflow_executions: wfe._process_timeouts() def list_domains(self, status, reverse_order=None): - self._check_string(status) domains = [domain for domain in self.domains if domain.status == status] domains = sorted(domains, key=lambda domain: domain.name) @@ -85,9 +64,6 @@ class SWFBackend(BaseBackend): def register_domain(self, name, workflow_execution_retention_period_in_days, description=None): - self._check_string(name) - self._check_string(workflow_execution_retention_period_in_days) - self._check_none_or_string(description) if self._get_domain(name, ignore_empty=True): raise SWFDomainAlreadyExistsFault(name) domain = Domain(name, workflow_execution_retention_period_in_days, @@ -95,19 +71,15 @@ class SWFBackend(BaseBackend): self.domains.append(domain) def deprecate_domain(self, name): - self._check_string(name) domain = self._get_domain(name) if domain.status == "DEPRECATED": raise SWFDomainDeprecatedFault(name) domain.status = "DEPRECATED" def describe_domain(self, name): - self._check_string(name) return self._get_domain(name) def list_types(self, kind, domain_name, status, reverse_order=None): - self._check_string(domain_name) - self._check_string(status) domain = self._get_domain(domain_name) _types = domain.find_types(kind, status) _types = sorted(_types, key=lambda domain: domain.name) @@ -116,11 +88,6 @@ class SWFBackend(BaseBackend): return _types def register_type(self, kind, domain_name, name, version, **kwargs): - self._check_string(domain_name) - self._check_string(name) - self._check_string(version) - for value in kwargs.values(): - self._check_none_or_string(value) domain = self._get_domain(domain_name) _type = domain.get_type(kind, name, version, ignore_empty=True) if _type: @@ -130,9 +97,6 @@ class SWFBackend(BaseBackend): domain.add_type(_type) def deprecate_type(self, kind, domain_name, name, version): - self._check_string(domain_name) - self._check_string(name) - self._check_string(version) domain = self._get_domain(domain_name) _type = domain.get_type(kind, name, version) if _type.status == "DEPRECATED": @@ -140,23 +104,12 @@ class SWFBackend(BaseBackend): _type.status = "DEPRECATED" def describe_type(self, kind, domain_name, name, version): - self._check_string(domain_name) - self._check_string(name) - self._check_string(version) domain = self._get_domain(domain_name) return domain.get_type(kind, name, version) def start_workflow_execution(self, domain_name, workflow_id, workflow_name, workflow_version, tag_list=None, **kwargs): - self._check_string(domain_name) - self._check_string(workflow_id) - self._check_string(workflow_name) - self._check_string(workflow_version) - self._check_none_or_list_of_strings(tag_list) - for value in kwargs.values(): - self._check_none_or_string(value) - domain = self._get_domain(domain_name) wf_type = domain.get_type("workflow", workflow_name, workflow_version) if wf_type.status == "DEPRECATED": @@ -169,17 +122,12 @@ class SWFBackend(BaseBackend): return wfe def describe_workflow_execution(self, domain_name, run_id, workflow_id): - self._check_string(domain_name) - self._check_string(run_id) - self._check_string(workflow_id) # process timeouts on all objects self._process_timeouts() domain = self._get_domain(domain_name) return domain.get_workflow_execution(workflow_id, run_id=run_id) def poll_for_decision_task(self, domain_name, task_list, identity=None): - self._check_string(domain_name) - self._check_string(task_list) # process timeouts on all objects self._process_timeouts() domain = self._get_domain(domain_name) @@ -211,8 +159,6 @@ class SWFBackend(BaseBackend): return None def count_pending_decision_tasks(self, domain_name, task_list): - self._check_string(domain_name) - self._check_string(task_list) # process timeouts on all objects self._process_timeouts() domain = self._get_domain(domain_name) @@ -225,8 +171,6 @@ class SWFBackend(BaseBackend): def respond_decision_task_completed(self, task_token, decisions=None, execution_context=None): - self._check_string(task_token) - self._check_none_or_string(execution_context) # process timeouts on all objects self._process_timeouts() # let's find decision task @@ -278,8 +222,6 @@ class SWFBackend(BaseBackend): execution_context=execution_context) def poll_for_activity_task(self, domain_name, task_list, identity=None): - self._check_string(domain_name) - self._check_string(task_list) # process timeouts on all objects self._process_timeouts() domain = self._get_domain(domain_name) @@ -311,8 +253,6 @@ class SWFBackend(BaseBackend): return None def count_pending_activity_tasks(self, domain_name, task_list): - self._check_string(domain_name) - self._check_string(task_list) # process timeouts on all objects self._process_timeouts() domain = self._get_domain(domain_name) @@ -362,8 +302,6 @@ class SWFBackend(BaseBackend): return activity_task def respond_activity_task_completed(self, task_token, result=None): - self._check_string(task_token) - self._check_none_or_string(result) # process timeouts on all objects self._process_timeouts() activity_task = self._find_activity_task_from_token(task_token) @@ -371,10 +309,6 @@ class SWFBackend(BaseBackend): wfe.complete_activity_task(activity_task.task_token, result=result) def respond_activity_task_failed(self, task_token, reason=None, details=None): - self._check_string(task_token) - # TODO: implement length limits on reason and details (common pb with client libs) - self._check_none_or_string(reason) - self._check_none_or_string(details) # process timeouts on all objects self._process_timeouts() activity_task = self._find_activity_task_from_token(task_token) @@ -383,12 +317,6 @@ class SWFBackend(BaseBackend): def terminate_workflow_execution(self, domain_name, workflow_id, child_policy=None, details=None, reason=None, run_id=None): - self._check_string(domain_name) - self._check_string(workflow_id) - self._check_none_or_string(child_policy) - self._check_none_or_string(details) - self._check_none_or_string(reason) - self._check_none_or_string(run_id) # process timeouts on all objects self._process_timeouts() domain = self._get_domain(domain_name) @@ -396,8 +324,6 @@ class SWFBackend(BaseBackend): wfe.terminate(child_policy=child_policy, details=details, reason=reason) def record_activity_task_heartbeat(self, task_token, details=None): - self._check_string(task_token) - self._check_none_or_string(details) # process timeouts on all objects self._process_timeouts() activity_task = self._find_activity_task_from_token(task_token) diff --git a/moto/swf/responses.py b/moto/swf/responses.py index 0b8557a2..7f418863 100644 --- a/moto/swf/responses.py +++ b/moto/swf/responses.py @@ -5,6 +5,7 @@ from moto.core.responses import BaseResponse from werkzeug.exceptions import HTTPException from moto.core.utils import camelcase_to_underscores, method_names_from_class +from .exceptions import SWFSerializationException from .models import swf_backends @@ -19,10 +20,31 @@ class SWFResponse(BaseResponse): def _params(self): return json.loads(self.body.decode("utf-8")) + def _check_none_or_string(self, parameter): + if parameter is not None: + self._check_string(parameter) + + def _check_string(self, parameter): + if not isinstance(parameter, six.string_types): + raise SWFSerializationException(parameter) + + def _check_none_or_list_of_strings(self, parameter): + if parameter is not None: + self._check_list_of_strings(parameter) + + def _check_list_of_strings(self, parameter): + if not isinstance(parameter, list): + raise SWFSerializationException(parameter) + for i in parameter: + if not isinstance(i, six.string_types): + raise SWFSerializationException(parameter) + def _list_types(self, kind): domain_name = self._params["domain"] status = self._params["registrationStatus"] reverse_order = self._params.get("reverseOrder", None) + self._check_string(domain_name) + self._check_string(status) types = self.swf_backend.list_types(kind, domain_name, status, reverse_order=reverse_order) return json.dumps({ "typeInfos": [_type.to_medium_dict() for _type in types] @@ -33,6 +55,9 @@ class SWFResponse(BaseResponse): _type_args = self._params["{0}Type".format(kind)] name = _type_args["name"] version = _type_args["version"] + self._check_string(domain) + self._check_string(name) + self._check_string(version) _type = self.swf_backend.describe_type(kind, domain, name, version) return json.dumps(_type.to_full_dict()) @@ -42,12 +67,16 @@ class SWFResponse(BaseResponse): _type_args = self._params["{0}Type".format(kind)] name = _type_args["name"] version = _type_args["version"] + self._check_string(domain) + self._check_string(name) + self._check_string(version) self.swf_backend.deprecate_type(kind, domain, name, version) return "" # TODO: implement pagination def list_domains(self): status = self._params["registrationStatus"] + self._check_string(status) reverse_order = self._params.get("reverseOrder", None) domains = self.swf_backend.list_domains(status, reverse_order=reverse_order) return json.dumps({ @@ -58,17 +87,22 @@ class SWFResponse(BaseResponse): name = self._params["name"] retention = self._params["workflowExecutionRetentionPeriodInDays"] description = self._params.get("description") + self._check_string(retention) + self._check_string(name) + self._check_none_or_string(description) domain = self.swf_backend.register_domain(name, retention, description=description) return "" def deprecate_domain(self): name = self._params["name"] + self._check_string(name) domain = self.swf_backend.deprecate_domain(name) return "" def describe_domain(self): name = self._params["name"] + self._check_string(name) domain = self.swf_backend.describe_domain(name) return json.dumps(domain.to_full_dict()) @@ -90,6 +124,17 @@ class SWFResponse(BaseResponse): default_task_schedule_to_start_timeout = self._params.get("defaultTaskScheduleToStartTimeout") default_task_start_to_close_timeout = self._params.get("defaultTaskStartToCloseTimeout") description = self._params.get("description") + + self._check_string(domain) + self._check_string(name) + self._check_string(version) + self._check_none_or_string(task_list) + self._check_none_or_string(default_task_heartbeat_timeout) + self._check_none_or_string(default_task_schedule_to_close_timeout) + self._check_none_or_string(default_task_schedule_to_start_timeout) + self._check_none_or_string(default_task_start_to_close_timeout) + self._check_none_or_string(description) + # TODO: add defaultTaskPriority when boto gets to support it activity_type = self.swf_backend.register_type( "activity", domain, name, version, task_list=task_list, @@ -123,6 +168,16 @@ class SWFResponse(BaseResponse): default_task_start_to_close_timeout = self._params.get("defaultTaskStartToCloseTimeout") default_execution_start_to_close_timeout = self._params.get("defaultExecutionStartToCloseTimeout") description = self._params.get("description") + + self._check_string(domain) + self._check_string(name) + self._check_string(version) + self._check_none_or_string(task_list) + self._check_none_or_string(default_child_policy) + self._check_none_or_string(default_task_start_to_close_timeout) + self._check_none_or_string(default_execution_start_to_close_timeout) + self._check_none_or_string(description) + # TODO: add defaultTaskPriority when boto gets to support it # TODO: add defaultLambdaRole when boto gets to support it workflow_type = self.swf_backend.register_type( @@ -157,6 +212,17 @@ class SWFResponse(BaseResponse): tag_list = self._params.get("tagList") task_start_to_close_timeout = self._params.get("taskStartToCloseTimeout") + self._check_string(domain) + self._check_string(workflow_id) + self._check_string(workflow_name) + self._check_string(workflow_version) + self._check_none_or_string(task_list) + self._check_none_or_string(child_policy) + self._check_none_or_string(execution_start_to_close_timeout) + self._check_none_or_string(input_) + self._check_none_or_list_of_strings(tag_list) + self._check_none_or_string(task_start_to_close_timeout) + wfe = self.swf_backend.start_workflow_execution( domain, workflow_id, workflow_name, workflow_version, task_list=task_list, child_policy=child_policy, @@ -175,6 +241,10 @@ class SWFResponse(BaseResponse): run_id = _workflow_execution["runId"] workflow_id = _workflow_execution["workflowId"] + self._check_string(domain_name) + self._check_string(run_id) + self._check_string(workflow_id) + wfe = self.swf_backend.describe_workflow_execution(domain_name, run_id, workflow_id) return json.dumps(wfe.to_full_dict()) @@ -195,6 +265,10 @@ class SWFResponse(BaseResponse): task_list = self._params["taskList"]["name"] identity = self._params.get("identity") reverse_order = self._params.get("reverseOrder", None) + + self._check_string(domain_name) + self._check_string(task_list) + decision = self.swf_backend.poll_for_decision_task( domain_name, task_list, identity=identity ) @@ -208,6 +282,8 @@ class SWFResponse(BaseResponse): def count_pending_decision_tasks(self): domain_name = self._params["domain"] task_list = self._params["taskList"]["name"] + self._check_string(domain_name) + self._check_string(task_list) count = self.swf_backend.count_pending_decision_tasks(domain_name, task_list) return json.dumps({"count": count, "truncated": False}) @@ -215,6 +291,8 @@ class SWFResponse(BaseResponse): task_token = self._params["taskToken"] execution_context = self._params.get("executionContext") decisions = self._params.get("decisions") + self._check_string(task_token) + self._check_none_or_string(execution_context) self.swf_backend.respond_decision_task_completed( task_token, decisions=decisions, execution_context=execution_context ) @@ -224,6 +302,9 @@ class SWFResponse(BaseResponse): domain_name = self._params["domain"] task_list = self._params["taskList"]["name"] identity = self._params.get("identity") + self._check_string(domain_name) + self._check_string(task_list) + self._check_none_or_string(identity) activity_task = self.swf_backend.poll_for_activity_task( domain_name, task_list, identity=identity ) @@ -237,12 +318,16 @@ class SWFResponse(BaseResponse): def count_pending_activity_tasks(self): domain_name = self._params["domain"] task_list = self._params["taskList"]["name"] + self._check_string(domain_name) + self._check_string(task_list) count = self.swf_backend.count_pending_activity_tasks(domain_name, task_list) return json.dumps({"count": count, "truncated": False}) def respond_activity_task_completed(self): task_token = self._params["taskToken"] result = self._params.get("result") + self._check_string(task_token) + self._check_none_or_string(result) self.swf_backend.respond_activity_task_completed( task_token, result=result ) @@ -252,6 +337,10 @@ class SWFResponse(BaseResponse): task_token = self._params["taskToken"] reason = self._params.get("reason") details = self._params.get("details") + self._check_string(task_token) + # TODO: implement length limits on reason and details (common pb with client libs) + self._check_none_or_string(reason) + self._check_none_or_string(details) self.swf_backend.respond_activity_task_failed( task_token, reason=reason, details=details ) @@ -264,6 +353,12 @@ class SWFResponse(BaseResponse): details = self._params.get("details") reason = self._params.get("reason") run_id = self._params.get("runId") + self._check_string(domain_name) + self._check_string(workflow_id) + self._check_none_or_string(child_policy) + self._check_none_or_string(details) + self._check_none_or_string(reason) + self._check_none_or_string(run_id) self.swf_backend.terminate_workflow_execution( domain_name, workflow_id, child_policy=child_policy, details=details, reason=reason, run_id=run_id @@ -273,6 +368,8 @@ class SWFResponse(BaseResponse): def record_activity_task_heartbeat(self): task_token = self._params["taskToken"] details = self._params.get("details") + self._check_string(task_token) + self._check_none_or_string(details) self.swf_backend.record_activity_task_heartbeat( task_token, details=details )