Sagemaker models (#3105)
* First failing test, and enough framework to run it. * Rudimentary passing test. * Sagemaker Notebook Support, take-1: create, describe, start, stop, delete. * Added list_tags. * Merged in model support from https://github.com/porthunt/moto/tree/sagemaker-support. * Re-org'd * Fixed up describe_model exception when no matching model. * Segregated tests by Sagemaker entity. Model arn check by regex.. * Python2 compabitility changes. * Added sagemaker to list of known backends. Corrected urls. * Added sagemaker special case to moto.server.infer_service_region_host due to irregular url format (use of 'api' subdomain) to support server mode. * Changes for PR 3105 comments of July 10, 2020 * PR3105 July 10, 2020, 8:55 AM EDT comment: dropped unnecessary re-addition of arn when formulating model list response. * PR 3105 July 15, 2020 9:10 AM EDT Comment: clean-up SageMakerModelBackend.describe_models logic for finding the model in the dict. * Optimized imports Co-authored-by: Joseph Weitekamp <jweite@amazon.com>
This commit is contained in:
parent
3e2a5e7ee8
commit
1b80b0a810
12 changed files with 963 additions and 0 deletions
127
moto/sagemaker/responses.py
Normal file
127
moto/sagemaker/responses.py
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
import json
|
||||
|
||||
from moto.core.responses import BaseResponse
|
||||
from moto.core.utils import amzn_request_id
|
||||
from .exceptions import AWSError
|
||||
from .models import sagemaker_backends
|
||||
|
||||
|
||||
class SageMakerResponse(BaseResponse):
|
||||
@property
|
||||
def sagemaker_backend(self):
|
||||
return sagemaker_backends[self.region]
|
||||
|
||||
@property
|
||||
def request_params(self):
|
||||
try:
|
||||
return json.loads(self.body)
|
||||
except ValueError:
|
||||
return {}
|
||||
|
||||
def describe_model(self):
|
||||
model_name = self._get_param("ModelName")
|
||||
response = self.sagemaker_backend.describe_model(model_name)
|
||||
return json.dumps(response)
|
||||
|
||||
def create_model(self):
|
||||
response = self.sagemaker_backend.create_model(**self.request_params)
|
||||
return json.dumps(response)
|
||||
|
||||
def delete_model(self):
|
||||
model_name = self._get_param("ModelName")
|
||||
response = self.sagemaker_backend.delete_model(model_name)
|
||||
return json.dumps(response)
|
||||
|
||||
def list_models(self):
|
||||
response = self.sagemaker_backend.list_models(**self.request_params)
|
||||
return json.dumps(response)
|
||||
|
||||
def _get_param(self, param, if_none=None):
|
||||
return self.request_params.get(param, if_none)
|
||||
|
||||
@amzn_request_id
|
||||
def create_notebook_instance(self):
|
||||
try:
|
||||
sagemaker_notebook = self.sagemaker_backend.create_notebook_instance(
|
||||
notebook_instance_name=self._get_param("NotebookInstanceName"),
|
||||
instance_type=self._get_param("InstanceType"),
|
||||
subnet_id=self._get_param("SubnetId"),
|
||||
security_group_ids=self._get_param("SecurityGroupIds"),
|
||||
role_arn=self._get_param("RoleArn"),
|
||||
kms_key_id=self._get_param("KmsKeyId"),
|
||||
tags=self._get_param("Tags"),
|
||||
lifecycle_config_name=self._get_param("LifecycleConfigName"),
|
||||
direct_internet_access=self._get_param("DirectInternetAccess"),
|
||||
volume_size_in_gb=self._get_param("VolumeSizeInGB"),
|
||||
accelerator_types=self._get_param("AcceleratorTypes"),
|
||||
default_code_repository=self._get_param("DefaultCodeRepository"),
|
||||
additional_code_repositories=self._get_param(
|
||||
"AdditionalCodeRepositories"
|
||||
),
|
||||
root_access=self._get_param("RootAccess"),
|
||||
)
|
||||
response = {
|
||||
"NotebookInstanceArn": sagemaker_notebook.arn,
|
||||
}
|
||||
return 200, {}, json.dumps(response)
|
||||
except AWSError as err:
|
||||
return err.response()
|
||||
|
||||
@amzn_request_id
|
||||
def describe_notebook_instance(self):
|
||||
notebook_instance_name = self._get_param("NotebookInstanceName")
|
||||
try:
|
||||
notebook_instance = self.sagemaker_backend.get_notebook_instance(
|
||||
notebook_instance_name
|
||||
)
|
||||
response = {
|
||||
"NotebookInstanceArn": notebook_instance.arn,
|
||||
"NotebookInstanceName": notebook_instance.notebook_instance_name,
|
||||
"NotebookInstanceStatus": notebook_instance.status,
|
||||
"Url": notebook_instance.url,
|
||||
"InstanceType": notebook_instance.instance_type,
|
||||
"SubnetId": notebook_instance.subnet_id,
|
||||
"SecurityGroups": notebook_instance.security_group_ids,
|
||||
"RoleArn": notebook_instance.role_arn,
|
||||
"KmsKeyId": notebook_instance.kms_key_id,
|
||||
# ToDo: NetworkInterfaceId
|
||||
"LastModifiedTime": str(notebook_instance.last_modified_time),
|
||||
"CreationTime": str(notebook_instance.creation_time),
|
||||
"NotebookInstanceLifecycleConfigName": notebook_instance.lifecycle_config_name,
|
||||
"DirectInternetAccess": notebook_instance.direct_internet_access,
|
||||
"VolumeSizeInGB": notebook_instance.volume_size_in_gb,
|
||||
"AcceleratorTypes": notebook_instance.accelerator_types,
|
||||
"DefaultCodeRepository": notebook_instance.default_code_repository,
|
||||
"AdditionalCodeRepositories": notebook_instance.additional_code_repositories,
|
||||
"RootAccess": notebook_instance.root_access,
|
||||
}
|
||||
return 200, {}, json.dumps(response)
|
||||
except AWSError as err:
|
||||
return err.response()
|
||||
|
||||
@amzn_request_id
|
||||
def start_notebook_instance(self):
|
||||
notebook_instance_name = self._get_param("NotebookInstanceName")
|
||||
self.sagemaker_backend.start_notebook_instance(notebook_instance_name)
|
||||
return 200, {}, json.dumps("{}")
|
||||
|
||||
@amzn_request_id
|
||||
def stop_notebook_instance(self):
|
||||
notebook_instance_name = self._get_param("NotebookInstanceName")
|
||||
self.sagemaker_backend.stop_notebook_instance(notebook_instance_name)
|
||||
return 200, {}, json.dumps("{}")
|
||||
|
||||
@amzn_request_id
|
||||
def delete_notebook_instance(self):
|
||||
notebook_instance_name = self._get_param("NotebookInstanceName")
|
||||
self.sagemaker_backend.delete_notebook_instance(notebook_instance_name)
|
||||
return 200, {}, json.dumps("{}")
|
||||
|
||||
@amzn_request_id
|
||||
def list_tags(self):
|
||||
arn = self._get_param("ResourceArn")
|
||||
tags = self.sagemaker_backend.get_notebook_instance_tags(arn)
|
||||
response = {"Tags": tags}
|
||||
return 200, {}, json.dumps(response)
|
||||
Loading…
Add table
Add a link
Reference in a new issue