Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update cal_mfu.py #5388

Merged
merged 3 commits into from
Sep 7, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 149 additions & 0 deletions scripts/cal_mfu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# coding=utf-8
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os

import fire
import torch
from transformers import AutoConfig

from llamafactory.train.tuner import run_exp


BASE = 2 # gemm (add + mul)


def compute_model_flops(
model_name_or_path: str,
batch_size: int,
seq_length: int,
include_backward: bool = True,
include_recompute: bool = False,
include_flashattn: bool = False,
) -> int:
r"""
Calculates the FLOPs of model per forward/backward pass.
"""
config = AutoConfig.from_pretrained(model_name_or_path)
hidden_size = getattr(config, "hidden_size", None)
vocab_size = getattr(config, "vocab_size", None)
intermediate_size = getattr(config, "intermediate_size", None)
num_attention_heads = getattr(config, "num_attention_heads", None)
num_key_value_heads = getattr(config, "num_key_value_heads", None)
num_hidden_layers = getattr(config, "num_hidden_layers", None)
tie_word_embeddings = getattr(config, "tie_word_embeddings", False)

# mlp module
mlp_flops_per_token = 3 * BASE * hidden_size * intermediate_size # up, gate, down
mlp_flops = batch_size * seq_length * num_hidden_layers * mlp_flops_per_token

# attn projector module
q_flops_per_token = BASE * hidden_size * hidden_size
o_flops_per_token = BASE * hidden_size * hidden_size
k_flops_per_token = BASE * hidden_size * hidden_size * num_key_value_heads // num_attention_heads
v_flops_per_token = BASE * hidden_size * hidden_size * num_key_value_heads // num_attention_heads
attn_proj_flops_per_token = q_flops_per_token + o_flops_per_token + k_flops_per_token + v_flops_per_token
attn_proj_flops = batch_size * seq_length * num_hidden_layers * attn_proj_flops_per_token

# attn sdpa module
sdpa_flops_per_layer = 2 * BASE * hidden_size * seq_length * seq_length # (q * k^T) * v
sdpa_flops = batch_size * num_hidden_layers * sdpa_flops_per_layer

# embedding module
embedding_flops_per_token = hidden_size * vocab_size
embedding_flops = batch_size * seq_length * embedding_flops_per_token
if tie_word_embeddings is False:
embedding_flops *= 2

non_embedding_flops = mlp_flops + attn_proj_flops + sdpa_flops
non_embedding_coeff, embedding_coeff = 1, 1
if include_backward:
non_embedding_coeff += 2
embedding_coeff += 2

if include_recompute:
non_embedding_coeff += 1

total_flops = non_embedding_coeff * non_embedding_flops + embedding_coeff * embedding_flops

if include_flashattn:
total_flops += sdpa_flops

return total_flops


def compute_device_flops() -> float:
device_name = torch.cuda.get_device_name()
device_count = torch.cuda.device_count()
if "H100" in device_name or "H800" in device_name:
return 989 * 1e12 * device_count
elif "A100" in device_name or "A800" in device_name:
return 312 * 1e12 * device_count
elif "V100" in device_name:
return 125 * 1e12 * device_count
elif "4090" in device_name:
return 98 * 1e12 * device_count
else:
raise NotImplementedError("Device not supported: {}.".format(device_name))


def compute_mfu(
model_name_or_path: str,
batch_size: int,
seq_length: int,
finetuning_type: str = "lora",
flash_attn: str = "auto",
deepspeed_stage: int = 0,
disable_gc: bool = False,
liger_kernel: bool = False,
) -> float:
r"""
Computes MFU for given model and hyper-params.
Usage: python cal_mfu.py --model_name_or_path path_to_model --batch_size 1 --seq_length 1024
"""
args = {
"model_name_or_path": model_name_or_path,
"flash_attn": flash_attn,
"disable_gradient_checkpointing": disable_gc,
"enable_liger_kernel": liger_kernel,
"stage": "pt",
"do_train": True,
"finetuning_type": finetuning_type,
"dataset": "c4_demo",
"cutoff_len": seq_length,
"output_dir": os.path.join("saves", "test_mfu"),
"overwrite_output_dir": True,
"per_device_train_batch_size": batch_size,
"max_steps": 100,
"bf16": True,
}
if deepspeed_stage in [2, 3]:
args["deepspeed"] = "examples/deepspeed/ds_z{}_config.json".format(deepspeed_stage)

run_exp(args)
with open(os.path.join("saves", "test_mfu", "all_results.json"), "r", encoding="utf-8") as f:
result = json.load(f)

mfu_value = (
result["train_steps_per_second"]
* compute_model_flops(model_name_or_path, batch_size, seq_length)
/ compute_device_flops()
)
print("MFU: {:.2f}%".format(mfu_value * 100))


if __name__ == "__main__":
fire.Fire(compute_mfu)