Basic Support for Endpoints, EndpointConfigs and TrainingJobs (#3142)

* Basic upport for Endpoints, EndpointConfigs and TrainingJobs

* Dropped extraneous pass statement.

Co-authored-by: Joseph Weitekamp <jweite@amazon.com>
This commit is contained in:
jweite 2020-07-19 10:06:48 -04:00 committed by GitHub
commit ba99c61477
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 1007 additions and 6 deletions

View file

@ -1,5 +1,6 @@
from __future__ import unicode_literals
import os
from copy import deepcopy
from datetime import datetime
@ -32,6 +33,288 @@ class BaseObject(BaseModel):
return self.gen_response_object()
class FakeTrainingJob(BaseObject):
def __init__(
self,
region_name,
training_job_name,
hyper_parameters,
algorithm_specification,
role_arn,
input_data_config,
output_data_config,
resource_config,
vpc_config,
stopping_condition,
tags,
enable_network_isolation,
enable_inter_container_traffic_encryption,
enable_managed_spot_training,
checkpoint_config,
debug_hook_config,
debug_rule_configurations,
tensor_board_output_config,
experiment_config,
):
self.training_job_name = training_job_name
self.hyper_parameters = hyper_parameters
self.algorithm_specification = algorithm_specification
self.role_arn = role_arn
self.input_data_config = input_data_config
self.output_data_config = output_data_config
self.resource_config = resource_config
self.vpc_config = vpc_config
self.stopping_condition = stopping_condition
self.tags = tags
self.enable_network_isolation = enable_network_isolation
self.enable_inter_container_traffic_encryption = (
enable_inter_container_traffic_encryption
)
self.enable_managed_spot_training = enable_managed_spot_training
self.checkpoint_config = checkpoint_config
self.debug_hook_config = debug_hook_config
self.debug_rule_configurations = debug_rule_configurations
self.tensor_board_output_config = tensor_board_output_config
self.experiment_config = experiment_config
self.training_job_arn = FakeTrainingJob.arn_formatter(
training_job_name, region_name
)
self.creation_time = self.last_modified_time = datetime.now().strftime(
"%Y-%m-%d %H:%M:%S"
)
self.model_artifacts = {
"S3ModelArtifacts": os.path.join(
self.output_data_config["S3OutputPath"],
self.training_job_name,
"output",
"model.tar.gz",
)
}
self.training_job_status = "Completed"
self.secondary_status = "Completed"
self.algorithm_specification["MetricDefinitions"] = [
{
"Name": "test:dcg",
"Regex": "#quality_metric: host=\\S+, test dcg <score>=(\\S+)",
}
]
now_string = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self.creation_time = now_string
self.last_modified_time = now_string
self.training_start_time = now_string
self.training_end_time = now_string
self.secondary_status_transitions = [
{
"Status": "Starting",
"StartTime": self.creation_time,
"EndTime": self.creation_time,
"StatusMessage": "Preparing the instances for training",
}
]
self.final_metric_data_list = [
{
"MetricName": "train:progress",
"Value": 100.0,
"Timestamp": self.creation_time,
}
]
@property
def response_object(self):
response_object = self.gen_response_object()
return {
k: v for k, v in response_object.items() if v is not None and v != [None]
}
@property
def response_create(self):
return {"TrainingJobArn": self.training_job_arn}
@staticmethod
def arn_formatter(endpoint_name, region_name):
return (
"arn:aws:sagemaker:"
+ region_name
+ ":"
+ str(ACCOUNT_ID)
+ ":training-job/"
+ endpoint_name
)
class FakeEndpoint(BaseObject):
def __init__(
self,
region_name,
endpoint_name,
endpoint_config_name,
production_variants,
data_capture_config,
tags,
):
self.endpoint_name = endpoint_name
self.endpoint_arn = FakeEndpoint.arn_formatter(endpoint_name, region_name)
self.endpoint_config_name = endpoint_config_name
self.production_variants = production_variants
self.data_capture_config = data_capture_config
self.tags = tags or []
self.endpoint_status = "InService"
self.failure_reason = None
self.creation_time = self.last_modified_time = datetime.now().strftime(
"%Y-%m-%d %H:%M:%S"
)
@property
def response_object(self):
response_object = self.gen_response_object()
return {
k: v for k, v in response_object.items() if v is not None and v != [None]
}
@property
def response_create(self):
return {"EndpointArn": self.endpoint_arn}
@staticmethod
def arn_formatter(endpoint_name, region_name):
return (
"arn:aws:sagemaker:"
+ region_name
+ ":"
+ str(ACCOUNT_ID)
+ ":endpoint/"
+ endpoint_name
)
class FakeEndpointConfig(BaseObject):
def __init__(
self,
region_name,
endpoint_config_name,
production_variants,
data_capture_config,
tags,
kms_key_id,
):
self.validate_production_variants(production_variants)
self.endpoint_config_name = endpoint_config_name
self.endpoint_config_arn = FakeEndpointConfig.arn_formatter(
endpoint_config_name, region_name
)
self.production_variants = production_variants or []
self.data_capture_config = data_capture_config or {}
self.tags = tags or []
self.kms_key_id = kms_key_id
self.creation_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
def validate_production_variants(self, production_variants):
for production_variant in production_variants:
self.validate_instance_type(production_variant["InstanceType"])
def validate_instance_type(self, instance_type):
VALID_INSTANCE_TYPES = [
"ml.r5d.12xlarge",
"ml.r5.12xlarge",
"ml.p2.xlarge",
"ml.m5.4xlarge",
"ml.m4.16xlarge",
"ml.r5d.24xlarge",
"ml.r5.24xlarge",
"ml.p3.16xlarge",
"ml.m5d.xlarge",
"ml.m5.large",
"ml.t2.xlarge",
"ml.p2.16xlarge",
"ml.m5d.12xlarge",
"ml.inf1.2xlarge",
"ml.m5d.24xlarge",
"ml.c4.2xlarge",
"ml.c5.2xlarge",
"ml.c4.4xlarge",
"ml.inf1.6xlarge",
"ml.c5d.2xlarge",
"ml.c5.4xlarge",
"ml.g4dn.xlarge",
"ml.g4dn.12xlarge",
"ml.c5d.4xlarge",
"ml.g4dn.2xlarge",
"ml.c4.8xlarge",
"ml.c4.large",
"ml.c5d.xlarge",
"ml.c5.large",
"ml.g4dn.4xlarge",
"ml.c5.9xlarge",
"ml.g4dn.16xlarge",
"ml.c5d.large",
"ml.c5.xlarge",
"ml.c5d.9xlarge",
"ml.c4.xlarge",
"ml.inf1.xlarge",
"ml.g4dn.8xlarge",
"ml.inf1.24xlarge",
"ml.m5d.2xlarge",
"ml.t2.2xlarge",
"ml.c5d.18xlarge",
"ml.m5d.4xlarge",
"ml.t2.medium",
"ml.c5.18xlarge",
"ml.r5d.2xlarge",
"ml.r5.2xlarge",
"ml.p3.2xlarge",
"ml.m5d.large",
"ml.m5.xlarge",
"ml.m4.10xlarge",
"ml.t2.large",
"ml.r5d.4xlarge",
"ml.r5.4xlarge",
"ml.m5.12xlarge",
"ml.m4.xlarge",
"ml.m5.24xlarge",
"ml.m4.2xlarge",
"ml.p2.8xlarge",
"ml.m5.2xlarge",
"ml.r5d.xlarge",
"ml.r5d.large",
"ml.r5.xlarge",
"ml.r5.large",
"ml.p3.8xlarge",
"ml.m4.4xlarge",
]
if not validators.is_one_of(instance_type, VALID_INSTANCE_TYPES):
message = "Value '{}' at 'instanceType' failed to satisfy constraint: Member must satisfy enum value set: {}".format(
instance_type, VALID_INSTANCE_TYPES
)
raise RESTError(
error_type="ValidationException",
message=message,
template="error_json",
)
@property
def response_object(self):
response_object = self.gen_response_object()
return {
k: v for k, v in response_object.items() if v is not None and v != [None]
}
@property
def response_create(self):
return {"EndpointConfigArn": self.endpoint_config_arn}
@staticmethod
def arn_formatter(model_name, region_name):
return (
"arn:aws:sagemaker:"
+ region_name
+ ":"
+ str(ACCOUNT_ID)
+ ":endpoint-config/"
+ model_name
)
class Model(BaseObject):
def __init__(
self,
@ -238,6 +521,9 @@ class SageMakerModelBackend(BaseBackend):
def __init__(self, region_name=None):
self._models = {}
self.notebook_instances = {}
self.endpoint_configs = {}
self.endpoints = {}
self.training_jobs = {}
self.region_name = region_name
def reset(self):
@ -305,10 +591,10 @@ class SageMakerModelBackend(BaseBackend):
self._validate_unique_notebook_instance_name(notebook_instance_name)
notebook_instance = FakeSagemakerNotebookInstance(
self.region_name,
notebook_instance_name,
instance_type,
role_arn,
region_name=self.region_name,
notebook_instance_name=notebook_instance_name,
instance_type=instance_type,
role_arn=role_arn,
subnet_id=subnet_id,
security_group_ids=security_group_ids,
kms_key_id=kms_key_id,
@ -392,6 +678,235 @@ class SageMakerModelBackend(BaseBackend):
except RESTError:
return []
def create_endpoint_config(
self,
endpoint_config_name,
production_variants,
data_capture_config,
tags,
kms_key_id,
):
endpoint_config = FakeEndpointConfig(
region_name=self.region_name,
endpoint_config_name=endpoint_config_name,
production_variants=production_variants,
data_capture_config=data_capture_config,
tags=tags,
kms_key_id=kms_key_id,
)
self.validate_production_variants(production_variants)
self.endpoint_configs[endpoint_config_name] = endpoint_config
return endpoint_config
def validate_production_variants(self, production_variants):
for production_variant in production_variants:
if production_variant["ModelName"] not in self._models:
message = "Could not find model '{}'.".format(
Model.arn_for_model_name(
production_variant["ModelName"], self.region_name
)
)
raise RESTError(
error_type="ValidationException",
message=message,
template="error_json",
)
def describe_endpoint_config(self, endpoint_config_name):
try:
return self.endpoint_configs[endpoint_config_name].response_object
except KeyError:
message = "Could not find endpoint configuration '{}'.".format(
FakeEndpointConfig.arn_formatter(endpoint_config_name, self.region_name)
)
raise RESTError(
error_type="ValidationException",
message=message,
template="error_json",
)
def delete_endpoint_config(self, endpoint_config_name):
try:
del self.endpoint_configs[endpoint_config_name]
except KeyError:
message = "Could not find endpoint configuration '{}'.".format(
FakeEndpointConfig.arn_formatter(endpoint_config_name, self.region_name)
)
raise RESTError(
error_type="ValidationException",
message=message,
template="error_json",
)
def create_endpoint(
self, endpoint_name, endpoint_config_name, tags,
):
try:
endpoint_config = self.describe_endpoint_config(endpoint_config_name)
except KeyError:
message = "Could not find endpoint_config '{}'.".format(
FakeEndpointConfig.arn_formatter(endpoint_config_name, self.region_name)
)
raise RESTError(
error_type="ValidationException",
message=message,
template="error_json",
)
endpoint = FakeEndpoint(
region_name=self.region_name,
endpoint_name=endpoint_name,
endpoint_config_name=endpoint_config_name,
production_variants=endpoint_config["ProductionVariants"],
data_capture_config=endpoint_config["DataCaptureConfig"],
tags=tags,
)
self.endpoints[endpoint_name] = endpoint
return endpoint
def describe_endpoint(self, endpoint_name):
try:
return self.endpoints[endpoint_name].response_object
except KeyError:
message = "Could not find endpoint configuration '{}'.".format(
FakeEndpoint.arn_formatter(endpoint_name, self.region_name)
)
raise RESTError(
error_type="ValidationException",
message=message,
template="error_json",
)
def delete_endpoint(self, endpoint_name):
try:
del self.endpoints[endpoint_name]
except KeyError:
message = "Could not find endpoint configuration '{}'.".format(
FakeEndpoint.arn_formatter(endpoint_name, self.region_name)
)
raise RESTError(
error_type="ValidationException",
message=message,
template="error_json",
)
def get_endpoint_by_arn(self, arn):
endpoints = [
endpoint
for endpoint in self.endpoints.values()
if endpoint.endpoint_arn == arn
]
if len(endpoints) == 0:
message = "RecordNotFound"
raise RESTError(
error_type="ValidationException",
message=message,
template="error_json",
)
return endpoints[0]
def get_endpoint_tags(self, arn):
try:
endpoint = self.get_endpoint_by_arn(arn)
return endpoint.tags or []
except RESTError:
return []
def create_training_job(
self,
training_job_name,
hyper_parameters,
algorithm_specification,
role_arn,
input_data_config,
output_data_config,
resource_config,
vpc_config,
stopping_condition,
tags,
enable_network_isolation,
enable_inter_container_traffic_encryption,
enable_managed_spot_training,
checkpoint_config,
debug_hook_config,
debug_rule_configurations,
tensor_board_output_config,
experiment_config,
):
training_job = FakeTrainingJob(
region_name=self.region_name,
training_job_name=training_job_name,
hyper_parameters=hyper_parameters,
algorithm_specification=algorithm_specification,
role_arn=role_arn,
input_data_config=input_data_config,
output_data_config=output_data_config,
resource_config=resource_config,
vpc_config=vpc_config,
stopping_condition=stopping_condition,
tags=tags,
enable_network_isolation=enable_network_isolation,
enable_inter_container_traffic_encryption=enable_inter_container_traffic_encryption,
enable_managed_spot_training=enable_managed_spot_training,
checkpoint_config=checkpoint_config,
debug_hook_config=debug_hook_config,
debug_rule_configurations=debug_rule_configurations,
tensor_board_output_config=tensor_board_output_config,
experiment_config=experiment_config,
)
self.training_jobs[training_job_name] = training_job
return training_job
def describe_training_job(self, training_job_name):
try:
return self.training_jobs[training_job_name].response_object
except KeyError:
message = "Could not find training job '{}'.".format(
FakeTrainingJob.arn_formatter(training_job_name, self.region_name)
)
raise RESTError(
error_type="ValidationException",
message=message,
template="error_json",
)
def delete_training_job(self, training_job_name):
try:
del self.training_jobs[training_job_name]
except KeyError:
message = "Could not find endpoint configuration '{}'.".format(
FakeTrainingJob.arn_formatter(training_job_name, self.region_name)
)
raise RESTError(
error_type="ValidationException",
message=message,
template="error_json",
)
def get_training_job_by_arn(self, arn):
training_jobs = [
training_job
for training_job in self.training_jobs.values()
if training_job.training_job_arn == arn
]
if len(training_jobs) == 0:
message = "RecordNotFound"
raise RESTError(
error_type="ValidationException",
message=message,
template="error_json",
)
return training_jobs[0]
def get_training_job_tags(self, arn):
try:
training_job = self.get_training_job_by_arn(arn)
return training_job.tags or []
except RESTError:
return []
sagemaker_backends = {}
for region, ec2_backend in ec2_backends.items():

View file

@ -122,6 +122,120 @@ class SageMakerResponse(BaseResponse):
@amzn_request_id
def list_tags(self):
arn = self._get_param("ResourceArn")
tags = self.sagemaker_backend.get_notebook_instance_tags(arn)
try:
if ":notebook-instance/" in arn:
tags = self.sagemaker_backend.get_notebook_instance_tags(arn)
elif ":endpoint/" in arn:
tags = self.sagemaker_backend.get_endpoint_tags(arn)
elif ":training-job/" in arn:
tags = self.sagemaker_backend.get_training_job_tags(arn)
else:
tags = []
except AWSError:
tags = []
response = {"Tags": tags}
return 200, {}, json.dumps(response)
@amzn_request_id
def create_endpoint_config(self):
try:
endpoint_config = self.sagemaker_backend.create_endpoint_config(
endpoint_config_name=self._get_param("EndpointConfigName"),
production_variants=self._get_param("ProductionVariants"),
data_capture_config=self._get_param("DataCaptureConfig"),
tags=self._get_param("Tags"),
kms_key_id=self._get_param("KmsKeyId"),
)
response = {
"EndpointConfigArn": endpoint_config.endpoint_config_arn,
}
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
@amzn_request_id
def describe_endpoint_config(self):
endpoint_config_name = self._get_param("EndpointConfigName")
response = self.sagemaker_backend.describe_endpoint_config(endpoint_config_name)
return json.dumps(response)
@amzn_request_id
def delete_endpoint_config(self):
endpoint_config_name = self._get_param("EndpointConfigName")
self.sagemaker_backend.delete_endpoint_config(endpoint_config_name)
return 200, {}, json.dumps("{}")
@amzn_request_id
def create_endpoint(self):
try:
endpoint = self.sagemaker_backend.create_endpoint(
endpoint_name=self._get_param("EndpointName"),
endpoint_config_name=self._get_param("EndpointConfigName"),
tags=self._get_param("Tags"),
)
response = {
"EndpointArn": endpoint.endpoint_arn,
}
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
@amzn_request_id
def describe_endpoint(self):
endpoint_name = self._get_param("EndpointName")
response = self.sagemaker_backend.describe_endpoint(endpoint_name)
return json.dumps(response)
@amzn_request_id
def delete_endpoint(self):
endpoint_name = self._get_param("EndpointName")
self.sagemaker_backend.delete_endpoint(endpoint_name)
return 200, {}, json.dumps("{}")
@amzn_request_id
def create_training_job(self):
try:
training_job = self.sagemaker_backend.create_training_job(
training_job_name=self._get_param("TrainingJobName"),
hyper_parameters=self._get_param("HyperParameters"),
algorithm_specification=self._get_param("AlgorithmSpecification"),
role_arn=self._get_param("RoleArn"),
input_data_config=self._get_param("InputDataConfig"),
output_data_config=self._get_param("OutputDataConfig"),
resource_config=self._get_param("ResourceConfig"),
vpc_config=self._get_param("VpcConfig"),
stopping_condition=self._get_param("StoppingCondition"),
tags=self._get_param("Tags"),
enable_network_isolation=self._get_param(
"EnableNetworkIsolation", False
),
enable_inter_container_traffic_encryption=self._get_param(
"EnableInterContainerTrafficEncryption", False
),
enable_managed_spot_training=self._get_param(
"EnableManagedSpotTraining", False
),
checkpoint_config=self._get_param("CheckpointConfig"),
debug_hook_config=self._get_param("DebugHookConfig"),
debug_rule_configurations=self._get_param("DebugRuleConfigurations"),
tensor_board_output_config=self._get_param("TensorBoardOutputConfig"),
experiment_config=self._get_param("ExperimentConfig"),
)
response = {
"TrainingJobArn": training_job.training_job_arn,
}
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
@amzn_request_id
def describe_training_job(self):
training_job_name = self._get_param("TrainingJobName")
response = self.sagemaker_backend.describe_training_job(training_job_name)
return json.dumps(response)
@amzn_request_id
def delete_training_job(self):
training_job_name = self._get_param("TrainingJobName")
self.sagemaker_backend.delete_training_job(training_job_name)
return 200, {}, json.dumps("{}")

View file

@ -3,7 +3,6 @@ from .responses import SageMakerResponse
url_bases = [
"https?://api.sagemaker.(.+).amazonaws.com",
"https?://api-fips.sagemaker.(.+).amazonaws.com",
]
url_paths = {