Skip to content

Commit

Permalink
Updated hpo_xgboost_direct_marketing_sagemaker_python_sdk.ipynb for S…
Browse files Browse the repository at this point in the history
…ageMaker SDK v2 (#1898)

* Updated hpo_xgboost_direct_marketing_sagemaker_python_sdk.ipynb for SageMaker SDK v2

Notebook needed updates to be compatible with the v2 of the SageMaker SDK, which is the new Default SDK version in SageMaker.

Issue #, if available:

Description of changes:
Notebook needed updates to be compatible with the v2 of the SageMaker SDK, which is the new Default SDK version in SageMaker.

Fixed:
- Use TrainingInput instead of s3_input.
- Use sagemaker.image_uris.retrieve instead of get_image_uri.
- Use instance_count and instance_type instead of train_instance_count and train_instance_type.

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

* Update hpo_xgboost_direct_marketing_sagemaker_python_sdk.ipynb

Removed redundant get_image_uri import.
  • Loading branch information
eitansela authored Jan 4, 2021
1 parent f4cee69 commit 99dd54c
Showing 1 changed file with 21 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,14 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true,
"isConfigCell": true
},
"outputs": [],
"source": [
"import sagemaker\n",
"import boto3\n",
"from sagemaker.tuner import IntegerParameter, CategoricalParameter, ContinuousParameter, HyperparameterTuner\n",
"from sagemaker.inputs import TrainingInput\n",
"\n",
"import numpy as np # For matrix operations and numerical processing\n",
"import pandas as pd # For munging tabular data\n",
Expand Down Expand Up @@ -84,9 +84,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"metadata": {},
"outputs": [],
"source": [
"!wget -N https://archive.ics.uci.edu/ml/machine-learning-databases/00222/bank-additional.zip\n",
Expand All @@ -103,9 +101,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"metadata": {},
"outputs": [],
"source": [
"data = pd.read_csv('./bank-additional/bank-additional-full.csv', sep=';')\n",
Expand Down Expand Up @@ -192,9 +188,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"metadata": {},
"outputs": [],
"source": [
"data['no_previous_contact'] = np.where(data['pdays'] == 999, 1, 0) # Indicator variable to capture when pdays takes a value of 999\n",
Expand All @@ -217,9 +211,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"metadata": {},
"outputs": [],
"source": [
"model_data = model_data.drop(['duration', 'emp.var.rate', 'cons.price.idx', 'cons.conf.idx', 'euribor3m', 'nr.employed'], axis=1)"
Expand All @@ -237,9 +229,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"metadata": {},
"outputs": [],
"source": [
"train_data, validation_data, test_data = np.split(model_data.sample(frac=1, random_state=1729), [int(0.7 * len(model_data)), int(0.9*len(model_data))]) \n",
Expand All @@ -259,9 +249,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"metadata": {},
"outputs": [],
"source": [
"boto3.Session().resource('s3').Bucket(bucket).Object(os.path.join(prefix, 'train/train.csv')).upload_file('train.csv')\n",
Expand Down Expand Up @@ -292,21 +280,17 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"metadata": {},
"outputs": [],
"source": [
"from sagemaker.amazon.amazon_estimator import get_image_uri\n",
"\n",
"sess = sagemaker.Session()\n",
"\n",
"container = get_image_uri(region, 'xgboost', repo_version='latest')\n",
"container = sagemaker.image_uris.retrieve('xgboost', boto3.Session().region_name, 'latest')\n",
"\n",
"xgb = sagemaker.estimator.Estimator(container,\n",
" role, \n",
" train_instance_count=1, \n",
" train_instance_type='ml.m4.xlarge',\n",
" instance_count=1, \n",
" instance_type='ml.m4.xlarge',\n",
" output_path='s3://{}/{}/output'.format(bucket, prefix),\n",
" sagemaker_session=sess)\n",
"\n",
Expand All @@ -331,9 +315,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"metadata": {},
"outputs": [],
"source": [
"hyperparameter_ranges = {'eta': ContinuousParameter(0, 1),\n",
Expand All @@ -352,9 +334,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"metadata": {},
"outputs": [],
"source": [
"objective_metric_name = 'validation:auc'"
Expand All @@ -374,9 +354,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"metadata": {},
"outputs": [],
"source": [
"tuner = HyperparameterTuner(xgb,\n",
Expand All @@ -397,13 +375,11 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"metadata": {},
"outputs": [],
"source": [
"s3_input_train = sagemaker.s3_input(s3_data='s3://{}/{}/train'.format(bucket, prefix), content_type='csv')\n",
"s3_input_validation = sagemaker.s3_input(s3_data='s3://{}/{}/validation/'.format(bucket, prefix), content_type='csv')\n",
"s3_input_train = TrainingInput(s3_data='s3://{}/{}/train'.format(bucket, prefix), content_type='csv')\n",
"s3_input_validation = TrainingInput(s3_data='s3://{}/{}/validation/'.format(bucket, prefix), content_type='csv')\n",
"\n",
"tuner.fit({'train': s3_input_train, 'validation': s3_input_validation}, include_cls_metadata=False)"
]
Expand All @@ -418,9 +394,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"metadata": {},
"outputs": [],
"source": [
"boto3.client('sagemaker').describe_hyper_parameter_tuning_job(\n",
Expand All @@ -447,9 +421,9 @@
"metadata": {
"instance_type": "ml.t3.medium",
"kernelspec": {
"display_name": "Python 3 (Data Science)",
"display_name": "conda_python3",
"language": "python",
"name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-2:429704687514:image/datascience-1.0"
"name": "conda_python3"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -461,7 +435,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
"version": "3.6.10"
},
"notice": "Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License."
},
Expand Down

0 comments on commit 99dd54c

Please sign in to comment.