-
Notifications
You must be signed in to change notification settings - Fork 208
/
rms_norm.py
323 lines (274 loc) · 9.52 KB
/
rms_norm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
import operator
import torch
import triton
import triton.language as tl
from liger_kernel.ops.utils import (
calculate_settings,
compare_version,
ensure_contiguous,
)
if compare_version("triton", operator.ge, "3.0.0"):
try:
# typical import path with dispatch available
from triton.language.extra.libdevice import rsqrt
except ModuleNotFoundError:
# for working with NGC containers
from triton.language.extra.cuda.libdevice import rsqrt
else:
from triton.language.math import rsqrt
_CASTING_MODE_NONE = tl.constexpr(-1)
_CASTING_MODE_LLAMA = tl.constexpr(0)
_CASTING_MODE_GEMMA = tl.constexpr(1)
@triton.jit
def _rms_norm_forward_kernel(
Y_ptr,
Y_row_stride,
X_ptr,
X_row_stride,
W_ptr,
W_row_stride,
r_ptr,
r_row_stride,
n_cols,
eps,
offset,
casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
BLOCK_SIZE: tl.constexpr,
):
"""
y_i = (x_i / (RMS)) * (offset + wi), RMS = sqrt(sum(x_i^2) / N)
Reference:
1. https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
2. https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22
3. https://arxiv.org/pdf/1910.07467
"""
row_idx = tl.program_id(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
Y_ptr += row_idx * Y_row_stride
X_ptr += row_idx * X_row_stride
r_ptr += row_idx * r_row_stride
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
X_row_dtype = X_row.dtype
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
# On Llama, only inv_rms is computed on fp32
if casting_mode == _CASTING_MODE_LLAMA:
X_row = X_row.to(tl.float32)
# Gemma computes everything on fp32, and then casts back the output to the original dtype
if casting_mode == _CASTING_MODE_GEMMA:
W_row = W_row.to(tl.float32)
X_row = X_row.to(tl.float32)
mean_square = tl.sum(X_row * X_row, axis=0) / n_cols
inv_rms = rsqrt(mean_square + eps)
# We can save time by caching rms with minimal memory overhead
# because rms is much smaller compared to X_row, as rms is for each row.
# However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
tl.store(r_ptr, inv_rms)
X_row = X_row * inv_rms
# On Llama, the multiplication with the weight is done on the original dtype
if casting_mode == _CASTING_MODE_LLAMA:
X_row = X_row.to(X_row_dtype)
Y_row = X_row * (offset + W_row)
tl.store(Y_ptr + col_offsets, Y_row, mask=mask)
@triton.jit
def _rms_norm_backward_kernel(
dY_ptr,
dY_row_stride,
X_ptr,
X_row_stride,
W_ptr,
W_row_stride,
r_ptr,
r_row_stride,
dW_ptr,
dW_row_stride,
n_cols,
eps,
offset,
casting_mode: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""
dx = (1 / RMS) * [dy * (w + offset - (1 / N) * (1 / RMS^2) * ((dy * (w + offset)) dot x) * x]. * means element-wise multiplication, whileas dot means dot product
dw = sum(dy * (x / RMS)). summation over BxT dimension
"""
row_idx = tl.program_id(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
dY_ptr += row_idx * dY_row_stride
X_ptr += row_idx * X_row_stride
r_ptr += row_idx * r_row_stride
dW_ptr += row_idx * dW_row_stride
dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0)
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
original_x_dtype = X_row.dtype
# Get cached rms
inv_rms_row = tl.load(r_ptr)
W_row = W_row + offset
# Different bacward graphs for different casting modes
if casting_mode == _CASTING_MODE_LLAMA:
X_row = X_row.to(tl.float32)
m = (dY_row * W_row).to(tl.float32)
dX_row = inv_rms_row * m
dX_row += (inv_rms_row) * (
-(1 / n_cols)
* inv_rms_row
* inv_rms_row
* tl.sum(m * X_row, axis=0)
* X_row
)
if casting_mode == _CASTING_MODE_GEMMA:
dY_row, W_row, X_row = (
dY_row.to(tl.float32),
W_row.to(tl.float32),
X_row.to(tl.float32),
)
dX_row = inv_rms_row * dY_row * W_row
dX_row += (inv_rms_row) * (
-(1 / n_cols)
* inv_rms_row
* inv_rms_row
* tl.sum(dY_row * W_row * X_row, axis=0)
* X_row
)
# calculate the gradient of W
if casting_mode == _CASTING_MODE_LLAMA:
dW_row = dY_row * (X_row * inv_rms_row).to(original_x_dtype)
else:
# here X_row is already in fp32 (see previous if block)
dW_row = dY_row * (X_row * inv_rms_row)
tl.store(dY_ptr + col_offsets, dX_row, mask=mask)
tl.store(dW_ptr + col_offsets, dW_row, mask=mask)
_str_to_casting_mode = {
"llama": _CASTING_MODE_LLAMA.value,
"gemma": _CASTING_MODE_GEMMA.value,
"none": _CASTING_MODE_NONE.value,
}
def rms_norm_forward(X, W, eps, offset, casting_mode):
if not isinstance(casting_mode, int):
assert (
casting_mode in _str_to_casting_mode
), f"Invalid casting mode: {casting_mode}"
casting_mode = _str_to_casting_mode[casting_mode]
else:
assert (
casting_mode in _str_to_casting_mode.values()
), f"Invalid casting mode: {casting_mode}"
shape = X.shape
dim = shape[-1]
X = X.view(-1, dim)
n_rows, n_cols = X.shape
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
# r is to cache (1/rms) for each row
# r is always computed/stored in fp32 if we are using Llama or Gemma casting mode
r_dtype = (
torch.float32
if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value)
else X.dtype
)
r = torch.empty(n_rows, dtype=r_dtype, device=X.device)
# Check constraints.
assert (
X.shape[1] == W.shape[0]
), "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
_rms_norm_forward_kernel[(n_rows,)](
Y,
Y.stride(0),
X,
X.stride(0),
W,
W.stride(0),
r,
r.stride(0),
n_cols,
eps,
offset,
casting_mode,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
return Y.view(*shape), X, r, BLOCK_SIZE, num_warps, casting_mode
def rms_norm_backward(dY, X, W, r, eps, offset, casting_mode, BLOCK_SIZE, num_warps):
shape = dY.shape
dim = shape[-1]
dY = dY.view(-1, dim)
n_rows, n_cols = dY.shape
dW = torch.empty_like(
X,
dtype=(torch.float32 if casting_mode == _CASTING_MODE_GEMMA.value else W.dtype),
)
# Here we use dY to store the value of dX to save memory
_rms_norm_backward_kernel[(n_rows,)](
dY,
dY.stride(0),
X,
X.stride(0),
W,
W.stride(0),
r,
r.stride(0),
dW,
dW.stride(0),
n_cols,
eps,
offset,
casting_mode,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
dX = dY.view(*shape)
dW = torch.sum(dW, dim=0).to(W.dtype)
return dX, dW
class LigerRMSNormFunction(torch.autograd.Function):
"""
Performs RMSNorm (Root Mean Square Normalization), which normalizes the input tensor `X` using the
weight tensor `W`, with an optional offset and casting mode.
Some models use an 'offset' to shift the weight tensor `W` by a constant value. For example, Gemma
uses an offset of 1.0, so the computation becomes `(X / RMS(X)) * (W + 1.0)` instead of the usual
`(X / RMS(X)) * W`. You can pass the offset value as an argument to the forward function.
In addition, different models cast their inputs at different places during RMSNorm computation. For
example, Gemma casts everything to fp32 nefore starting the computation, while Llama casts only the
inverse RMS to fp32. You can specify the casting mode using the `casting_mode` argument. We currently
support the following casting modes (they match HuggingFace Transformers' implementations):
- 'llama': matches the Llama implementation, where only the inverse RMS is computed on fp32.
- 'gemma': matches the Gemma implementation, where everything is cast to fp32, then computed, then cast back to the original dtype.
- 'none': no casting is done. The computation is done in the original dtype. This saves memory and is slightly faster, but has more error w.r.t. the original implementation.
"""
@staticmethod
@ensure_contiguous
def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama"):
"""
X: (B, T, H) or (BxT, H)
W: (H,)
"""
Y, X, r, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(
X, W, eps, offset, casting_mode
)
ctx.eps = eps
ctx.offset = offset
ctx.casting_mode = casting_mode
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
ctx.save_for_backward(X, W, r)
return Y
@staticmethod
@ensure_contiguous
def backward(ctx, dY):
"""
Y: (B, T, H) or (BxT, H)
"""
X, W, r = ctx.saved_tensors
dX, dW = rms_norm_backward(
dY,
X,
W,
r,
ctx.eps,
ctx.offset,
ctx.casting_mode,
ctx.BLOCK_SIZE,
ctx.num_warps,
)
return dX, dW, None, None, None