-
Notifications
You must be signed in to change notification settings - Fork 208
/
fused_linear_cross_entropy.py
248 lines (214 loc) · 9.08 KB
/
fused_linear_cross_entropy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
import torch
import triton
from liger_kernel.ops.cross_entropy import liger_cross_entropy_kernel
from liger_kernel.ops.utils import (
amp_custom_bwd,
amp_custom_fwd,
element_mul_kernel,
is_hip,
)
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
MAX_FUSED_SIZE = 65536 // 2
def fused_linear_cross_entropy_forward(
_input,
weight,
target,
bias=None,
ignore_index=-100,
lse_square_scale=0.0,
label_smoothing=0.0,
reduction="mean",
softcap=None,
):
device = _input.device
# inputs have shape: BT x H
# materialized activations will have shape: BT x V
# the increase in memory = BT x V
# reduction can be achieved by partitioning the number of tokens BT into smaller chunks.
# for ex: if we were to achieve the same memory consumption as BT x H, then the chunk size should be:
# inc_factor = (V+H-1)//H, chunk_size = (BT + inc_factor - 1)//inc_factor
# for ex: BT = 4096*4, V = 32000, H = 4096 ==> inc_factor = 8, chunk_size = 2048
BT, H = _input.shape
V = weight.shape[0]
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))
inc_factor = triton.cdiv(V, H) # (V + H - 1) // H
chunk_size = triton.next_power_of_2(
triton.cdiv(BT, inc_factor)
) # (BT + inc_factor - 1) // inc_factor
num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size
grad_weight = (
torch.zeros_like(weight, device=device) if weight.requires_grad else None
)
grad_input = torch.zeros_like(_input, device=device)
grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None
# we use fp32 for loss accumulator
loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
# NOTE: skip .item() here to avoid CUDA synchronization
total_n_non_ignore = (target != ignore_index).sum()
for chunk_id in range(num_chunks):
start_idx = chunk_id * chunk_size
end_idx = min((chunk_id + 1) * chunk_size, BT)
_input_chunk = _input[start_idx:end_idx] # chunk_size x H
# when doing matmul, use the original precision
logits_chunk = _input_chunk @ weight.t() # chunk_size x V
if bias is not None:
logits_chunk = logits_chunk + bias
target_chunk = target[start_idx:end_idx] # chunk_size,
n_rows = logits_chunk.shape[0]
# unreduced loss
loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
n_non_ignore = (target_chunk != ignore_index).sum().item()
# ensure _input and target are contiguous
logits_chunk = logits_chunk.contiguous()
target_chunk = target_chunk.contiguous()
# Here we calculate the gradient of logits_chunk in place so we can save memory.
liger_cross_entropy_kernel[(n_rows,)](
X_ptr=logits_chunk,
X_stride=logits_chunk.stride(-2),
Y_ptr=target_chunk,
Y_stride=target_chunk.stride(-1), # always 1
loss_ptr=loss_1d_slice,
z_loss_ptr=loss_1d_slice, # dummy ptr, not used
loss_stride=loss_1d_slice.stride(-1), # always 1
n_cols=V,
n_non_ignore=n_non_ignore,
ignore_index=ignore_index,
lse_square_scale=lse_square_scale,
label_smoothing=label_smoothing,
reduction=reduction,
softcap=softcap if softcap is not None else 0.0,
RETURN_Z_LOSS=0, # False
HAS_SOFTCAPPING=True if softcap is not None else False,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32 if not is_hip() else 16,
)
# gradient of logits_chunk is computed in-place by the above triton kernel and is of shape: chunk_size x V
# thus grad_input[start_idx: end_idx] should be of shape: chunk_size x H
# additionally, since we are chunking the inputs, observe that the loss and gradients are calculated only
# on `n_non_ignore` tokens. However, the gradient of the input should be calculated for all tokens.
# Thus, we need an additional scaling factor of (n_non_ignore/total_n_non_ignore) to scale the gradients.
if reduction == "mean":
alpha = n_non_ignore / total_n_non_ignore if total_n_non_ignore > 0 else 0.0
else:
alpha = 1.0
loss_1d[start_idx:end_idx] = loss_1d_slice * alpha
grad_logits_chunk = logits_chunk * alpha # chunk_size x V
grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
if grad_weight is not None:
torch.addmm(
input=grad_weight,
mat1=logits_chunk.t(),
mat2=_input_chunk,
out=grad_weight,
alpha=alpha,
beta=1.0,
)
if bias is not None:
torch.add(
input=grad_bias,
other=logits_chunk.sum(dim=0),
out=grad_bias,
alpha=alpha,
)
loss = torch.sum(loss_1d)
return loss, grad_input, grad_weight, grad_bias
def fused_linear_cross_entropy_backward(
grad_output, grad_input, grad_weight, grad_bias
):
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
# We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
# for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
BT, H = grad_input.shape
n_rows = BT
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(H))
element_mul_kernel[(n_rows,)](
grad_input,
grad_input.stride(-2),
grad_output,
H,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32 if not is_hip() else 16,
)
# handle grad_weight
if grad_weight is not None:
V, H = grad_weight.shape
n_rows = V
element_mul_kernel[(n_rows,)](
grad_weight,
grad_weight.stride(-2),
grad_output,
H,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32 if not is_hip() else 16,
)
if grad_bias is not None:
V = grad_bias.shape[0]
n_rows = V
element_mul_kernel[(n_rows,)](
grad_bias,
grad_bias.stride(-1),
grad_output,
1,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32 if not is_hip() else 16,
)
return grad_input, grad_weight, grad_bias
class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
@staticmethod
@amp_custom_fwd
def forward(
ctx,
_input,
weight,
target,
bias=None,
ignore_index=-100,
lse_square_scale=0.0,
label_smoothing=0.0,
reduction="mean",
softcap=None,
):
"""
Fusing the last linear layer with cross-entropy loss
Reference: https://github.com/mgmalek/efficient_cross_entropy
Handle the forward and backward pass of the final linear layer via cross-entropy loss by avoiding
the materialization of the large logits tensor. Since Cross Entropy Loss is the last layer, we can
compute the gradient at the forward pass. By doing so, we don't have to store the _input and target
for the backward pass.
_input: (B*T, H) where B is batch size, T is sequence length, H is hidden dimension.
target: (B*T) where each value is in [0, V-1]
weight: (V, H) where V is the number of classes
bias: (V) where V is the number of classes
ignore_index: the index to ignore in the target
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
reduction: reduction to apply
"""
loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
_input,
weight,
target,
bias,
ignore_index,
lse_square_scale,
label_smoothing,
reduction,
softcap,
)
# downcast to dtype and store for backward
ctx.save_for_backward(
grad_input.detach(),
grad_weight.detach() if grad_weight is not None else None,
grad_bias.detach() if bias is not None else None,
)
return loss
@staticmethod
@amp_custom_bwd
def backward(ctx, grad_output):
(grad_input, grad_weight, grad_bias) = ctx.saved_tensors
grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
grad_output, grad_input, grad_weight, grad_bias
)
return (grad_input, grad_weight, None, grad_bias, None, None, None, None, None)