Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add INT8 mixed-precision training #748

Merged
merged 54 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
b2e99ec
initial commit
gau-nernst Aug 26, 2024
255abe9
expose some UX. update test
gau-nernst Aug 26, 2024
efb53bf
add test. update bench
gau-nernst Aug 26, 2024
0a510f5
update test. add doc
gau-nernst Aug 26, 2024
f80ea8c
fix ngpu
gau-nernst Aug 26, 2024
4a404ce
fix FSDP
gau-nernst Aug 26, 2024
42abc15
fix
gau-nernst Aug 26, 2024
e826d48
fix fsdp test
gau-nernst Aug 26, 2024
2ab9df3
fix
gau-nernst Aug 26, 2024
c89b950
grammar
gau-nernst Aug 26, 2024
cde7e8f
simplify fsdp test
gau-nernst Aug 26, 2024
691da9d
update benchmark script
gau-nernst Aug 27, 2024
3540e79
update
gau-nernst Aug 27, 2024
f9d4e2a
make claim more conservative
gau-nernst Aug 27, 2024
9448b4d
Merge branch 'main' into int8_mp
gau-nernst Aug 28, 2024
64f707a
register fused adam
gau-nernst Aug 28, 2024
d04e8b3
Merge branch 'pytorch:main' into int8_mp
gau-nernst Sep 2, 2024
4f8d63d
Merge branch 'main' into int8_mp
gau-nernst Sep 3, 2024
b3770d3
update benchmark script
gau-nernst Sep 3, 2024
f39fdac
Merge branch 'main' into int8_mp
gau-nernst Sep 4, 2024
dd33823
add more ops
gau-nernst Sep 4, 2024
b96769a
update default
gau-nernst Sep 4, 2024
2b16ebb
use TorchAOBaseTensor
gau-nernst Sep 4, 2024
117cc60
fix fsdp param_dtype
gau-nernst Sep 4, 2024
ae37058
fix param_dtype
gau-nernst Sep 4, 2024
ae4eb21
dtype check to prevent unnecessary errors
gau-nernst Sep 4, 2024
730c90c
move checks
gau-nernst Sep 4, 2024
c470a24
add note
gau-nernst Sep 4, 2024
7c1d760
fix
gau-nernst Sep 4, 2024
0e15e2d
simplify script
gau-nernst Sep 4, 2024
22c11bc
Merge branch 'main' into int8_mp
gau-nernst Sep 5, 2024
208188c
add module-based UX
gau-nernst Sep 5, 2024
77aafdb
fix
gau-nernst Sep 5, 2024
ce6a5d5
Merge branch 'main' into int8_mp
gau-nernst Sep 6, 2024
d367f77
use FP8 impl of __torch_dispatch__
gau-nernst Sep 6, 2024
d24a894
rename _dynamice interface
gau-nernst Sep 6, 2024
fb09b24
update test
gau-nernst Sep 6, 2024
3372644
fix compile on 2.4
gau-nernst Sep 6, 2024
9e05b5c
log torch version
gau-nernst Sep 6, 2024
6e4e684
make log interval customizable
gau-nernst Sep 6, 2024
b395858
make naming for explicit
gau-nernst Sep 6, 2024
986c590
update readme
gau-nernst Sep 6, 2024
35df447
some change
gau-nernst Sep 6, 2024
7164551
fix big bug
gau-nernst Sep 6, 2024
b14ab6d
add docstring. update _get_linear_inserter
gau-nernst Sep 6, 2024
dbbc90f
add TorchAOBaseTensor back
gau-nernst Sep 6, 2024
8d918f1
fix FSDP
gau-nernst Sep 7, 2024
b4bd411
Merge branch 'main' into int8_mp
gau-nernst Sep 7, 2024
d67a933
update FSDP test. add autocast support
gau-nernst Sep 7, 2024
7352335
Merge branch 'main' into int8_mp
gau-nernst Sep 7, 2024
6122aaa
reduce iter
gau-nernst Sep 9, 2024
8dab7cc
Merge branch 'main' into int8_mp
gau-nernst Sep 9, 2024
0d65b26
update int8_mm fallback
gau-nernst Sep 9, 2024
6082d30
put leading dims logic to _dynamic_int8_mm
gau-nernst Sep 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions benchmarks/quantized_training/benchmark_int8mm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import pandas as pd
import torch
from triton.testing import do_bench

from torchao.prototype.quantized_training.int8_mm import int8_mm_dequant


def bench_f(f, *args):
return do_bench(lambda: f(*args), fast_flush=False, return_mode="median")


shapes = [(sz, sz, sz) for sz in [1024, 2048, 4096]]

# Llama-8B shapes
shapes += [
# linear in attention
(32_768, 4096, 4096),
(4096, 4096, 32_768),
# linear in feed-forward
(32_768, 14_336, 4096),
(32_768, 4096, 14_336),
(14_336, 4096, 32_768),
]

data = []
for M, N, K in shapes:
print(f"{M=}, {N=}, {K=}")

A_bf16 = torch.randn(M, K).bfloat16().cuda()
B_bf16 = torch.randn(N, K).bfloat16().cuda()
A_i8 = torch.randint(-128, 127, size=(M, K), dtype=torch.int8).cuda()
B_i8 = torch.randint(-128, 127, size=(N, K), dtype=torch.int8).cuda()
A_scale = torch.randn(M).bfloat16().cuda()
B_scale = torch.randn(N).bfloat16().cuda()

# benchmark F.linear() i.e. A @ B.T
bf16_time = bench_f(torch.mm, A_bf16, B_bf16.T)
i8_time = bench_f(torch._int_mm, A_i8, B_i8.T)
i8_dequant_time = bench_f(int8_mm_dequant, A_i8, B_i8.T, A_scale, B_scale)

sample = [M, N, K, bf16_time / i8_time, bf16_time / i8_dequant_time]
data.append(sample)

df = pd.DataFrame(data, columns=["M", "N", "K", "CuBLAS INT8 speedup", "Triton INT8 dequant speedup"])
print(df.to_markdown())
57 changes: 37 additions & 20 deletions benchmarks/quantized_training/pretrain_llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
#
# BF16 baseline: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile
# INT8 QT: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile --quantize int8_weight_only
# INT8 MP: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile --quantize int8_mixed_precision

import os

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import argparse
import time
from functools import partial
from pathlib import Path

Expand All @@ -18,22 +20,34 @@
from torch.utils.checkpoint import checkpoint
from tqdm import tqdm

from torchao._models.llama.model import ModelArgs, Transformer
from torchao._models.llama.model import ModelArgs, Transformer, transformer_configs
from torchao.prototype import low_bit_optim
from torchao.prototype.quantized_training import int8_weight_only_quantized_training
from torchao.prototype.quantized_training import (
int8_mixed_precision_training,
int8_weight_only_quantized_training,
)
from torchao.quantization.quant_api import quantize_


# not official models
transformer_configs.update(
(
("470M", dict(n_layer=24, n_head=16, dim=1024, intermediate_size=4096)),
("1B", dict(n_layer=24, n_head=24, dim=1536, intermediate_size=6144)),
)
)


# hack from fairseq
# https://github.com/facebookresearch/fairseq/blob/920a548ca770fb1a951f7f4289b4d3a0c1bc226f/fairseq/modules/checkpoint_activations.py
def enable_activation_checkpointing(m: torch.nn.Module):
assert not hasattr(m, "_forward")
m._forward = m.forward
m.forward = partial(checkpoint, m.forward)
m.forward = partial(checkpoint, m.forward, use_reentrant=False)


def get_loss(model: Transformer, batch: torch.Tensor):
logits = model(batch)[:, :-1].flatten(0, 1)
logits = model(batch)[:, :-1].float().flatten(0, 1)
labels = batch[:, 1:].flatten()
return torch.nn.functional.cross_entropy(logits, labels)

