Skip to content

Commit

Permalink
Update SMP v2 notebooks to use latest PyTorch 2.3.1, TSM 2.4.0 release (
Browse files Browse the repository at this point in the history
#4678)

* Update SMP v2 notebooks to use latest PT2.3.1-TSM2.4.0 release.

* Update SMP v2 shared_scripts

* Update minimum sagemaker pysdk version to 2.224
  • Loading branch information
viclzhu authored Jun 24, 2024
1 parent a75d5f2 commit 2283df7
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
"metadata": {},
"outputs": [],
"source": [
"%pip install --upgrade \"sagemaker>=2.212\"\n",
"%pip install --upgrade \"sagemaker>=2.224\"\n",
"%pip install sagemaker-experiments"
]
},
Expand Down Expand Up @@ -882,8 +882,8 @@
" }\n",
" },\n",
" },\n",
" py_version=\"py310\",\n",
" framework_version=\"2.2.0\",\n",
" py_version=\"py311\",\n",
" framework_version=\"2.3.1\",\n",
" # image_uri=$IMAGE, # Either provide `framework_version` or `image_uri`\n",
" output_path=s3_output_bucket,\n",
" max_run=86400,\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
"metadata": {},
"outputs": [],
"source": [
"%pip install --upgrade \"sagemaker>=2.212\"\n",
"%pip install --upgrade \"sagemaker>=2.224\"\n",
"%pip install sagemaker-experiments"
]
},
Expand Down Expand Up @@ -873,8 +873,8 @@
" }\n",
" },\n",
" },\n",
" py_version=\"py310\",\n",
" framework_version=\"2.2.0\",\n",
" py_version=\"py311\",\n",
" framework_version=\"2.3.1\",\n",
" # image_uri=$IMAGE, # Either provide `framework_version` or `image_uri`\n",
" output_path=s3_output_bucket,\n",
" max_run=86400,\n",
Expand Down Expand Up @@ -955,8 +955,8 @@
" }\n",
" },\n",
" },\n",
" py_version=\"py310\",\n",
" framework_version=\"2.2.0\",\n",
" py_version=\"py311\",\n",
" framework_version=\"2.3.1\",\n",
" # image_uri=$IMAGE, # Either provide `framework_version` or `image_uri`\n",
" output_path=s3_output_bucket,\n",
" max_run=86400,\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
"metadata": {},
"outputs": [],
"source": [
"%pip install --upgrade \"sagemaker>=2.212\"\n",
"%pip install --upgrade \"sagemaker>=2.224\"\n",
"%pip install sagemaker-experiments"
]
},
Expand Down Expand Up @@ -867,8 +867,8 @@
" }\n",
" },\n",
" },\n",
" py_version=\"py310\",\n",
" framework_version=\"2.2.0\",\n",
" py_version=\"py311\",\n",
" framework_version=\"2.3.1\",\n",
" # image_uri=$IMAGE, # Either provide `framework_version` or `image_uri`\n",
" output_path=s3_output_bucket,\n",
" max_run=86400,\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
"metadata": {},
"outputs": [],
"source": [
"%pip install --upgrade \"sagemaker>=2.212\"\n",
"%pip install --upgrade \"sagemaker>=2.224\"\n",
"%pip install sagemaker-experiments"
]
},
Expand Down Expand Up @@ -831,8 +831,8 @@
" }\n",
" },\n",
" },\n",
" py_version=\"py310\",\n",
" framework_version=\"2.2.0\",\n",
" py_version=\"py311\",\n",
" framework_version=\"2.3.1\",\n",
" # image_uri=$IMAGE, # Either provide `framework_version` or `image_uri`\n",
" output_path=s3_output_bucket,\n",
" max_run=86400,\n",
Expand Down Expand Up @@ -913,8 +913,8 @@
" }\n",
" },\n",
" },\n",
" py_version=\"py310\",\n",
" framework_version=\"2.2.0\",\n",
" py_version=\"py311\",\n",
" framework_version=\"2.3.1\",\n",
" # image_uri=$IMAGE, # Either provide `framework_version` or `image_uri`\n",
" output_path=s3_output_bucket,\n",
" max_run=86400,\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
"metadata": {},
"outputs": [],
"source": [
"%pip install --upgrade \"sagemaker>=2.215\"\n",
"%pip install --upgrade \"sagemaker>=2.224\"\n",
"%pip install sagemaker-experiments"
]
},
Expand Down Expand Up @@ -916,8 +916,8 @@
" }\n",
" },\n",
" },\n",
" py_version=\"py310\",\n",
" framework_version=\"2.2.0\",\n",
" py_version=\"py311\",\n",
" framework_version=\"2.3.1\",\n",
" # image_uri=$IMAGE, # Either provide `framework_version` or `image_uri`\n",
" output_path=s3_output_bucket,\n",
" max_run=86400,\n",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
accelerate>=0.12.0
datasets>=2.16.1
datasets>=2.19.1
einops
evaluate
expecttest
flash-attn>=2.3.6
flash-attn>=2.3.6,<2.4
h5py
humanize
hypothesis
Expand All @@ -14,4 +14,4 @@ protobuf
scikit-learn
sentencepiece!=0.1.92
tensorboard
transformers>=4.37.1
transformers>=4.40.1
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ def main(args):
len(args.num_kept_checkpoints),
)
if len(set(ckpt_lens)) != 1:
raise ValueError(f"Len mismtach for checkpoint dir, freq vs num to keep: {ckpt_lens}.")
raise ValueError(f"Len mismatch for checkpoint dir, freq vs num to keep: {ckpt_lens}.")

