Add CloudFormation support for SageMaker Endpoint Configs and Endpoints (#3863)
* Create SageMaker EndpointConfig with CloudFormation Implement attributes for SM Endpoint Configs with CloudFormation Delete SM Endpoint Configs with CloudFormation Update SM Endpoint Configs with CloudFormation * Fix typos in SM CF Model update test and refactor helper function for CF stack outputs * Fixup weird commas in SM CF Test Configs from using black * Create SageMaker Endpoints with CloudFormation * Fix typos in SM CF update tests
This commit is contained in:
parent
f6dda54a6c
commit
9b3e932822
3 changed files with 450 additions and 54 deletions
|
|
@ -42,6 +42,14 @@ class TestConfig:
|
|||
def get_cloudformation_template(self, include_outputs=True, **kwargs):
|
||||
pass
|
||||
|
||||
def run_setup_procedure(self, sagemaker_client):
|
||||
"""Provides a method to set up resources with a SageMaker client.
|
||||
|
||||
Note: This procedure should be called while within a `mock_sagemaker`
|
||||
context so that no actual resources are created with the sagemaker_client.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class NotebookInstanceTestConfig(TestConfig):
|
||||
"""Test configuration for SageMaker Notebook Instances."""
|
||||
|
|
@ -186,6 +194,131 @@ class ModelTestConfig(TestConfig):
|
|||
if include_outputs:
|
||||
template["Outputs"] = {
|
||||
"Arn": {"Value": {"Ref": self.resource_name}},
|
||||
"Name": {"Value": {"Fn::GetAtt": [self.resource_name, "ModelName"],}},
|
||||
"Name": {"Value": {"Fn::GetAtt": [self.resource_name, "ModelName"]}},
|
||||
}
|
||||
return json.dumps(template)
|
||||
|
||||
|
||||
class EndpointConfigTestConfig(TestConfig):
|
||||
"""Test configuration for SageMaker Endpoint Configs."""
|
||||
|
||||
@property
|
||||
def resource_name(self):
|
||||
return "TestEndpointConfig"
|
||||
|
||||
@property
|
||||
def describe_function_name(self):
|
||||
return "describe_endpoint_config"
|
||||
|
||||
@property
|
||||
def name_parameter(self):
|
||||
return "EndpointConfigName"
|
||||
|
||||
@property
|
||||
def arn_parameter(self):
|
||||
return "EndpointConfigArn"
|
||||
|
||||
def get_cloudformation_template(self, include_outputs=True, **kwargs):
|
||||
num_production_variants = kwargs.get("num_production_variants", 1)
|
||||
|
||||
production_variants = [
|
||||
{
|
||||
"InitialInstanceCount": 1,
|
||||
"InitialVariantWeight": 1,
|
||||
"InstanceType": "ml.c4.xlarge",
|
||||
"ModelName": self.resource_name,
|
||||
"VariantName": "variant-name-{}".format(i),
|
||||
}
|
||||
for i in range(num_production_variants)
|
||||
]
|
||||
|
||||
template = {
|
||||
"AWSTemplateFormatVersion": "2010-09-09",
|
||||
"Resources": {
|
||||
self.resource_name: {
|
||||
"Type": "AWS::SageMaker::EndpointConfig",
|
||||
"Properties": {"ProductionVariants": production_variants},
|
||||
},
|
||||
},
|
||||
}
|
||||
if include_outputs:
|
||||
template["Outputs"] = {
|
||||
"Arn": {"Value": {"Ref": self.resource_name}},
|
||||
"Name": {
|
||||
"Value": {"Fn::GetAtt": [self.resource_name, "EndpointConfigName"]}
|
||||
},
|
||||
}
|
||||
return json.dumps(template)
|
||||
|
||||
def run_setup_procedure(self, sagemaker_client):
|
||||
"""Adds Model that can be referenced in the CloudFormation template."""
|
||||
|
||||
sagemaker_client.create_model(
|
||||
ModelName=self.resource_name,
|
||||
ExecutionRoleArn="arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID),
|
||||
PrimaryContainer={
|
||||
"Image": "404615174143.dkr.ecr.us-east-2.amazonaws.com/linear-learner:1",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class EndpointTestConfig(TestConfig):
|
||||
"""Test configuration for SageMaker Endpoints."""
|
||||
|
||||
@property
|
||||
def resource_name(self):
|
||||
return "TestEndpoint"
|
||||
|
||||
@property
|
||||
def describe_function_name(self):
|
||||
return "describe_endpoint"
|
||||
|
||||
@property
|
||||
def name_parameter(self):
|
||||
return "EndpointName"
|
||||
|
||||
@property
|
||||
def arn_parameter(self):
|
||||
return "EndpointArn"
|
||||
|
||||
def get_cloudformation_template(self, include_outputs=True, **kwargs):
|
||||
endpoint_config_name = kwargs.get("endpoint_config_name", self.resource_name)
|
||||
|
||||
template = {
|
||||
"AWSTemplateFormatVersion": "2010-09-09",
|
||||
"Resources": {
|
||||
self.resource_name: {
|
||||
"Type": "AWS::SageMaker::Endpoint",
|
||||
"Properties": {"EndpointConfigName": endpoint_config_name},
|
||||
},
|
||||
},
|
||||
}
|
||||
if include_outputs:
|
||||
template["Outputs"] = {
|
||||
"Arn": {"Value": {"Ref": self.resource_name}},
|
||||
"Name": {"Value": {"Fn::GetAtt": [self.resource_name, "EndpointName"]}},
|
||||
}
|
||||
return json.dumps(template)
|
||||
|
||||
def run_setup_procedure(self, sagemaker_client):
|
||||
"""Adds Model and Endpoint Config that can be referenced in the CloudFormation template."""
|
||||
|
||||
sagemaker_client.create_model(
|
||||
ModelName=self.resource_name,
|
||||
ExecutionRoleArn="arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID),
|
||||
PrimaryContainer={
|
||||
"Image": "404615174143.dkr.ecr.us-east-2.amazonaws.com/linear-learner:1",
|
||||
},
|
||||
)
|
||||
sagemaker_client.create_endpoint_config(
|
||||
EndpointConfigName=self.resource_name,
|
||||
ProductionVariants=[
|
||||
{
|
||||
"InitialInstanceCount": 1,
|
||||
"InitialVariantWeight": 1,
|
||||
"InstanceType": "ml.c4.xlarge",
|
||||
"ModelName": self.resource_name,
|
||||
"VariantName": "variant-name-1",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue