Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add INT8 mixed-precision training #748

Merged
merged 54 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
b2e99ec
initial commit
gau-nernst Aug 26, 2024
255abe9
expose some UX. update test
gau-nernst Aug 26, 2024
efb53bf
add test. update bench
gau-nernst Aug 26, 2024
0a510f5
update test. add doc
gau-nernst Aug 26, 2024
f80ea8c
fix ngpu
gau-nernst Aug 26, 2024
4a404ce
fix FSDP
gau-nernst Aug 26, 2024
42abc15
fix
gau-nernst Aug 26, 2024
e826d48
fix fsdp test
gau-nernst Aug 26, 2024
2ab9df3
fix
gau-nernst Aug 26, 2024
c89b950
grammar
gau-nernst Aug 26, 2024
cde7e8f
simplify fsdp test
gau-nernst Aug 26, 2024
691da9d
update benchmark script
gau-nernst Aug 27, 2024
3540e79
update
gau-nernst Aug 27, 2024
f9d4e2a
make claim more conservative
gau-nernst Aug 27, 2024
9448b4d
Merge branch 'main' into int8_mp
gau-nernst Aug 28, 2024
64f707a
register fused adam
gau-nernst Aug 28, 2024
d04e8b3
Merge branch 'pytorch:main' into int8_mp
gau-nernst Sep 2, 2024
4f8d63d
Merge branch 'main' into int8_mp
gau-nernst Sep 3, 2024
b3770d3
update benchmark script
gau-nernst Sep 3, 2024
f39fdac
Merge branch 'main' into int8_mp
gau-nernst Sep 4, 2024
dd33823
add more ops
gau-nernst Sep 4, 2024
b96769a
update default
gau-nernst Sep 4, 2024
2b16ebb
use TorchAOBaseTensor
gau-nernst Sep 4, 2024
117cc60
fix fsdp param_dtype
gau-nernst Sep 4, 2024
ae37058
fix param_dtype
gau-nernst Sep 4, 2024
ae4eb21
dtype check to prevent unnecessary errors
gau-nernst Sep 4, 2024
730c90c
move checks
gau-nernst Sep 4, 2024
c470a24
add note
gau-nernst Sep 4, 2024
7c1d760
fix
gau-nernst Sep 4, 2024
0e15e2d
simplify script
gau-nernst Sep 4, 2024
22c11bc
Merge branch 'main' into int8_mp
gau-nernst Sep 5, 2024
208188c
add module-based UX
gau-nernst Sep 5, 2024
77aafdb
fix
gau-nernst Sep 5, 2024
ce6a5d5
Merge branch 'main' into int8_mp
gau-nernst Sep 6, 2024
d367f77
use FP8 impl of __torch_dispatch__
gau-nernst Sep 6, 2024
d24a894
rename _dynamice interface
gau-nernst Sep 6, 2024
fb09b24
update test
gau-nernst Sep 6, 2024
3372644
fix compile on 2.4
gau-nernst Sep 6, 2024
9e05b5c
log torch version
gau-nernst Sep 6, 2024
6e4e684
make log interval customizable
gau-nernst Sep 6, 2024
b395858
make naming for explicit
gau-nernst Sep 6, 2024
986c590
update readme
gau-nernst Sep 6, 2024
35df447
some change
gau-nernst Sep 6, 2024
7164551
fix big bug
gau-nernst Sep 6, 2024
b14ab6d
add docstring. update _get_linear_inserter
gau-nernst Sep 6, 2024
dbbc90f
add TorchAOBaseTensor back
gau-nernst Sep 6, 2024
8d918f1
fix FSDP
gau-nernst Sep 7, 2024
b4bd411
Merge branch 'main' into int8_mp
gau-nernst Sep 7, 2024
d67a933
update FSDP test. add autocast support
gau-nernst Sep 7, 2024
7352335
Merge branch 'main' into int8_mp
gau-nernst Sep 7, 2024
6122aaa
reduce iter
gau-nernst Sep 9, 2024
8dab7cc
Merge branch 'main' into int8_mp
gau-nernst Sep 9, 2024
0d65b26
update int8_mm fallback
gau-nernst Sep 9, 2024
6082d30
put leading dims logic to _dynamic_int8_mm
gau-nernst Sep 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions benchmarks/quantized_training/pretrain_llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@

from torchao._models.llama.model import ModelArgs, Transformer
from torchao.prototype import low_bit_optim
from torchao.prototype.quantized_training import int8_weight_only_quantized_training
from torchao.prototype.quantized_training import (
Int8MixedPrecisionConfig,
int8_mixed_precision_training,
int8_weight_only_quantized_training,
)
from torchao.quantization.quant_api import quantize_


Expand Down Expand Up @@ -116,10 +120,16 @@ def get_tinystories():
if args.activation_checkpointing:
for layer in model.layers:
enable_activation_checkpointing(layer)

# NOTE: don't apply to LM head since there are memory issues.
if args.quantize == "int8_weight_only":
quantize_(model, int8_weight_only_quantized_training(), set_inductor_config=False)
quantize_(model.layers, int8_weight_only_quantized_training(), set_inductor_config=False)
elif args.quantize == "int8_mixed_precision":
cfg = Int8MixedPrecisionConfig(True, True, True)
quantize_(model.layers, int8_mixed_precision_training(cfg), set_inductor_config=False)
elif args.quantize is not None:
raise ValueError(f"Unsupported quantize={args.quantize}")

print(f"No. of params: {sum(p.numel() for p in model.parameters()):,}")
print(f"No. of buffers: {sum(p.numel() for p in model.buffers()):,}")

Expand Down
133 changes: 88 additions & 45 deletions test/prototype/test_quantized_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,16 @@
from torch.testing._internal.common_utils import TestCase, instantiate_parametrized_tests, parametrize, run_tests

from torchao.prototype.low_bit_optim import _AdamW
from torchao.prototype.quantized_training import Int8QTLinearWeight, int8_weight_only_quantized_training
from torchao.prototype.quantized_training import (
int8_weight_only_quantized_training,
int8_mixed_precision_training,
quantize_int8_rowwise,
Int8MixedPrecisionConfig,
)
from torchao.quantization.quant_api import quantize_
from torchao.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4

if not TORCH_VERSION_AFTER_2_3:
if not TORCH_VERSION_AT_LEAST_2_4:
pytest.skip("Requires torch>=2.4", allow_module_level=True)


Expand All @@ -35,18 +40,26 @@ def test_int8_stochastic_rounding(self, device):
x = torch.randn(32, device=device)
x_samples = x.view(1, -1).repeat(100_000, 1)

x_int8, x_scale = Int8QTLinearWeight.quantize(x_samples, stochastic_rounding=True)
x_int8, x_scale = quantize_int8_rowwise(x_samples, stochastic_rounding=True)
x_dequant_samples = x_int8 * x_scale.view(-1, 1)
x_dequant_mean = x_dequant_samples.mean(0)

# a more rigorous test would be to do a hypothesis testing.
# due to the statistical nature, this assertion may still fail, though very rarely.
torch.testing.assert_close(x_dequant_mean, x, atol=1e-4, rtol=1e-4)

@staticmethod
def _forward_and_backward(module, input, grad):
# clone input, since we want to inspect its gradient later
input = input.detach().clone().requires_grad_(True)
output = module(input)
output.backward(grad)
return input, output

@parametrize("leading_dims", [(), (2,), (2, 4)])
@parametrize("bias", [False, True])
@parametrize("device", _DEVICES)
def test_int8_linear(self, leading_dims, bias, device):
def test_int8_weight_only_correctness(self, leading_dims, bias, device):
_reset()
embed_dim = 32

Expand All @@ -55,20 +68,13 @@ def test_int8_linear(self, leading_dims, bias, device):
quantize_(linear_int8, int8_weight_only_quantized_training(), set_inductor_config=False)
linear_fp32.weight.data = linear_int8.weight.data.dequantize()

input_fp32 = torch.randn(leading_dims + (embed_dim,), device=device)
input_int8 = input_fp32.clone()
input_fp32.requires_grad_(True)
input_int8.requires_grad_(True)
input = torch.randn(leading_dims + (embed_dim,), device=device)
grad = torch.randn(leading_dims + (embed_dim,), device=device)

# test forward
out_fp32 = linear_fp32(input_fp32)
out_int8 = linear_int8(input_int8)
torch.testing.assert_close(out_fp32, out_int8)
input_fp32, out_fp32 = self._forward_and_backward(linear_fp32, input, grad)
input_int8, out_int8 = self._forward_and_backward(linear_int8, input, grad)

# test backward
grad = torch.randn(leading_dims + (embed_dim,), device=device)
out_fp32.backward(grad)
out_int8.backward(grad)
torch.testing.assert_close(out_fp32, out_int8)
torch.testing.assert_close(input_fp32.grad, input_int8.grad)
torch.testing.assert_close(linear_fp32.weight.grad, linear_int8.weight.grad)
if bias:
Expand All @@ -77,7 +83,7 @@ def test_int8_linear(self, leading_dims, bias, device):
@parametrize("leading_dims", [(), (2,), (2, 4)])
@parametrize("bias", [False, True])
@parametrize("device", _DEVICES)
def test_int8_linear_compile(self, leading_dims, bias, device):
def test_int8_weight_only_compile(self, leading_dims, bias, device):
_reset()
embed_dim = 128

Expand All @@ -86,26 +92,21 @@ def test_int8_linear_compile(self, leading_dims, bias, device):
linear_compiled = copy.deepcopy(linear_eager)
linear_compiled.compile()

input_eager = torch.randn(leading_dims + (embed_dim,), device=device) * 10
input_compiled = input_eager.clone()
input_eager.requires_grad_(True)
input_compiled.requires_grad_(True)
input = torch.randn(leading_dims + (embed_dim,), device=device) * 10
grad = torch.randn(leading_dims + (embed_dim,), device=device)

out_eager = linear_eager(input_eager)
out_compiled = linear_compiled(input_compiled)
torch.testing.assert_close(out_eager, out_compiled)
input_eager, out_eager = self._forward_and_backward(linear_eager, input, grad)
input_compiled, out_compiled = self._forward_and_backward(linear_compiled, input, grad)

grad = torch.randn(leading_dims + (embed_dim,), device=device)
out_eager.backward(grad)
out_compiled.backward(grad)
torch.testing.assert_close(out_eager, out_compiled)
torch.testing.assert_close(input_eager.grad, input_compiled.grad)
torch.testing.assert_close(linear_eager.weight.grad, linear_compiled.weight.grad)
if bias:
torch.testing.assert_close(linear_eager.bias.grad, linear_compiled.bias.grad)

@parametrize("compile", [False, True])
@parametrize("device", _DEVICES)
def test_int8_linear_training(self, compile, device):
def test_int8_weight_only_training(self, compile, device):
_reset()
bsize = 4
embed_dim = 32
Expand All @@ -117,7 +118,6 @@ def test_int8_linear_training(self, compile, device):
nn.Linear(embed_dim * 2, n_classes),
).to(device)
model_int8 = copy.deepcopy(model_fp32)
# don't set inductor flags to speed up CI time
quantize_(model_int8, int8_weight_only_quantized_training(), set_inductor_config=False)

if compile:
Expand All @@ -144,6 +144,48 @@ def test_int8_linear_training(self, compile, device):
optim_int8.step()
optim_int8.zero_grad()

@parametrize("compile", [False, True])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_int8_mixed_precision_training(self, compile):
_reset()
bsize = 4
embed_dim = 32
device = "cuda"
config = Int8MixedPrecisionConfig(True, True, True)

# only use 1 matmul shape to reduce triton autotune time
model_ref = nn.Sequential(
nn.Linear(embed_dim, embed_dim, bias=False),
nn.GELU(),
nn.Linear(embed_dim, embed_dim),
).to(device)
model_int8mp = copy.deepcopy(model_ref)
quantize_(model_int8mp, int8_mixed_precision_training(config), set_inductor_config=False)

if compile:
model_ref.compile()
model_int8mp.compile()

optim_ref = torch.optim.AdamW(model_ref.parameters())
optim_int8mp = torch.optim.AdamW(model_int8mp.parameters())

for i in range(5):
inputs = torch.randn(bsize, embed_dim, device=device)
labels = torch.randint(embed_dim, size=(bsize,), device=device)
loss_ref = F.cross_entropy(model_ref(inputs), labels)
loss_int8mp = F.cross_entropy(model_int8mp(inputs), labels)

rel_error = abs(loss_int8mp.item() - loss_ref.item()) / abs(loss_ref.item())
assert rel_error < 3e-2, (i, rel_error)

loss_ref.backward()
optim_ref.step()
optim_ref.zero_grad()

loss_int8mp.backward()
optim_int8mp.step()
optim_int8mp.zero_grad()


class TestFSDP2(FSDPTest):
@property
Expand All @@ -152,17 +194,24 @@ def world_size(self) -> int:

@skip_if_lt_x_gpu(2)
def test_fsdp2(self):
# FSDP2 + compiled quantized training fails with PyTorch 2.4
compile_layer_choices = [False]
if TORCH_VERSION_AFTER_2_4:
compile_layer_choices.append(True)
# due to stochastic rounding, use a pretty large tolerance here
self.run_subtests(
dict(),
self._test_fsdp2,
quantize_fn=int8_weight_only_quantized_training(),
tolerance=0.05,
)

# triton autotune takes too long. only test with compile_layer=False
# and apply INT8 matmul to forward pass only.
self.run_subtests(
{"compile_layer": compile_layer_choices},
dict(),
self._test_fsdp2,
quantize_fn=int8_mixed_precision_training(Int8MixedPrecisionConfig(True, False, False)),
tolerance=1e-6,
)

def _test_fsdp2(self, compile_layer):
def _test_fsdp2(self, quantize_fn, tolerance):
import torch.distributed as dist
from torch.distributed._composable.fsdp import fully_shard
from torch.testing._internal.distributed._tensor.common_dtensor import ModelArgs, Transformer
Expand All @@ -178,19 +227,14 @@ def _test_fsdp2(self, compile_layer):
vocab_size=vocab_size,
max_seq_len=seq_len,
dropout_p=0,
weight_tying=False, # INT8 mixed-precision will fail if weight_tying=True
)
torch.manual_seed(42)
base_model = Transformer(model_args).cuda()
quantize_(base_model, int8_weight_only_quantized_training(), set_inductor_config=False)
quantize_(base_model, quantize_fn, set_inductor_config=False)
fsdp_model = copy.deepcopy(base_model)

if compile_layer:
for layer in base_model.layers:
layer.compile()

for layer in fsdp_model.layers:
if compile_layer:
layer.compile()
fully_shard(layer)
fully_shard(fsdp_model)

Expand All @@ -213,9 +257,8 @@ def _test_fsdp2(self, compile_layer):
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)
base_optim.step()

# due to stochastic rounding, use a pretty large tolerance here
rel_error = (fsdp_loss - base_loss).abs() / base_loss.abs()
assert rel_error < 0.05, rel_error
assert rel_error < tolerance, (iter_idx, rel_error)


instantiate_parametrized_tests(TestQuantizedTraining)
Expand Down
44 changes: 41 additions & 3 deletions torchao/prototype/quantized_training/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

