Skip to content

Commit

Permalink
notebook fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
shaernev committed Aug 1, 2024
1 parent 0abd974 commit 3285576
Showing 1 changed file with 9 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@
"metadata": {},
"outputs": [],
"source": [
"!pip install sagemaker ipywidgets --upgrade --quiet"
"!pip install sagemaker jupyterlab --upgrade --quiet\n",
"!pip install ipywidgets==7.6.5"
]
},
{
Expand Down Expand Up @@ -234,9 +235,7 @@
"metadata": {},
"outputs": [],
"source": [
"from sagemaker import image_uris, model_uris, script_uris, hyperparameters\n",
"from sagemaker.model import Model\n",
"from sagemaker.predictor import Predictor\n",
"from sagemaker.jumpstart.model import JumpStartModel\n",
"from sagemaker.utils import name_from_base\n",
"\n",
"# model_version=\"*\" fetches the latest version of the model\n",
Expand All @@ -247,45 +246,17 @@
"\n",
"inference_instance_type = \"ml.p2.xlarge\"\n",
"\n",
"# Retrieve the inference docker container uri\n",
"deploy_image_uri = image_uris.retrieve(\n",
" region=None,\n",
" framework=None, # automatically inferred from model_id\n",
" image_scope=\"inference\",\n",
"# Create the SageMaker JumpStart model instance\n",
"model = JumpStartModel(\n",
" model_id=infer_model_id,\n",
" model_version=infer_model_version,\n",
" instance_type=inference_instance_type,\n",
")\n",
"\n",
"# Retrieve the inference script uri. This includes scripts for model loading, inference handling etc.\n",
"deploy_source_uri = script_uris.retrieve(\n",
" model_id=infer_model_id, model_version=infer_model_version, script_scope=\"inference\"\n",
")\n",
"\n",
"\n",
"# Retrieve the base model uri\n",
"base_model_uri = model_uris.retrieve(\n",
" model_id=infer_model_id, model_version=infer_model_version, model_scope=\"inference\"\n",
")\n",
"\n",
"\n",
"# Create the SageMaker model instance\n",
"model = Model(\n",
" image_uri=deploy_image_uri,\n",
" source_dir=deploy_source_uri,\n",
" model_data=base_model_uri,\n",
" entry_point=\"inference.py\", # entry point file in source_dir and present in deploy_source_uri\n",
" role=aws_role,\n",
" predictor_cls=Predictor,\n",
" name=endpoint_name,\n",
")\n",
"\n",
"# deploy the Model. Note that we need to pass Predictor class when we deploy model through Model class,\n",
"# for being able to run inference through the sagemaker API.\n",
"base_model_predictor = model.deploy(\n",
" initial_instance_count=1,\n",
" instance_type=inference_instance_type,\n",
" predictor_cls=Predictor,\n",
" endpoint_name=endpoint_name,\n",
")"
]
Expand Down Expand Up @@ -355,8 +326,7 @@
" return query_response\n",
"\n",
"\n",
"def parse_response(query_response):\n",
" model_predictions = json.loads(query_response)\n",
"def parse_response(model_predictions):\n",
" normalized_boxes, classes, scores, labels = (\n",
" model_predictions[\"normalized_boxes\"],\n",
" model_predictions[\"classes\"],\n",
Expand Down Expand Up @@ -837,8 +807,10 @@
"outputs": [],
"source": [
"query_response = query(finetuned_predictor, pedestrian_image_file_name)\n",
"model_predictions = json.loads(query_response)\n",
"\n",
"\n",
"normalized_boxes, classes_names, confidences = parse_response(query_response)\n",
"normalized_boxes, classes_names, confidences = parse_response(model_predictions)\n",
"display_predictions(pedestrian_image_file_name, normalized_boxes, classes_names, confidences)"
]
},
Expand Down

0 comments on commit 3285576

Please sign in to comment.