if args.distributed_backend == "smddp":
import smdistributed.dataparallel.torch.torch_smddp # pylint: disable=unused-import
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,22 @@ def compute_num_params(model):


def compute_tflops(args, global_batch_size, step_time, world_size):
# Based on
# Based on
# https://github.com/NVIDIA/Megatron-LM/blob/ba773259dbe5735fbd91ca41e7f4ded60b335c52/megatron/training/training.py#L65
num_experts_routed_to = 1 if args.moe > 1 else args.num_experts_per_tok
if args.num_key_value_heads is None:
# Attention projection size.
kv_channels = args.hidden_width // args.num_heads
query_projection_size = kv_channels * args.num_heads
query_projection_to_hidden_size_ratio = query_projection_size / args.hidden_width

# Group Query Attention.
if not args.num_key_value_heads:
args.num_key_value_heads = args.num_heads

# MoE.
num_experts_routed_to = 1 if args.moe == 0 else args.num_experts_per_tok
gated_linear_multiplier = 3/2 if args.moe > 0 else 1

# Compute the number of floating point operations
num_flops = (
12
* global_batch_size
Expand All @@ -47,13 +58,26 @@ def compute_tflops(args, global_batch_size, step_time, world_size):
* args.hidden_width
* args.hidden_width
* (
1
+ ((args.intermediate_size / args.hidden_width) * num_experts_routed_to)
+ (args.num_key_value_heads / args.num_heads)
+ (args.max_context_width / args.hidden_width)
# Attention.
(
(
1
+ (args.num_key_value_heads / args.num_heads)
+ (args.max_context_width / args.hidden_width)
) * query_projection_to_hidden_size_ratio
)
# MLP.
+ (
(args.intermediate_size / args.hidden_width)
* num_experts_routed_to
* gated_linear_multiplier
)
# Logit.
+ (args.vocab_size / (2 * args.num_layers * args.hidden_width))
)
)

# Convert to TFLOPs per GPU
tflops_per_gpu = num_flops / (
step_time * 10**12 * world_size)
return tflops_per_gpu
Expand Down

0 comments on commit 2283df7

Please sign in to comment.