Basic Support for Endpoints, EndpointConfigs and TrainingJobs (#3142)

* Basic upport for Endpoints, EndpointConfigs and TrainingJobs

* Dropped extraneous pass statement.

Co-authored-by: Joseph Weitekamp <jweite@amazon.com>
This commit is contained in:
jweite 2020-07-19 10:06:48 -04:00 committed by GitHub
commit ba99c61477
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 1007 additions and 6 deletions

View file

@ -122,6 +122,120 @@ class SageMakerResponse(BaseResponse):
@amzn_request_id
def list_tags(self):
arn = self._get_param("ResourceArn")
tags = self.sagemaker_backend.get_notebook_instance_tags(arn)
try:
if ":notebook-instance/" in arn:
tags = self.sagemaker_backend.get_notebook_instance_tags(arn)
elif ":endpoint/" in arn:
tags = self.sagemaker_backend.get_endpoint_tags(arn)
elif ":training-job/" in arn:
tags = self.sagemaker_backend.get_training_job_tags(arn)
else:
tags = []
except AWSError:
tags = []
response = {"Tags": tags}
return 200, {}, json.dumps(response)
@amzn_request_id
def create_endpoint_config(self):
try:
endpoint_config = self.sagemaker_backend.create_endpoint_config(
endpoint_config_name=self._get_param("EndpointConfigName"),
production_variants=self._get_param("ProductionVariants"),
data_capture_config=self._get_param("DataCaptureConfig"),
tags=self._get_param("Tags"),
kms_key_id=self._get_param("KmsKeyId"),
)
response = {
"EndpointConfigArn": endpoint_config.endpoint_config_arn,
}
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
@amzn_request_id
def describe_endpoint_config(self):
endpoint_config_name = self._get_param("EndpointConfigName")
response = self.sagemaker_backend.describe_endpoint_config(endpoint_config_name)
return json.dumps(response)
@amzn_request_id
def delete_endpoint_config(self):
endpoint_config_name = self._get_param("EndpointConfigName")
self.sagemaker_backend.delete_endpoint_config(endpoint_config_name)
return 200, {}, json.dumps("{}")
@amzn_request_id
def create_endpoint(self):
try:
endpoint = self.sagemaker_backend.create_endpoint(
endpoint_name=self._get_param("EndpointName"),
endpoint_config_name=self._get_param("EndpointConfigName"),
tags=self._get_param("Tags"),
)
response = {
"EndpointArn": endpoint.endpoint_arn,
}
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
@amzn_request_id
def describe_endpoint(self):
endpoint_name = self._get_param("EndpointName")
response = self.sagemaker_backend.describe_endpoint(endpoint_name)
return json.dumps(response)
@amzn_request_id
def delete_endpoint(self):
endpoint_name = self._get_param("EndpointName")
self.sagemaker_backend.delete_endpoint(endpoint_name)
return 200, {}, json.dumps("{}")
@amzn_request_id
def create_training_job(self):
try:
training_job = self.sagemaker_backend.create_training_job(
training_job_name=self._get_param("TrainingJobName"),
hyper_parameters=self._get_param("HyperParameters"),
algorithm_specification=self._get_param("AlgorithmSpecification"),
role_arn=self._get_param("RoleArn"),
input_data_config=self._get_param("InputDataConfig"),
output_data_config=self._get_param("OutputDataConfig"),
resource_config=self._get_param("ResourceConfig"),
vpc_config=self._get_param("VpcConfig"),
stopping_condition=self._get_param("StoppingCondition"),
tags=self._get_param("Tags"),
enable_network_isolation=self._get_param(
"EnableNetworkIsolation", False
),
enable_inter_container_traffic_encryption=self._get_param(
"EnableInterContainerTrafficEncryption", False
),
enable_managed_spot_training=self._get_param(
"EnableManagedSpotTraining", False
),
checkpoint_config=self._get_param("CheckpointConfig"),
debug_hook_config=self._get_param("DebugHookConfig"),
debug_rule_configurations=self._get_param("DebugRuleConfigurations"),
tensor_board_output_config=self._get_param("TensorBoardOutputConfig"),
experiment_config=self._get_param("ExperimentConfig"),
)
response = {
"TrainingJobArn": training_job.training_job_arn,
}
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
@amzn_request_id
def describe_training_job(self):
training_job_name = self._get_param("TrainingJobName")
response = self.sagemaker_backend.describe_training_job(training_job_name)
return json.dumps(response)
@amzn_request_id
def delete_training_job(self):
training_job_name = self._get_param("TrainingJobName")
self.sagemaker_backend.delete_training_job(training_job_name)
return 200, {}, json.dumps("{}")