moto/tests/test_sagemaker/cloudformation_test_configs.py
Zach Churchill f6dda54a6c
Add CloudFormation support for SageMaker Models (#3861)
* Create a formal interface for SM Cloudformation test configurations

* Create SageMaker Models with CloudFormation

* Utilize six for adding metaclass to TestConfig

* Update SM backend to return Model objects instead of response objects
2021-04-16 15:23:05 +01:00

191 lines
5.5 KiB
Python

import json
from abc import ABCMeta, abstractmethod
import six
from moto.sts.models import ACCOUNT_ID
@six.add_metaclass(ABCMeta)
class TestConfig:
"""Provides the interface to use for creating test configurations.
This class will provide the interface for what information will be
needed for the SageMaker CloudFormation tests. Ultimately, this will
improve the readability of the tests in `test_sagemaker_cloudformation.py`
because it will reduce the amount of information we pass through the
`pytest.mark.parametrize` decorator.
"""
@property
@abstractmethod
def resource_name(self):
pass
@property
@abstractmethod
def describe_function_name(self):
pass
@property
@abstractmethod
def name_parameter(self):
pass
@property
@abstractmethod
def arn_parameter(self):
pass
@abstractmethod
def get_cloudformation_template(self, include_outputs=True, **kwargs):
pass
class NotebookInstanceTestConfig(TestConfig):
"""Test configuration for SageMaker Notebook Instances."""
@property
def resource_name(self):
return "TestNotebook"
@property
def describe_function_name(self):
return "describe_notebook_instance"
@property
def name_parameter(self):
return "NotebookInstanceName"
@property
def arn_parameter(self):
return "NotebookInstanceArn"
def get_cloudformation_template(self, include_outputs=True, **kwargs):
instance_type = kwargs.get("instance_type", "ml.c4.xlarge")
role_arn = kwargs.get(
"role_arn", "arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID)
)
template = {
"AWSTemplateFormatVersion": "2010-09-09",
"Resources": {
self.resource_name: {
"Type": "AWS::SageMaker::NotebookInstance",
"Properties": {"InstanceType": instance_type, "RoleArn": role_arn},
},
},
}
if include_outputs:
template["Outputs"] = {
"Arn": {"Value": {"Ref": self.resource_name}},
"Name": {
"Value": {
"Fn::GetAtt": [self.resource_name, "NotebookInstanceName"]
}
},
}
return json.dumps(template)
class NotebookInstanceLifecycleConfigTestConfig(TestConfig):
"""Test configuration for SageMaker Notebook Instance Lifecycle Configs."""
@property
def resource_name(self):
return "TestNotebookLifecycleConfig"
@property
def describe_function_name(self):
return "describe_notebook_instance_lifecycle_config"
@property
def name_parameter(self):
return "NotebookInstanceLifecycleConfigName"
@property
def arn_parameter(self):
return "NotebookInstanceLifecycleConfigArn"
def get_cloudformation_template(self, include_outputs=True, **kwargs):
on_create = kwargs.get("on_create")
on_start = kwargs.get("on_start")
template = {
"AWSTemplateFormatVersion": "2010-09-09",
"Resources": {
self.resource_name: {
"Type": "AWS::SageMaker::NotebookInstanceLifecycleConfig",
"Properties": {},
},
},
}
if on_create is not None:
template["Resources"][self.resource_name]["Properties"]["OnCreate"] = [
{"Content": on_create}
]
if on_start is not None:
template["Resources"][self.resource_name]["Properties"]["OnStart"] = [
{"Content": on_start}
]
if include_outputs:
template["Outputs"] = {
"Arn": {"Value": {"Ref": self.resource_name}},
"Name": {
"Value": {
"Fn::GetAtt": [
self.resource_name,
"NotebookInstanceLifecycleConfigName",
]
}
},
}
return json.dumps(template)
class ModelTestConfig(TestConfig):
"""Test configuration for SageMaker Models."""
@property
def resource_name(self):
return "TestModel"
@property
def describe_function_name(self):
return "describe_model"
@property
def name_parameter(self):
return "ModelName"
@property
def arn_parameter(self):
return "ModelArn"
def get_cloudformation_template(self, include_outputs=True, **kwargs):
execution_role_arn = kwargs.get(
"execution_role_arn", "arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID)
)
image = kwargs.get(
"image", "404615174143.dkr.ecr.us-east-2.amazonaws.com/linear-learner:1"
)
template = {
"AWSTemplateFormatVersion": "2010-09-09",
"Resources": {
self.resource_name: {
"Type": "AWS::SageMaker::Model",
"Properties": {
"ExecutionRoleArn": execution_role_arn,
"PrimaryContainer": {"Image": image,},
},
},
},
}
if include_outputs:
template["Outputs"] = {
"Arn": {"Value": {"Ref": self.resource_name}},
"Name": {"Value": {"Fn::GetAtt": [self.resource_name, "ModelName"],}},
}
return json.dumps(template)