Basic Support for Endpoints, EndpointConfigs and TrainingJobs (#3142)
* Basic upport for Endpoints, EndpointConfigs and TrainingJobs * Dropped extraneous pass statement. Co-authored-by: Joseph Weitekamp <jweite@amazon.com>
This commit is contained in:
parent
a123a22eeb
commit
ba99c61477
5 changed files with 1007 additions and 6 deletions
246
tests/test_sagemaker/test_sagemaker_endpoint.py
Normal file
246
tests/test_sagemaker/test_sagemaker_endpoint.py
Normal file
|
|
@ -0,0 +1,246 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import datetime
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError, ParamValidationError
|
||||
import sure # noqa
|
||||
|
||||
from moto import mock_sagemaker
|
||||
from moto.sts.models import ACCOUNT_ID
|
||||
from nose.tools import assert_true, assert_equal, assert_raises
|
||||
|
||||
TEST_REGION_NAME = "us-east-1"
|
||||
FAKE_ROLE_ARN = "arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID)
|
||||
GENERIC_TAGS_PARAM = [
|
||||
{"Key": "newkey1", "Value": "newval1"},
|
||||
{"Key": "newkey2", "Value": "newval2"},
|
||||
]
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_create_endpoint_config():
|
||||
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
model_name = "MyModel"
|
||||
production_variants = [
|
||||
{
|
||||
"VariantName": "MyProductionVariant",
|
||||
"ModelName": model_name,
|
||||
"InitialInstanceCount": 1,
|
||||
"InstanceType": "ml.t2.medium",
|
||||
},
|
||||
]
|
||||
|
||||
endpoint_config_name = "MyEndpointConfig"
|
||||
with assert_raises(ClientError) as e:
|
||||
sagemaker.create_endpoint_config(
|
||||
EndpointConfigName=endpoint_config_name,
|
||||
ProductionVariants=production_variants,
|
||||
)
|
||||
assert_true(
|
||||
e.exception.response["Error"]["Message"].startswith("Could not find model")
|
||||
)
|
||||
|
||||
_create_model(sagemaker, model_name)
|
||||
resp = sagemaker.create_endpoint_config(
|
||||
EndpointConfigName=endpoint_config_name, ProductionVariants=production_variants
|
||||
)
|
||||
resp["EndpointConfigArn"].should.match(
|
||||
r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name)
|
||||
)
|
||||
|
||||
resp = sagemaker.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
|
||||
resp["EndpointConfigArn"].should.match(
|
||||
r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name)
|
||||
)
|
||||
resp["EndpointConfigName"].should.equal(endpoint_config_name)
|
||||
resp["ProductionVariants"].should.equal(production_variants)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_delete_endpoint_config():
|
||||
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
model_name = "MyModel"
|
||||
_create_model(sagemaker, model_name)
|
||||
|
||||
endpoint_config_name = "MyEndpointConfig"
|
||||
production_variants = [
|
||||
{
|
||||
"VariantName": "MyProductionVariant",
|
||||
"ModelName": model_name,
|
||||
"InitialInstanceCount": 1,
|
||||
"InstanceType": "ml.t2.medium",
|
||||
},
|
||||
]
|
||||
|
||||
resp = sagemaker.create_endpoint_config(
|
||||
EndpointConfigName=endpoint_config_name, ProductionVariants=production_variants
|
||||
)
|
||||
resp["EndpointConfigArn"].should.match(
|
||||
r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name)
|
||||
)
|
||||
|
||||
resp = sagemaker.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
|
||||
resp["EndpointConfigArn"].should.match(
|
||||
r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name)
|
||||
)
|
||||
|
||||
resp = sagemaker.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
|
||||
with assert_raises(ClientError) as e:
|
||||
sagemaker.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
|
||||
assert_true(
|
||||
e.exception.response["Error"]["Message"].startswith(
|
||||
"Could not find endpoint configuration"
|
||||
)
|
||||
)
|
||||
|
||||
with assert_raises(ClientError) as e:
|
||||
sagemaker.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
|
||||
assert_true(
|
||||
e.exception.response["Error"]["Message"].startswith(
|
||||
"Could not find endpoint configuration"
|
||||
)
|
||||
)
|
||||
pass
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_create_endpoint_invalid_instance_type():
|
||||
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
model_name = "MyModel"
|
||||
_create_model(sagemaker, model_name)
|
||||
|
||||
instance_type = "InvalidInstanceType"
|
||||
production_variants = [
|
||||
{
|
||||
"VariantName": "MyProductionVariant",
|
||||
"ModelName": model_name,
|
||||
"InitialInstanceCount": 1,
|
||||
"InstanceType": instance_type,
|
||||
},
|
||||
]
|
||||
|
||||
endpoint_config_name = "MyEndpointConfig"
|
||||
with assert_raises(ClientError) as e:
|
||||
sagemaker.create_endpoint_config(
|
||||
EndpointConfigName=endpoint_config_name,
|
||||
ProductionVariants=production_variants,
|
||||
)
|
||||
assert_equal(e.exception.response["Error"]["Code"], "ValidationException")
|
||||
expected_message = "Value '{}' at 'instanceType' failed to satisfy constraint: Member must satisfy enum value set: [".format(
|
||||
instance_type
|
||||
)
|
||||
assert_true(expected_message in e.exception.response["Error"]["Message"])
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_create_endpoint():
|
||||
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
endpoint_name = "MyEndpoint"
|
||||
with assert_raises(ClientError) as e:
|
||||
sagemaker.create_endpoint(
|
||||
EndpointName=endpoint_name, EndpointConfigName="NonexistentEndpointConfig"
|
||||
)
|
||||
assert_true(
|
||||
e.exception.response["Error"]["Message"].startswith(
|
||||
"Could not find endpoint configuration"
|
||||
)
|
||||
)
|
||||
|
||||
model_name = "MyModel"
|
||||
_create_model(sagemaker, model_name)
|
||||
|
||||
endpoint_config_name = "MyEndpointConfig"
|
||||
_create_endpoint_config(sagemaker, endpoint_config_name, model_name)
|
||||
|
||||
resp = sagemaker.create_endpoint(
|
||||
EndpointName=endpoint_name,
|
||||
EndpointConfigName=endpoint_config_name,
|
||||
Tags=GENERIC_TAGS_PARAM,
|
||||
)
|
||||
resp["EndpointArn"].should.match(
|
||||
r"^arn:aws:sagemaker:.*:.*:endpoint/{}$".format(endpoint_name)
|
||||
)
|
||||
|
||||
resp = sagemaker.describe_endpoint(EndpointName=endpoint_name)
|
||||
resp["EndpointArn"].should.match(
|
||||
r"^arn:aws:sagemaker:.*:.*:endpoint/{}$".format(endpoint_name)
|
||||
)
|
||||
resp["EndpointName"].should.equal(endpoint_name)
|
||||
resp["EndpointConfigName"].should.equal(endpoint_config_name)
|
||||
resp["EndpointStatus"].should.equal("InService")
|
||||
assert_true(isinstance(resp["CreationTime"], datetime.datetime))
|
||||
assert_true(isinstance(resp["LastModifiedTime"], datetime.datetime))
|
||||
resp["ProductionVariants"][0]["VariantName"].should.equal("MyProductionVariant")
|
||||
|
||||
resp = sagemaker.list_tags(ResourceArn=resp["EndpointArn"])
|
||||
assert_equal(resp["Tags"], GENERIC_TAGS_PARAM)
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_delete_endpoint():
|
||||
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
model_name = "MyModel"
|
||||
_create_model(sagemaker, model_name)
|
||||
|
||||
endpoint_config_name = "MyEndpointConfig"
|
||||
_create_endpoint_config(sagemaker, endpoint_config_name, model_name)
|
||||
|
||||
endpoint_name = "MyEndpoint"
|
||||
_create_endpoint(sagemaker, endpoint_name, endpoint_config_name)
|
||||
|
||||
sagemaker.delete_endpoint(EndpointName=endpoint_name)
|
||||
with assert_raises(ClientError) as e:
|
||||
sagemaker.describe_endpoint(EndpointName=endpoint_name)
|
||||
assert_true(
|
||||
e.exception.response["Error"]["Message"].startswith("Could not find endpoint")
|
||||
)
|
||||
|
||||
with assert_raises(ClientError) as e:
|
||||
sagemaker.delete_endpoint(EndpointName=endpoint_name)
|
||||
assert_true(
|
||||
e.exception.response["Error"]["Message"].startswith("Could not find endpoint")
|
||||
)
|
||||
|
||||
|
||||
def _create_model(boto_client, model_name):
|
||||
resp = boto_client.create_model(
|
||||
ModelName=model_name,
|
||||
PrimaryContainer={
|
||||
"Image": "382416733822.dkr.ecr.us-east-1.amazonaws.com/factorization-machines:1",
|
||||
"ModelDataUrl": "s3://MyBucket/model.tar.gz",
|
||||
},
|
||||
ExecutionRoleArn=FAKE_ROLE_ARN,
|
||||
)
|
||||
assert_equal(resp["ResponseMetadata"]["HTTPStatusCode"], 200)
|
||||
|
||||
|
||||
def _create_endpoint_config(boto_client, endpoint_config_name, model_name):
|
||||
production_variants = [
|
||||
{
|
||||
"VariantName": "MyProductionVariant",
|
||||
"ModelName": model_name,
|
||||
"InitialInstanceCount": 1,
|
||||
"InstanceType": "ml.t2.medium",
|
||||
},
|
||||
]
|
||||
resp = boto_client.create_endpoint_config(
|
||||
EndpointConfigName=endpoint_config_name, ProductionVariants=production_variants
|
||||
)
|
||||
resp["EndpointConfigArn"].should.match(
|
||||
r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name)
|
||||
)
|
||||
|
||||
|
||||
def _create_endpoint(boto_client, endpoint_name, endpoint_config_name):
|
||||
resp = boto_client.create_endpoint(
|
||||
EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name
|
||||
)
|
||||
resp["EndpointArn"].should.match(
|
||||
r"^arn:aws:sagemaker:.*:.*:endpoint/{}$".format(endpoint_name)
|
||||
)
|
||||
127
tests/test_sagemaker/test_sagemaker_training.py
Normal file
127
tests/test_sagemaker/test_sagemaker_training.py
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import boto3
|
||||
import datetime
|
||||
import sure # noqa
|
||||
|
||||
from moto import mock_sagemaker
|
||||
from moto.sts.models import ACCOUNT_ID
|
||||
from nose.tools import assert_true, assert_equal, assert_raises, assert_regexp_matches
|
||||
|
||||
FAKE_ROLE_ARN = "arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID)
|
||||
TEST_REGION_NAME = "us-east-1"
|
||||
|
||||
|
||||
@mock_sagemaker
|
||||
def test_create_training_job():
|
||||
sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
|
||||
|
||||
training_job_name = "MyTrainingJob"
|
||||
container = "382416733822.dkr.ecr.us-east-1.amazonaws.com/linear-learner:1"
|
||||
bucket = "my-bucket"
|
||||
prefix = "sagemaker/DEMO-breast-cancer-prediction/"
|
||||
|
||||
params = {
|
||||
"RoleArn": FAKE_ROLE_ARN,
|
||||
"TrainingJobName": training_job_name,
|
||||
"AlgorithmSpecification": {
|
||||
"TrainingImage": container,
|
||||
"TrainingInputMode": "File",
|
||||
},
|
||||
"ResourceConfig": {
|
||||
"InstanceCount": 1,
|
||||
"InstanceType": "ml.c4.2xlarge",
|
||||
"VolumeSizeInGB": 10,
|
||||
},
|
||||
"InputDataConfig": [
|
||||
{
|
||||
"ChannelName": "train",
|
||||
"DataSource": {
|
||||
"S3DataSource": {
|
||||
"S3DataType": "S3Prefix",
|
||||
"S3Uri": "s3://{}/{}/train/".format(bucket, prefix),
|
||||
"S3DataDistributionType": "ShardedByS3Key",
|
||||
}
|
||||
},
|
||||
"CompressionType": "None",
|
||||
"RecordWrapperType": "None",
|
||||
},
|
||||
{
|
||||
"ChannelName": "validation",
|
||||
"DataSource": {
|
||||
"S3DataSource": {
|
||||
"S3DataType": "S3Prefix",
|
||||
"S3Uri": "s3://{}/{}/validation/".format(bucket, prefix),
|
||||
"S3DataDistributionType": "FullyReplicated",
|
||||
}
|
||||
},
|
||||
"CompressionType": "None",
|
||||
"RecordWrapperType": "None",
|
||||
},
|
||||
],
|
||||
"OutputDataConfig": {"S3OutputPath": "s3://{}/{}/".format(bucket, prefix)},
|
||||
"HyperParameters": {
|
||||
"feature_dim": "30",
|
||||
"mini_batch_size": "100",
|
||||
"predictor_type": "regressor",
|
||||
"epochs": "10",
|
||||
"num_models": "32",
|
||||
"loss": "absolute_loss",
|
||||
},
|
||||
"StoppingCondition": {"MaxRuntimeInSeconds": 60 * 60},
|
||||
}
|
||||
|
||||
resp = sagemaker.create_training_job(**params)
|
||||
resp["TrainingJobArn"].should.match(
|
||||
r"^arn:aws:sagemaker:.*:.*:training-job/{}$".format(training_job_name)
|
||||
)
|
||||
|
||||
resp = sagemaker.describe_training_job(TrainingJobName=training_job_name)
|
||||
resp["TrainingJobName"].should.equal(training_job_name)
|
||||
resp["TrainingJobArn"].should.match(
|
||||
r"^arn:aws:sagemaker:.*:.*:training-job/{}$".format(training_job_name)
|
||||
)
|
||||
assert_true(
|
||||
resp["ModelArtifacts"]["S3ModelArtifacts"].startswith(
|
||||
params["OutputDataConfig"]["S3OutputPath"]
|
||||
)
|
||||
)
|
||||
assert_true(training_job_name in (resp["ModelArtifacts"]["S3ModelArtifacts"]))
|
||||
assert_true(
|
||||
resp["ModelArtifacts"]["S3ModelArtifacts"].endswith("output/model.tar.gz")
|
||||
)
|
||||
assert_equal(resp["TrainingJobStatus"], "Completed")
|
||||
assert_equal(resp["SecondaryStatus"], "Completed")
|
||||
assert_equal(resp["HyperParameters"], params["HyperParameters"])
|
||||
assert_equal(
|
||||
resp["AlgorithmSpecification"]["TrainingImage"],
|
||||
params["AlgorithmSpecification"]["TrainingImage"],
|
||||
)
|
||||
assert_equal(
|
||||
resp["AlgorithmSpecification"]["TrainingInputMode"],
|
||||
params["AlgorithmSpecification"]["TrainingInputMode"],
|
||||
)
|
||||
assert_true("MetricDefinitions" in resp["AlgorithmSpecification"])
|
||||
assert_true("Name" in resp["AlgorithmSpecification"]["MetricDefinitions"][0])
|
||||
assert_true("Regex" in resp["AlgorithmSpecification"]["MetricDefinitions"][0])
|
||||
assert_equal(resp["RoleArn"], FAKE_ROLE_ARN)
|
||||
assert_equal(resp["InputDataConfig"], params["InputDataConfig"])
|
||||
assert_equal(resp["OutputDataConfig"], params["OutputDataConfig"])
|
||||
assert_equal(resp["ResourceConfig"], params["ResourceConfig"])
|
||||
assert_equal(resp["StoppingCondition"], params["StoppingCondition"])
|
||||
assert_true(isinstance(resp["CreationTime"], datetime.datetime))
|
||||
assert_true(isinstance(resp["TrainingStartTime"], datetime.datetime))
|
||||
assert_true(isinstance(resp["TrainingEndTime"], datetime.datetime))
|
||||
assert_true(isinstance(resp["LastModifiedTime"], datetime.datetime))
|
||||
assert_true("SecondaryStatusTransitions" in resp)
|
||||
assert_true("Status" in resp["SecondaryStatusTransitions"][0])
|
||||
assert_true("StartTime" in resp["SecondaryStatusTransitions"][0])
|
||||
assert_true("EndTime" in resp["SecondaryStatusTransitions"][0])
|
||||
assert_true("StatusMessage" in resp["SecondaryStatusTransitions"][0])
|
||||
assert_true("FinalMetricDataList" in resp)
|
||||
assert_true("MetricName" in resp["FinalMetricDataList"][0])
|
||||
assert_true("Value" in resp["FinalMetricDataList"][0])
|
||||
assert_true("Timestamp" in resp["FinalMetricDataList"][0])
|
||||
|
||||
pass
|
||||
Loading…
Add table
Add a link
Reference in a new issue