Basic Support for Endpoints, EndpointConfigs and TrainingJobs (#3142)
* Basic upport for Endpoints, EndpointConfigs and TrainingJobs * Dropped extraneous pass statement. Co-authored-by: Joseph Weitekamp <jweite@amazon.com>
This commit is contained in:
parent
a123a22eeb
commit
ba99c61477
5 changed files with 1007 additions and 6 deletions
|
|
@ -122,6 +122,120 @@ class SageMakerResponse(BaseResponse):
|
|||
@amzn_request_id
|
||||
def list_tags(self):
|
||||
arn = self._get_param("ResourceArn")
|
||||
tags = self.sagemaker_backend.get_notebook_instance_tags(arn)
|
||||
try:
|
||||
if ":notebook-instance/" in arn:
|
||||
tags = self.sagemaker_backend.get_notebook_instance_tags(arn)
|
||||
elif ":endpoint/" in arn:
|
||||
tags = self.sagemaker_backend.get_endpoint_tags(arn)
|
||||
elif ":training-job/" in arn:
|
||||
tags = self.sagemaker_backend.get_training_job_tags(arn)
|
||||
else:
|
||||
tags = []
|
||||
except AWSError:
|
||||
tags = []
|
||||
response = {"Tags": tags}
|
||||
return 200, {}, json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
def create_endpoint_config(self):
|
||||
try:
|
||||
endpoint_config = self.sagemaker_backend.create_endpoint_config(
|
||||
endpoint_config_name=self._get_param("EndpointConfigName"),
|
||||
production_variants=self._get_param("ProductionVariants"),
|
||||
data_capture_config=self._get_param("DataCaptureConfig"),
|
||||
tags=self._get_param("Tags"),
|
||||
kms_key_id=self._get_param("KmsKeyId"),
|
||||
)
|
||||
response = {
|
||||
"EndpointConfigArn": endpoint_config.endpoint_config_arn,
|
||||
}
|
||||
return 200, {}, json.dumps(response)
|
||||
except AWSError as err:
|
||||
return err.response()
|
||||
|
||||
@amzn_request_id
|
||||
def describe_endpoint_config(self):
|
||||
endpoint_config_name = self._get_param("EndpointConfigName")
|
||||
response = self.sagemaker_backend.describe_endpoint_config(endpoint_config_name)
|
||||
return json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
def delete_endpoint_config(self):
|
||||
endpoint_config_name = self._get_param("EndpointConfigName")
|
||||
self.sagemaker_backend.delete_endpoint_config(endpoint_config_name)
|
||||
return 200, {}, json.dumps("{}")
|
||||
|
||||
@amzn_request_id
|
||||
def create_endpoint(self):
|
||||
try:
|
||||
endpoint = self.sagemaker_backend.create_endpoint(
|
||||
endpoint_name=self._get_param("EndpointName"),
|
||||
endpoint_config_name=self._get_param("EndpointConfigName"),
|
||||
tags=self._get_param("Tags"),
|
||||
)
|
||||
response = {
|
||||
"EndpointArn": endpoint.endpoint_arn,
|
||||
}
|
||||
return 200, {}, json.dumps(response)
|
||||
except AWSError as err:
|
||||
return err.response()
|
||||
|
||||
@amzn_request_id
|
||||
def describe_endpoint(self):
|
||||
endpoint_name = self._get_param("EndpointName")
|
||||
response = self.sagemaker_backend.describe_endpoint(endpoint_name)
|
||||
return json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
def delete_endpoint(self):
|
||||
endpoint_name = self._get_param("EndpointName")
|
||||
self.sagemaker_backend.delete_endpoint(endpoint_name)
|
||||
return 200, {}, json.dumps("{}")
|
||||
|
||||
@amzn_request_id
|
||||
def create_training_job(self):
|
||||
try:
|
||||
training_job = self.sagemaker_backend.create_training_job(
|
||||
training_job_name=self._get_param("TrainingJobName"),
|
||||
hyper_parameters=self._get_param("HyperParameters"),
|
||||
algorithm_specification=self._get_param("AlgorithmSpecification"),
|
||||
role_arn=self._get_param("RoleArn"),
|
||||
input_data_config=self._get_param("InputDataConfig"),
|
||||
output_data_config=self._get_param("OutputDataConfig"),
|
||||
resource_config=self._get_param("ResourceConfig"),
|
||||
vpc_config=self._get_param("VpcConfig"),
|
||||
stopping_condition=self._get_param("StoppingCondition"),
|
||||
tags=self._get_param("Tags"),
|
||||
enable_network_isolation=self._get_param(
|
||||
"EnableNetworkIsolation", False
|
||||
),
|
||||
enable_inter_container_traffic_encryption=self._get_param(
|
||||
"EnableInterContainerTrafficEncryption", False
|
||||
),
|
||||
enable_managed_spot_training=self._get_param(
|
||||
"EnableManagedSpotTraining", False
|
||||
),
|
||||
checkpoint_config=self._get_param("CheckpointConfig"),
|
||||
debug_hook_config=self._get_param("DebugHookConfig"),
|
||||
debug_rule_configurations=self._get_param("DebugRuleConfigurations"),
|
||||
tensor_board_output_config=self._get_param("TensorBoardOutputConfig"),
|
||||
experiment_config=self._get_param("ExperimentConfig"),
|
||||
)
|
||||
response = {
|
||||
"TrainingJobArn": training_job.training_job_arn,
|
||||
}
|
||||
return 200, {}, json.dumps(response)
|
||||
except AWSError as err:
|
||||
return err.response()
|
||||
|
||||
@amzn_request_id
|
||||
def describe_training_job(self):
|
||||
training_job_name = self._get_param("TrainingJobName")
|
||||
response = self.sagemaker_backend.describe_training_job(training_job_name)
|
||||
return json.dumps(response)
|
||||
|
||||
@amzn_request_id
|
||||
def delete_training_job(self):
|
||||
training_job_name = self._get_param("TrainingJobName")
|
||||
self.sagemaker_backend.delete_training_job(training_job_name)
|
||||
return 200, {}, json.dumps("{}")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue