Basic Support for Endpoints, EndpointConfigs and TrainingJobs (#3142)
* Basic upport for Endpoints, EndpointConfigs and TrainingJobs * Dropped extraneous pass statement. Co-authored-by: Joseph Weitekamp <jweite@amazon.com>
This commit is contained in:
parent
a123a22eeb
commit
ba99c61477
5 changed files with 1007 additions and 6 deletions
|
|
@ -1,5 +1,6 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
|
||||
|
|
@ -32,6 +33,288 @@ class BaseObject(BaseModel):
|
|||
return self.gen_response_object()
|
||||
|
||||
|
||||
class FakeTrainingJob(BaseObject):
|
||||
def __init__(
|
||||
self,
|
||||
region_name,
|
||||
training_job_name,
|
||||
hyper_parameters,
|
||||
algorithm_specification,
|
||||
role_arn,
|
||||
input_data_config,
|
||||
output_data_config,
|
||||
resource_config,
|
||||
vpc_config,
|
||||
stopping_condition,
|
||||
tags,
|
||||
enable_network_isolation,
|
||||
enable_inter_container_traffic_encryption,
|
||||
enable_managed_spot_training,
|
||||
checkpoint_config,
|
||||
debug_hook_config,
|
||||
debug_rule_configurations,
|
||||
tensor_board_output_config,
|
||||
experiment_config,
|
||||
):
|
||||
self.training_job_name = training_job_name
|
||||
self.hyper_parameters = hyper_parameters
|
||||
self.algorithm_specification = algorithm_specification
|
||||
self.role_arn = role_arn
|
||||
self.input_data_config = input_data_config
|
||||
self.output_data_config = output_data_config
|
||||
self.resource_config = resource_config
|
||||
self.vpc_config = vpc_config
|
||||
self.stopping_condition = stopping_condition
|
||||
self.tags = tags
|
||||
self.enable_network_isolation = enable_network_isolation
|
||||
self.enable_inter_container_traffic_encryption = (
|
||||
enable_inter_container_traffic_encryption
|
||||
)
|
||||
self.enable_managed_spot_training = enable_managed_spot_training
|
||||
self.checkpoint_config = checkpoint_config
|
||||
self.debug_hook_config = debug_hook_config
|
||||
self.debug_rule_configurations = debug_rule_configurations
|
||||
self.tensor_board_output_config = tensor_board_output_config
|
||||
self.experiment_config = experiment_config
|
||||
self.training_job_arn = FakeTrainingJob.arn_formatter(
|
||||
training_job_name, region_name
|
||||
)
|
||||
self.creation_time = self.last_modified_time = datetime.now().strftime(
|
||||
"%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
self.model_artifacts = {
|
||||
"S3ModelArtifacts": os.path.join(
|
||||
self.output_data_config["S3OutputPath"],
|
||||
self.training_job_name,
|
||||
"output",
|
||||
"model.tar.gz",
|
||||
)
|
||||
}
|
||||
self.training_job_status = "Completed"
|
||||
self.secondary_status = "Completed"
|
||||
self.algorithm_specification["MetricDefinitions"] = [
|
||||
{
|
||||
"Name": "test:dcg",
|
||||
"Regex": "#quality_metric: host=\\S+, test dcg <score>=(\\S+)",
|
||||
}
|
||||
]
|
||||
now_string = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
self.creation_time = now_string
|
||||
self.last_modified_time = now_string
|
||||
self.training_start_time = now_string
|
||||
self.training_end_time = now_string
|
||||
self.secondary_status_transitions = [
|
||||
{
|
||||
"Status": "Starting",
|
||||
"StartTime": self.creation_time,
|
||||
"EndTime": self.creation_time,
|
||||
"StatusMessage": "Preparing the instances for training",
|
||||
}
|
||||
]
|
||||
self.final_metric_data_list = [
|
||||
{
|
||||
"MetricName": "train:progress",
|
||||
"Value": 100.0,
|
||||
"Timestamp": self.creation_time,
|
||||
}
|
||||
]
|
||||
|
||||
@property
|
||||
def response_object(self):
|
||||
response_object = self.gen_response_object()
|
||||
return {
|
||||
k: v for k, v in response_object.items() if v is not None and v != [None]
|
||||
}
|
||||
|
||||
@property
|
||||
def response_create(self):
|
||||
return {"TrainingJobArn": self.training_job_arn}
|
||||
|
||||
@staticmethod
|
||||
def arn_formatter(endpoint_name, region_name):
|
||||
return (
|
||||
"arn:aws:sagemaker:"
|
||||
+ region_name
|
||||
+ ":"
|
||||
+ str(ACCOUNT_ID)
|
||||
+ ":training-job/"
|
||||
+ endpoint_name
|
||||
)
|
||||
|
||||
|
||||
class FakeEndpoint(BaseObject):
|
||||
def __init__(
|
||||
self,
|
||||
region_name,
|
||||
endpoint_name,
|
||||
endpoint_config_name,
|
||||
production_variants,
|
||||
data_capture_config,
|
||||
tags,
|
||||
):
|
||||
self.endpoint_name = endpoint_name
|
||||
self.endpoint_arn = FakeEndpoint.arn_formatter(endpoint_name, region_name)
|
||||
self.endpoint_config_name = endpoint_config_name
|
||||
self.production_variants = production_variants
|
||||
self.data_capture_config = data_capture_config
|
||||
self.tags = tags or []
|
||||
self.endpoint_status = "InService"
|
||||
self.failure_reason = None
|
||||
self.creation_time = self.last_modified_time = datetime.now().strftime(
|
||||
"%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
|
||||
@property
|
||||
def response_object(self):
|
||||
response_object = self.gen_response_object()
|
||||
return {
|
||||
k: v for k, v in response_object.items() if v is not None and v != [None]
|
||||
}
|
||||
|
||||
@property
|
||||
def response_create(self):
|
||||
return {"EndpointArn": self.endpoint_arn}
|
||||
|
||||
@staticmethod
|
||||
def arn_formatter(endpoint_name, region_name):
|
||||
return (
|
||||
"arn:aws:sagemaker:"
|
||||
+ region_name
|
||||
+ ":"
|
||||
+ str(ACCOUNT_ID)
|
||||
+ ":endpoint/"
|
||||
+ endpoint_name
|
||||
)
|
||||
|
||||
|
||||
class FakeEndpointConfig(BaseObject):
|
||||
def __init__(
|
||||
self,
|
||||
region_name,
|
||||
endpoint_config_name,
|
||||
production_variants,
|
||||
data_capture_config,
|
||||
tags,
|
||||
kms_key_id,
|
||||
):
|
||||
self.validate_production_variants(production_variants)
|
||||
|
||||
self.endpoint_config_name = endpoint_config_name
|
||||
self.endpoint_config_arn = FakeEndpointConfig.arn_formatter(
|
||||
endpoint_config_name, region_name
|
||||
)
|
||||
self.production_variants = production_variants or []
|
||||
self.data_capture_config = data_capture_config or {}
|
||||
self.tags = tags or []
|
||||
self.kms_key_id = kms_key_id
|
||||
self.creation_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
def validate_production_variants(self, production_variants):
|
||||
for production_variant in production_variants:
|
||||
self.validate_instance_type(production_variant["InstanceType"])
|
||||
|
||||
def validate_instance_type(self, instance_type):
|
||||
VALID_INSTANCE_TYPES = [
|
||||
"ml.r5d.12xlarge",
|
||||
"ml.r5.12xlarge",
|
||||
"ml.p2.xlarge",
|
||||
"ml.m5.4xlarge",
|
||||
"ml.m4.16xlarge",
|
||||
"ml.r5d.24xlarge",
|
||||
"ml.r5.24xlarge",
|
||||
"ml.p3.16xlarge",
|
||||
"ml.m5d.xlarge",
|
||||
"ml.m5.large",
|
||||
"ml.t2.xlarge",
|
||||
"ml.p2.16xlarge",
|
||||
"ml.m5d.12xlarge",
|
||||
"ml.inf1.2xlarge",
|
||||
"ml.m5d.24xlarge",
|
||||
"ml.c4.2xlarge",
|
||||
"ml.c5.2xlarge",
|
||||
"ml.c4.4xlarge",
|
||||
"ml.inf1.6xlarge",
|
||||
"ml.c5d.2xlarge",
|
||||
"ml.c5.4xlarge",
|
||||
"ml.g4dn.xlarge",
|
||||
"ml.g4dn.12xlarge",
|
||||
"ml.c5d.4xlarge",
|
||||
"ml.g4dn.2xlarge",
|
||||
"ml.c4.8xlarge",
|
||||
"ml.c4.large",
|
||||
"ml.c5d.xlarge",
|
||||
"ml.c5.large",
|
||||
"ml.g4dn.4xlarge",
|
||||
"ml.c5.9xlarge",
|
||||
"ml.g4dn.16xlarge",
|
||||
"ml.c5d.large",
|
||||
"ml.c5.xlarge",
|
||||
"ml.c5d.9xlarge",
|
||||
"ml.c4.xlarge",
|
||||
"ml.inf1.xlarge",
|
||||
"ml.g4dn.8xlarge",
|
||||
"ml.inf1.24xlarge",
|
||||
"ml.m5d.2xlarge",
|
||||
"ml.t2.2xlarge",
|
||||
"ml.c5d.18xlarge",
|
||||
"ml.m5d.4xlarge",
|
||||
"ml.t2.medium",
|
||||
"ml.c5.18xlarge",
|
||||
"ml.r5d.2xlarge",
|
||||
"ml.r5.2xlarge",
|
||||
"ml.p3.2xlarge",
|
||||
"ml.m5d.large",
|
||||
"ml.m5.xlarge",
|
||||
"ml.m4.10xlarge",
|
||||
"ml.t2.large",
|
||||
"ml.r5d.4xlarge",
|
||||
"ml.r5.4xlarge",
|
||||
"ml.m5.12xlarge",
|
||||
"ml.m4.xlarge",
|
||||
"ml.m5.24xlarge",
|
||||
"ml.m4.2xlarge",
|
||||
"ml.p2.8xlarge",
|
||||
"ml.m5.2xlarge",
|
||||
"ml.r5d.xlarge",
|
||||
"ml.r5d.large",
|
||||
"ml.r5.xlarge",
|
||||
"ml.r5.large",
|
||||
"ml.p3.8xlarge",
|
||||
"ml.m4.4xlarge",
|
||||
]
|
||||
if not validators.is_one_of(instance_type, VALID_INSTANCE_TYPES):
|
||||
message = "Value '{}' at 'instanceType' failed to satisfy constraint: Member must satisfy enum value set: {}".format(
|
||||
instance_type, VALID_INSTANCE_TYPES
|
||||
)
|
||||
raise RESTError(
|
||||
error_type="ValidationException",
|
||||
message=message,
|
||||
template="error_json",
|
||||
)
|
||||
|
||||
@property
|
||||
def response_object(self):
|
||||
response_object = self.gen_response_object()
|
||||
return {
|
||||
k: v for k, v in response_object.items() if v is not None and v != [None]
|
||||
}
|
||||
|
||||
@property
|
||||
def response_create(self):
|
||||
return {"EndpointConfigArn": self.endpoint_config_arn}
|
||||
|
||||
@staticmethod
|
||||
def arn_formatter(model_name, region_name):
|
||||
return (
|
||||
"arn:aws:sagemaker:"
|
||||
+ region_name
|
||||
+ ":"
|
||||
+ str(ACCOUNT_ID)
|
||||
+ ":endpoint-config/"
|
||||
+ model_name
|
||||
)
|
||||
|
||||
|
||||
class Model(BaseObject):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -238,6 +521,9 @@ class SageMakerModelBackend(BaseBackend):
|
|||
def __init__(self, region_name=None):
|
||||
self._models = {}
|
||||
self.notebook_instances = {}
|
||||
self.endpoint_configs = {}
|
||||
self.endpoints = {}
|
||||
self.training_jobs = {}
|
||||
self.region_name = region_name
|
||||
|
||||
def reset(self):
|
||||
|
|
@ -305,10 +591,10 @@ class SageMakerModelBackend(BaseBackend):
|
|||
self._validate_unique_notebook_instance_name(notebook_instance_name)
|
||||
|
||||
notebook_instance = FakeSagemakerNotebookInstance(
|
||||
self.region_name,
|
||||
notebook_instance_name,
|
||||
instance_type,
|
||||
role_arn,
|
||||
region_name=self.region_name,
|
||||
notebook_instance_name=notebook_instance_name,
|
||||
instance_type=instance_type,
|
||||
role_arn=role_arn,
|
||||
subnet_id=subnet_id,
|
||||
security_group_ids=security_group_ids,
|
||||
kms_key_id=kms_key_id,
|
||||
|
|
@ -392,6 +678,235 @@ class SageMakerModelBackend(BaseBackend):
|
|||
except RESTError:
|
||||
return []
|
||||
|
||||
def create_endpoint_config(
|
||||
self,
|
||||
endpoint_config_name,
|
||||
production_variants,
|
||||
data_capture_config,
|
||||
tags,
|
||||
kms_key_id,
|
||||
):
|
||||
endpoint_config = FakeEndpointConfig(
|
||||
region_name=self.region_name,
|
||||
endpoint_config_name=endpoint_config_name,
|
||||
production_variants=production_variants,
|
||||
data_capture_config=data_capture_config,
|
||||
tags=tags,
|
||||
kms_key_id=kms_key_id,
|
||||
)
|
||||
self.validate_production_variants(production_variants)
|
||||
|
||||
self.endpoint_configs[endpoint_config_name] = endpoint_config
|
||||
return endpoint_config
|
||||
|
||||
def validate_production_variants(self, production_variants):
|
||||
for production_variant in production_variants:
|
||||
if production_variant["ModelName"] not in self._models:
|
||||
message = "Could not find model '{}'.".format(
|
||||
Model.arn_for_model_name(
|
||||
production_variant["ModelName"], self.region_name
|
||||
)
|
||||
)
|
||||
raise RESTError(
|
||||
error_type="ValidationException",
|
||||
message=message,
|
||||
template="error_json",
|
||||
)
|
||||
|
||||
def describe_endpoint_config(self, endpoint_config_name):
|
||||
try:
|
||||
return self.endpoint_configs[endpoint_config_name].response_object
|
||||
except KeyError:
|
||||
message = "Could not find endpoint configuration '{}'.".format(
|
||||
FakeEndpointConfig.arn_formatter(endpoint_config_name, self.region_name)
|
||||
)
|
||||
raise RESTError(
|
||||
error_type="ValidationException",
|
||||
message=message,
|
||||
template="error_json",
|
||||
)
|
||||
|
||||
def delete_endpoint_config(self, endpoint_config_name):
|
||||
try:
|
||||
del self.endpoint_configs[endpoint_config_name]
|
||||
except KeyError:
|
||||
message = "Could not find endpoint configuration '{}'.".format(
|
||||
FakeEndpointConfig.arn_formatter(endpoint_config_name, self.region_name)
|
||||
)
|
||||
raise RESTError(
|
||||
error_type="ValidationException",
|
||||
message=message,
|
||||
template="error_json",
|
||||
)
|
||||
|
||||
def create_endpoint(
|
||||
self, endpoint_name, endpoint_config_name, tags,
|
||||
):
|
||||
try:
|
||||
endpoint_config = self.describe_endpoint_config(endpoint_config_name)
|
||||
except KeyError:
|
||||
message = "Could not find endpoint_config '{}'.".format(
|
||||
FakeEndpointConfig.arn_formatter(endpoint_config_name, self.region_name)
|
||||
)
|
||||
raise RESTError(
|
||||
error_type="ValidationException",
|
||||
message=message,
|
||||
template="error_json",
|
||||
)
|
||||
|
||||
endpoint = FakeEndpoint(
|
||||
region_name=self.region_name,
|
||||
endpoint_name=endpoint_name,
|
||||
endpoint_config_name=endpoint_config_name,
|
||||
production_variants=endpoint_config["ProductionVariants"],
|
||||
data_capture_config=endpoint_config["DataCaptureConfig"],
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
self.endpoints[endpoint_name] = endpoint
|
||||
return endpoint
|
||||
|
||||
def describe_endpoint(self, endpoint_name):
|
||||
try:
|
||||
return self.endpoints[endpoint_name].response_object
|
||||
except KeyError:
|
||||
message = "Could not find endpoint configuration '{}'.".format(
|
||||
FakeEndpoint.arn_formatter(endpoint_name, self.region_name)
|
||||
)
|
||||
raise RESTError(
|
||||
error_type="ValidationException",
|
||||
message=message,
|
||||
template="error_json",
|
||||
)
|
||||
|
||||
def delete_endpoint(self, endpoint_name):
|
||||
try:
|
||||
del self.endpoints[endpoint_name]
|
||||
except KeyError:
|
||||
message = "Could not find endpoint configuration '{}'.".format(
|
||||
FakeEndpoint.arn_formatter(endpoint_name, self.region_name)
|
||||
)
|
||||
raise RESTError(
|
||||
error_type="ValidationException",
|
||||
message=message,
|
||||
template="error_json",
|
||||
)
|
||||
|
||||
def get_endpoint_by_arn(self, arn):
|
||||
endpoints = [
|
||||
endpoint
|
||||
for endpoint in self.endpoints.values()
|
||||
if endpoint.endpoint_arn == arn
|
||||
]
|
||||
if len(endpoints) == 0:
|
||||
message = "RecordNotFound"
|
||||
raise RESTError(
|
||||
error_type="ValidationException",
|
||||
message=message,
|
||||
template="error_json",
|
||||
)
|
||||
return endpoints[0]
|
||||
|
||||
def get_endpoint_tags(self, arn):
|
||||
try:
|
||||
endpoint = self.get_endpoint_by_arn(arn)
|
||||
return endpoint.tags or []
|
||||
except RESTError:
|
||||
return []
|
||||
|
||||
def create_training_job(
|
||||
self,
|
||||
training_job_name,
|
||||
hyper_parameters,
|
||||
algorithm_specification,
|
||||
role_arn,
|
||||
input_data_config,
|
||||
output_data_config,
|
||||
resource_config,
|
||||
vpc_config,
|
||||
stopping_condition,
|
||||
tags,
|
||||
enable_network_isolation,
|
||||
enable_inter_container_traffic_encryption,
|
||||
enable_managed_spot_training,
|
||||
checkpoint_config,
|
||||
debug_hook_config,
|
||||
debug_rule_configurations,
|
||||
tensor_board_output_config,
|
||||
experiment_config,
|
||||
):
|
||||
training_job = FakeTrainingJob(
|
||||
region_name=self.region_name,
|
||||
training_job_name=training_job_name,
|
||||
hyper_parameters=hyper_parameters,
|
||||
algorithm_specification=algorithm_specification,
|
||||
role_arn=role_arn,
|
||||
input_data_config=input_data_config,
|
||||
output_data_config=output_data_config,
|
||||
resource_config=resource_config,
|
||||
vpc_config=vpc_config,
|
||||
stopping_condition=stopping_condition,
|
||||
tags=tags,
|
||||
enable_network_isolation=enable_network_isolation,
|
||||
enable_inter_container_traffic_encryption=enable_inter_container_traffic_encryption,
|
||||
enable_managed_spot_training=enable_managed_spot_training,
|
||||
checkpoint_config=checkpoint_config,
|
||||
debug_hook_config=debug_hook_config,
|
||||
debug_rule_configurations=debug_rule_configurations,
|
||||
tensor_board_output_config=tensor_board_output_config,
|
||||
experiment_config=experiment_config,
|
||||
)
|
||||
self.training_jobs[training_job_name] = training_job
|
||||
return training_job
|
||||
|
||||
def describe_training_job(self, training_job_name):
|
||||
try:
|
||||
return self.training_jobs[training_job_name].response_object
|
||||
except KeyError:
|
||||
message = "Could not find training job '{}'.".format(
|
||||
FakeTrainingJob.arn_formatter(training_job_name, self.region_name)
|
||||
)
|
||||
raise RESTError(
|
||||
error_type="ValidationException",
|
||||
message=message,
|
||||
template="error_json",
|
||||
)
|
||||
|
||||
def delete_training_job(self, training_job_name):
|
||||
try:
|
||||
del self.training_jobs[training_job_name]
|
||||
except KeyError:
|
||||
message = "Could not find endpoint configuration '{}'.".format(
|
||||
FakeTrainingJob.arn_formatter(training_job_name, self.region_name)
|
||||
)
|
||||
raise RESTError(
|
||||
error_type="ValidationException",
|
||||
message=message,
|
||||
template="error_json",
|
||||
)
|
||||
|
||||
def get_training_job_by_arn(self, arn):
|
||||
training_jobs = [
|
||||
training_job
|
||||
for training_job in self.training_jobs.values()
|
||||
if training_job.training_job_arn == arn
|
||||
]
|
||||
if len(training_jobs) == 0:
|
||||
message = "RecordNotFound"
|
||||
raise RESTError(
|
||||
error_type="ValidationException",
|
||||
message=message,
|
||||
template="error_json",
|
||||
)
|
||||
return training_jobs[0]
|
||||
|
||||
def get_training_job_tags(self, arn):
|
||||
try:
|
||||
training_job = self.get_training_job_by_arn(arn)
|
||||
return training_job.tags or []
|
||||
except RESTError:
|
||||
return []
|
||||
|
||||
|
||||
sagemaker_backends = {}
|
||||
for region, ec2_backend in ec2_backends.items():
|
||||
|
|
|
|||
|
|
@ -122,6 +122,120 @@ class SageMakerResponse(BaseResponse):
|
|||
@amzn_request_id
|
||||
def list_tags(self):
|
||||
arn = self._get_param("ResourceArn")
|
||||
tags = self.sagemaker_backend.get_notebook_instance_tags(arn)
|
||||
try:
|
||||
if ":notebook-instance/" in arn:
|
||||
tags = self.sagemaker_backend.get_notebook_instance_tags(arn)
|
||||
elif ":endpoint/" in arn:
|
||||
tags = self.sagemaker_backend.get_endpoint_tags(arn)
|
||||
elif ":training-job/" in arn:
|
||||
tags = self.sagemaker_backend.get_training_job_tags(arn)
|
||||
else:
|
||||
tags = []
|
||||
except AWSError:
|
||||
tags = []
|
||||
response = {"Tags": tags}
|
||||
return 200, {}, json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
def create_endpoint_config(self):
|
||||
try:
|
||||
endpoint_config = self.sagemaker_backend.create_endpoint_config(
|
||||
endpoint_config_name=self._get_param("EndpointConfigName"),
|
||||
production_variants=self._get_param("ProductionVariants"),
|
||||
data_capture_config=self._get_param("DataCaptureConfig"),
|
||||
tags=self._get_param("Tags"),
|
||||
kms_key_id=self._get_param("KmsKeyId"),
|
||||
)
|
||||
response = {
|
||||
"EndpointConfigArn": endpoint_config.endpoint_config_arn,
|
||||
}
|
||||
return 200, {}, json.dumps(response)
|
||||
except AWSError as err:
|
||||
return err.response()
|
||||
|
||||
@amzn_request_id
|
||||
def describe_endpoint_config(self):
|
||||
endpoint_config_name = self._get_param("EndpointConfigName")
|
||||
response = self.sagemaker_backend.describe_endpoint_config(endpoint_config_name)
|
||||
return json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
def delete_endpoint_config(self):
|
||||
endpoint_config_name = self._get_param("EndpointConfigName")
|
||||
self.sagemaker_backend.delete_endpoint_config(endpoint_config_name)
|
||||
return 200, {}, json.dumps("{}")
|
||||
|
||||
@amzn_request_id
|
||||
def create_endpoint(self):
|
||||
try:
|
||||
endpoint = self.sagemaker_backend.create_endpoint(
|
||||
endpoint_name=self._get_param("EndpointName"),
|
||||
endpoint_config_name=self._get_param("EndpointConfigName"),
|
||||
tags=self._get_param("Tags"),
|
||||
)
|
||||
response = {
|
||||
"EndpointArn": endpoint.endpoint_arn,
|
||||
}
|
||||
return 200, {}, json.dumps(response)
|
||||
except AWSError as err:
|
||||
return err.response()
|
||||
|
||||
@amzn_request_id
|
||||
def describe_endpoint(self):
|
||||
endpoint_name = self._get_param("EndpointName")
|
||||
response = self.sagemaker_backend.describe_endpoint(endpoint_name)
|
||||
return json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
def delete_endpoint(self):
|
||||
endpoint_name = self._get_param("EndpointName")
|
||||
self.sagemaker_backend.delete_endpoint(endpoint_name)
|
||||
return 200, {}, json.dumps("{}")
|
||||
|
||||
@amzn_request_id
|
||||
def create_training_job(self):
|
||||
try:
|
||||
training_job = self.sagemaker_backend.create_training_job(
|
||||
training_job_name=self._get_param("TrainingJobName"),
|
||||
hyper_parameters=self._get_param("HyperParameters"),
|
||||
algorithm_specification=self._get_param("AlgorithmSpecification"),
|
||||
role_arn=self._get_param("RoleArn"),
|
||||
input_data_config=self._get_param("InputDataConfig"),
|
||||
output_data_config=self._get_param("OutputDataConfig"),
|
||||
resource_config=self._get_param("ResourceConfig"),
|
||||
vpc_config=self._get_param("VpcConfig"),
|
||||
stopping_condition=self._get_param("StoppingCondition"),
|
||||
tags=self._get_param("Tags"),
|
||||
enable_network_isolation=self._get_param(
|
||||
"EnableNetworkIsolation", False
|
||||
),
|
||||
enable_inter_container_traffic_encryption=self._get_param(
|
||||
"EnableInterContainerTrafficEncryption", False
|
||||
),
|
||||
enable_managed_spot_training=self._get_param(
|
||||
"EnableManagedSpotTraining", False
|
||||
),
|
||||
checkpoint_config=self._get_param("CheckpointConfig"),
|
||||
debug_hook_config=self._get_param("DebugHookConfig"),
|
||||
debug_rule_configurations=self._get_param("DebugRuleConfigurations"),
|
||||
tensor_board_output_config=self._get_param("TensorBoardOutputConfig"),
|
||||
experiment_config=self._get_param("ExperimentConfig"),
|
||||
)
|
||||
response = {
|
||||
"TrainingJobArn": training_job.training_job_arn,
|
||||
}
|
||||
return 200, {}, json.dumps(response)
|
||||
except AWSError as err:
|
||||
return err.response()
|
||||
|
||||
@amzn_request_id
|
||||
def describe_training_job(self):
|
||||
training_job_name = self._get_param("TrainingJobName")
|
||||
response = self.sagemaker_backend.describe_training_job(training_job_name)
|
||||
return json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
def delete_training_job(self):
|
||||
training_job_name = self._get_param("TrainingJobName")
|
||||
self.sagemaker_backend.delete_training_job(training_job_name)
|
||||
return 200, {}, json.dumps("{}")
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ from .responses import SageMakerResponse
|
|||
|
||||
url_bases = [
|
||||
"https?://api.sagemaker.(.+).amazonaws.com",
|
||||
"https?://api-fips.sagemaker.(.+).amazonaws.com",
|
||||
]
|
||||
|
||||
url_paths = {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue