-
Notifications
You must be signed in to change notification settings - Fork 183
/
autograd.py
156 lines (133 loc) · 4.89 KB
/
autograd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
from enum import Enum
import torch
from torch.sparse import SparseSemiStructuredTensor
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3
if TORCH_VERSION_AT_LEAST_2_3:
from torch.sparse import (
SparseSemiStructuredTensorCUSPARSELT,
SparseSemiStructuredTensorCUTLASS,
)
torch._dynamo.allow_in_graph(SparseSemiStructuredTensorCUSPARSELT)
torch._dynamo.allow_in_graph(SparseSemiStructuredTensorCUTLASS)
GRADIENT_TYPE = Enum("GRADIENT_TYPE", ["DENSE", "SPARSE", "STE"])
class _SparsifyFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor, algo: str, backend: GRADIENT_TYPE): # type: ignore[override]
use_cutlass = backend == "cutlass"
if not isinstance(x, SparseSemiStructuredTensor):
(packed, meta, packed_t, meta_t, bitmask) = (
torch._sparse_semi_structured_tile(
x, algorithm=algo, use_cutlass=use_cutlass
)
)
cls = (
SparseSemiStructuredTensorCUTLASS
if use_cutlass
else SparseSemiStructuredTensorCUSPARSELT
)
out = cls(
x.shape,
packed=packed,
meta=meta,
packed_t=packed_t,
meta_t=meta_t,
compressed_swizzled_bitmask=bitmask,
requires_grad=False,
fuse_transpose_cusparselt=True,
)
else:
out = x.detach()
return out
@staticmethod
def backward(ctx, grad_out: torch.Tensor): # type: ignore[override]
# We just return grad_out, since we just use STE - straight through estimation
return grad_out, None, None
class _SparsifyLikeFunc(torch.autograd.Function):
@staticmethod
def forward(
ctx,
x: torch.Tensor,
pattern: SparseSemiStructuredTensor,
gradient=GRADIENT_TYPE.SPARSE,
): # type: ignore[override]
assert isinstance(pattern, SparseSemiStructuredTensor)
if not isinstance(pattern, SparseSemiStructuredTensorCUTLASS):
raise NotImplementedError(
"`sparsify_like(x, pattern)` is only implemented for CUTLASS backend"
)
if not pattern.compressed_swizzled_bitmask.is_contiguous():
raise NotImplementedError(
"`sparsify_like(x, pattern)` is not implemented when `bitmask` is transposed"
)
packed, packed_t = torch._sparse_semi_structured_apply(
x, pattern.compressed_swizzled_bitmask
)
# save for backwards
ctx.meta = pattern.meta
ctx.meta_t = pattern.meta_t
ctx.bitmask = pattern.compressed_swizzled_bitmask
ctx.gradient = gradient
return pattern.__class__(
x.shape,
packed,
pattern.meta,
packed_t,
pattern.meta_t,
pattern.compressed_swizzled_bitmask,
requires_grad=x.requires_grad,
)
@staticmethod
def backward(ctx, grad_out: torch.Tensor): # type: ignore[override]
if ctx.gradient == GRADIENT_TYPE.STE or isinstance(
grad_out, SparseSemiStructuredTensor
):
return grad_out, None, None, None
assert not isinstance(grad_out, SparseSemiStructuredTensor)
assert grad_out.dtype == ctx.dtype
if ctx.gradient == GRADIENT_TYPE.DENSE:
assert ctx.threads_masks.is_contiguous()
return (
torch._sparse_semi_structured_apply_dense(grad_out, ctx.bitmask),
None,
None,
None,
)
assert ctx.gradient == GRADIENT_TYPE.SPARSE
packed, _, packed_t, _ = torch._sparse_semi_structured_tile(
grad_out, ctx.bitmask, backend="cutlass"
)
return (
SparseSemiStructuredTensorCUTLASS(
grad_out.shape,
packed,
ctx.meta,
packed_t,
ctx.meta_t,
ctx.bitmask,
requires_grad=grad_out.requires_grad,
),
None,
None,
None,
)
return grad_out, None
@torch._dynamo.allow_in_graph
def semi_structured_sparsify(
x: torch.Tensor,
algo: str = "",
backend: str = "cutlass",
) -> SparseSemiStructuredTensor:
"""
Sparsifies a dense tensor into a semi-structured tensor, according to the algo and backend passed.
"""
return _SparsifyFunc.apply(x, algo, backend)
@torch._dynamo.allow_in_graph
def semi_structured_sparsify_like(
x: torch.Tensor,
pattern: SparseSemiStructuredTensor,
gradient: GRADIENT_TYPE = GRADIENT_TYPE.SPARSE,
) -> SparseSemiStructuredTensor:
"""
Sparsifies a dense tensor into a semi-structured tensor, using the mask of the provided pattern.
"""
return _SparsifyLikeFunc.apply(x, pattern, gradient)