From 51df02e7cf922d17e3fcb34993ada166444f11cd Mon Sep 17 00:00:00 2001 From: Steve Pulec Date: Mon, 20 Feb 2017 14:31:19 -0500 Subject: [PATCH] Cleanup Server host parsing. --- moto/core/models.py | 23 +++++++++++++---------- moto/server.py | 12 +++++++++--- tests/test_core/test_server.py | 6 +++--- 3 files changed, 25 insertions(+), 16 deletions(-) diff --git a/moto/core/models.py b/moto/core/models.py index 9570a86d..9675d514 100644 --- a/moto/core/models.py +++ b/moto/core/models.py @@ -3,6 +3,7 @@ from __future__ import absolute_import import functools import inspect +import os import re from moto.packages.responses import responses @@ -48,7 +49,9 @@ class BaseMockAWS(object): if self.__class__.nested_count < 0: raise RuntimeError('Called stop() before start().') - self.disable_patching() + + if self.__class__.nested_count == 0: + self.disable_patching() def decorate_callable(self, func, reset): def wrapper(*args, **kwargs): @@ -108,9 +111,8 @@ class HttprettyMockAWS(BaseMockAWS): ) def disable_patching(self): - if self.__class__.nested_count == 0: - HTTPretty.disable() - HTTPretty.reset() + HTTPretty.disable() + HTTPretty.reset() RESPONSES_METHODS = [responses.GET, responses.DELETE, responses.HEAD, @@ -142,14 +144,15 @@ class ResponsesMockAWS(BaseMockAWS): pattern['stream'] = True def disable_patching(self): - if self.__class__.nested_count == 0: - try: - responses.stop() - except AttributeError: - pass - responses.reset() + try: + responses.stop() + except AttributeError: + pass + responses.reset() + MockAWS = ResponsesMockAWS + class Model(type): def __new__(self, clsname, bases, namespace): cls = super(Model, self).__new__(self, clsname, bases, namespace) diff --git a/moto/server.py b/moto/server.py index 1780083d..321f5a9e 100644 --- a/moto/server.py +++ b/moto/server.py @@ -42,8 +42,14 @@ class DomainDispatcherApplication(object): raise RuntimeError('Invalid host: "%s"' % host) - def get_application(self, host): - host = host.split(':')[0] + def get_application(self, environ): + host = environ['HTTP_HOST'].split(':')[0] + if host == "localhost": + # Fall back to parsing auth header to find service + # ['Credential=sdffdsa', '20170220', 'us-east-1', 'sns', 'aws4_request'] + _, _, region, service, _ = environ['HTTP_AUTHORIZATION'].split(",")[0].split()[1].split("/") + host = "{service}.{region}.amazonaws.com".format(service=service, region=region) + with self.lock: backend = self.get_backend_for_host(host) app = self.app_instances.get(backend, None) @@ -53,7 +59,7 @@ class DomainDispatcherApplication(object): return app def __call__(self, environ, start_response): - backend_app = self.get_application(environ['HTTP_HOST']) + backend_app = self.get_application(environ) return backend_app(environ, start_response) diff --git a/tests/test_core/test_server.py b/tests/test_core/test_server.py index 3ee08465..a0fb328c 100644 --- a/tests/test_core/test_server.py +++ b/tests/test_core/test_server.py @@ -32,19 +32,19 @@ def test_port_argument(run_simple): def test_domain_dispatched(): dispatcher = DomainDispatcherApplication(create_backend_app) - backend_app = dispatcher.get_application("email.us-east1.amazonaws.com") + backend_app = dispatcher.get_application({"HTTP_HOST": "email.us-east1.amazonaws.com"}) keys = list(backend_app.view_functions.keys()) keys[0].should.equal('EmailResponse.dispatch') def test_domain_without_matches(): dispatcher = DomainDispatcherApplication(create_backend_app) - dispatcher.get_application.when.called_with("not-matching-anything.com").should.throw(RuntimeError) + dispatcher.get_application.when.called_with({"HTTP_HOST": "not-matching-anything.com"}).should.throw(RuntimeError) def test_domain_dispatched_with_service(): # If we pass a particular service, always return that. dispatcher = DomainDispatcherApplication(create_backend_app, service="s3") - backend_app = dispatcher.get_application("s3.us-east1.amazonaws.com") + backend_app = dispatcher.get_application({"HTTP_HOST": "s3.us-east1.amazonaws.com"}) keys = set(backend_app.view_functions.keys()) keys.should.contain('ResponseObject.key_response')