Expand Down Expand Up @@ -77,12 +91,7 @@ def get_tinystories():

if __name__ == "__main__":
parser = argparse.ArgumentParser()
# default config is 470M
parser.add_argument("--d_model", type=int, default=1024)
parser.add_argument("--depth", type=int, default=24)
parser.add_argument("--ffn_size", type=int, default=4096)
parser.add_argument("--head_dim", type=int, default=64)

parser.add_argument("--model", default="470M", choices=transformer_configs.keys())
parser.add_argument("--quantize")
parser.add_argument("--activation_checkpointing", action="store_true")
parser.add_argument("--compile", action="store_true")
Expand All @@ -98,44 +107,48 @@ def get_tinystories():
parser.add_argument("--project", default="int8_quantized_training")
parser.add_argument("--run_name")
parser.add_argument("--seed", type=int)
parser.add_argument("--log_interval", type=int, default=10)
args = parser.parse_args()

if args.seed is not None:
torch.manual_seed(args.seed)

config = ModelArgs(
block_size=args.seq_len,
n_layer=args.depth,
n_head=args.d_model // args.head_dim,
dim=args.d_model,
intermediate_size=args.ffn_size,
)
config = ModelArgs.from_name(args.model)
config.block_size = args.seq_len
model = Transformer(config).bfloat16().cuda()
with torch.device("cuda"):
model.setup_caches(args.batch_size, args.seq_len, training=True)
if args.activation_checkpointing:
for layer in model.layers:
enable_activation_checkpointing(layer)

# don't apply int8_mixed_precision to LM head, since it can cause convergence issue.
# TODO: might want to do the same for int8_weight_only to standardize.
if args.quantize == "int8_weight_only":
quantize_(model, int8_weight_only_quantized_training(), set_inductor_config=False)
elif args.quantize == "int8_mixed_precision":
quantize_(model.layers, int8_mixed_precision_training(), set_inductor_config=False)
elif args.quantize is not None:
raise ValueError(f"Unsupported quantize={args.quantize}")

print(f"No. of params: {sum(p.numel() for p in model.parameters()):,}")
print(f"No. of buffers: {sum(p.numel() for p in model.buffers()):,}")
torch.cuda.reset_peak_memory_stats() # don't count memory occupied by unquantized weights

# only use optimizers from torchao.prototype.low_bit_optim to support quantized training
if args.optim == "AdamW":
args.optim = "_AdamW"
optim = getattr(low_bit_optim, args.optim)(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

data = get_tinystories().cuda()
args.torch_version = torch.__version__
run = wandb.init(dir="/tmp", config=args, project=args.project, name=args.run_name)

step = 0
log_interval = 50
pbar = tqdm(total=args.n_steps, dynamic_ncols=True)
model.train()
_get_loss = torch.compile(get_loss) if args.compile else get_loss
time0 = time.time()

while step < args.n_steps:
# randomly select a continuous chunk, then reshape it
Expand All @@ -145,13 +158,17 @@ def get_tinystories():
loss = _get_loss(model, batch)
loss.backward()

if step % log_interval == 0:
if step % args.log_interval == 0:
log_dict = dict(
loss=loss.item(),
lr=optim.param_groups[0]["lr"],
max_memory_allocated=torch.cuda.max_memory_allocated() / 1e9,
max_memory_active=torch.cuda.memory_stats().get("active_bytes.all.peak", 0) / 1e9,
max_memory_reserved=torch.cuda.max_memory_reserved() / 1e9,
)
if step > 0:
time1 = time.time()
log_dict["tokens_per_second"] = (args.log_interval * args.batch_size * args.seq_len) / (time1 - time0)
time0 = time1
run.log(log_dict, step=step)
pbar.set_postfix(loss=log_dict["loss"])

Expand Down
Loading
Loading