Skip to content

Commit

Permalink
Update Inference Pipeline with Scikit-learn and Linear Learner notebo…
Browse files Browse the repository at this point in the history
…ok for SageMaker v2 API. Addresses #1891 (#1892)
  • Loading branch information
danielsiwiec authored Jan 4, 2021
1 parent 99dd54c commit 48a820a
Showing 1 changed file with 15 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@
" entry_point=script_path,\n",
" role=role,\n",
" framework_version=FRAMEWORK_VERSION,\n",
" train_instance_type=\"ml.c4.xlarge\",\n",
" instance_type=\"ml.c4.xlarge\",\n",
" sagemaker_session=sagemaker_session)\n"
]
},
Expand Down Expand Up @@ -398,8 +398,8 @@
"outputs": [],
"source": [
"import boto3\n",
"from sagemaker.amazon.amazon_estimator import get_image_uri\n",
"ll_image = get_image_uri(boto3.Session().region_name, 'linear-learner')"
"from sagemaker.image_uris import retrieve\n",
"ll_image = retrieve('linear-learner', boto3.Session().region_name)"
]
},
{
Expand All @@ -414,17 +414,17 @@
"ll_estimator = sagemaker.estimator.Estimator(\n",
" ll_image,\n",
" role, \n",
" train_instance_count=1, \n",
" train_instance_type='ml.m4.2xlarge',\n",
" train_volume_size = 20,\n",
" train_max_run = 3600,\n",
" instance_count=1, \n",
" instance_type='ml.m4.2xlarge',\n",
" volume_size = 20,\n",
" max_run = 3600,\n",
" input_mode= 'File',\n",
" output_path=s3_ll_output_location,\n",
" sagemaker_session=sagemaker_session)\n",
"\n",
"ll_estimator.set_hyperparameters(feature_dim=10, predictor_type='regressor', mini_batch_size=32)\n",
"\n",
"ll_train_data = sagemaker.session.s3_input(\n",
"ll_train_data = sagemaker.inputs.TrainingInput(\n",
" preprocessed_train, \n",
" distribution='FullyReplicated',\n",
" content_type='text/csv', \n",
Expand Down Expand Up @@ -494,16 +494,15 @@
"metadata": {},
"outputs": [],
"source": [
"from sagemaker.predictor import json_serializer, csv_serializer, json_deserializer, RealTimePredictor\n",
"from sagemaker.content_types import CONTENT_TYPE_CSV, CONTENT_TYPE_JSON\n",
"from sagemaker.predictor import Predictor\n",
"from sagemaker.serializers import CSVSerializer\n",
"\n",
"payload = 'M, 0.44, 0.365, 0.125, 0.516, 0.2155, 0.114, 0.155'\n",
"actual_rings = 10\n",
"predictor = RealTimePredictor(\n",
" endpoint=endpoint_name,\n",
"predictor = Predictor(\n",
" endpoint_name=endpoint_name,\n",
" sagemaker_session=sagemaker_session,\n",
" serializer=csv_serializer,\n",
" content_type=CONTENT_TYPE_CSV,\n",
" accept=CONTENT_TYPE_JSON)\n",
" serializer=CSVSerializer())\n",
"\n",
"print(predictor.predict(payload))"
]
Expand Down Expand Up @@ -544,7 +543,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
"version": "3.6.10"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 48a820a

Please sign in to comment.