This folder contains experimental work on quantized training (QT). The main difference from quantization-aware training (QAT) is that in QT, we don't keep a high-precision copy of model weights. We take inspirations from:
- Q-GaLore: [[paper](https://arxiv.org/abs/2407.08296)] [[code](https://github.com/VITA-Group/Q-GaLore)]
- AQT: [[related paper](https://arxiv.org/abs/2105.03536)] [[code](https://github.com/google/aqt)]
- JetFire: [[paper](https://arxiv.org/abs/2403.12422)] [[code](https://github.com/thu-ml/Jetfire-INT8Training)]

Typically, low-precision weights cannot be trained directly due to quantization error: a small change in the quantized weight will be round down to zero. To tackle this problem, we use **stochastic rounding** for weight update. In simple terms, stochastic rounding will round up or down randomly, but with a higher chance if it is closer to that direction. For example, 0.8 will have 80% chance of rounding up and 20% of rounding down. It also follows that on average, stochastic rounding will estimate the floating point value exactly.

Expand All @@ -23,7 +23,7 @@ Usage
```python
from torchao.prototype.quantized_training import int8_weight_only_quantized_training
from torchao.prototype.low_bit_optim import _AdamW
from torchao.quantization.quant_api import quantize_
from torchao.quantization import quantize_

model = ...
quantize_(model, int8_weight_only_quantized_training())
Expand All @@ -46,8 +46,46 @@ BF16 compile | 10.16915
INT8 QT eager | 10.11437
INT8 QT compile | 10.03365

## INT8 mixed-precision

On NVIDIA GPUs, INT8 Tensor Cores can be up to 3x faster than their BF16/FP16 counterparts. In mixed-precision training, we can down-cast activations and weights dynamically to INT8 to leverage faster matmuls. However, since INT8 has very limited range [-128,127], we perform row-wise quantization, similar to how INT8 post-training quantization (PTQ) is done. Weight is still in original precision. This is inspired by prior works:
gau-nernst marked this conversation as resolved.
Show resolved Hide resolved

- AQT: [[related paper](https://arxiv.org/abs/2105.03536)] [[code](https://github.com/google/aqt)]
- SwitchBack: [[paper](https://arxiv.org/abs/2304.13013)]

Usage

```python
from torchao.prototype.quantized_training import int8_mixed_precision_training, Int8MixedPrecisionConfig
from torchao.quantization import quantize_

model = ...
config = Int8MixedPrecisionConfig(
gau-nernst marked this conversation as resolved.
Show resolved Hide resolved
forward=True,
backward_grad_input=True,
backward_grad_weight=True,
)
quantize_(model, int8_mixed_precision_training(config))

# train model as usual
```

During training, there are 3 matmuls involved in each `nn.Linear` layer:
- 1 in forward: `output = input @ weight.T`
- 2 in backward:
- `grad_input = grad_output @ weight`
- `grad_weight = grad_output.T @ input`

You can configure which matmul to be applied with INT8 mixed-precision using `Int8MixedPrecisionConfig` shown above. If convergence is an issue, we recommend leaving `backward_grad_weight` in original matmul precision, and also `backward_grad_input` if the issue still persists.

Note:
- When we only apply INT8 mixed-precision in the forward pass, this can be considered QAT.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this true? QAT does fake quantization (doesn't cast dtypes to int8) but here we're actually casting

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we consider dtype casting an implementation detail, and only think QAT as doing quantization in forward pass during training (in terms of numerics), then it should be the same 😄.

- When we only apply INT8 mixed-precision to `forward` and `backward_grad_input`, this is similar to SwitchBack. However, SwitchBack uses tensor-wise scaling for weight. For simplicity, we only support row-wise scaling.

TODO: add some benchmarks

## Future ideas

- INT8 activation x INT8 weight. This can potentially leverage INT8 Tensor Cores, which is 2x faster than FP16/BF16 Tensor Cores.
- Tile-wise INT8 quantization to keep quantized weight for both forward and backward pass (similar to JetFire).
- INT4 weight only (with group-wise quantization). This can be used with INT4 tinygemm deployment in mind (or other optimized INT4 kernels).
- FP8 activation x FP8 weight. The current FP8 training recipe can be seen as a form of QAT, which maintains a high-precision copy of model weights. We can eliminate the high-precision copy.
11 changes: 10 additions & 1 deletion torchao/prototype/quantized_training/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,10 @@
from .int8 import Int8QTLinearWeight, int8_weight_only_quantized_training
from .int8 import (
Int8QTLinearWeight,
int8_weight_only_quantized_training,
quantize_int8_rowwise,
)
from .int8_mixed_precision import (
Int8MixedPrecisionConfig,
Int8MixedPrecisionLinearWeight,
int8_mixed_precision_training,
)
Loading
Loading