-
Notifications
You must be signed in to change notification settings - Fork 182
/
__init__.py
105 lines (87 loc) · 3.47 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torchao.sparsity.training.autograd import semi_structured_sparsify
from torchao.sparsity.training.pointwise_ops import CUTLASS_POINTWISE_OP_DISPATCH_TABLE
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3
# load pointwise op support, which exists only for CUTLASS
if TORCH_VERSION_AT_LEAST_2_3:
from torch.sparse import SparseSemiStructuredTensorCUTLASS
SparseSemiStructuredTensorCUTLASS._load_dispatch_table(
CUTLASS_POINTWISE_OP_DISPATCH_TABLE
)
__all__ = [
"SemiSparseLinear",
"SemiSparseActivationLinear",
"swap_linear_with_semi_sparse_linear",
"swap_semi_sparse_linear_with_linear",
]
class SemiSparseLinear(torch.nn.Linear):
"""
Replacement nn.Linear that supports runtime weight sparsity
"""
def forward(self, x):
sparse_weight = semi_structured_sparsify(self.weight, backend="cusparselt")
return torch.nn.functional.linear(x, sparse_weight, self.bias)
@classmethod
def from_dense(cls, linear):
mod = cls(linear.in_features, linear.out_features)
mod.weight = linear.weight
mod.bias = linear.bias
return mod
@classmethod
def to_dense(cls, semi_sparse_linear):
mod = torch.nn.Linear(
semi_sparse_linear.in_features, semi_sparse_linear.out_features
)
mod.weight = semi_sparse_linear.weight
mod.bias = semi_sparse_linear.bias
return mod
class SemiSparseActivationLinear(torch.nn.Linear):
"""
Replacement nn.Linear that supports runtime activation sparsity
"""
def forward(self, x):
sparse_x = semi_structured_sparsify(x, backend="cusparselt")
return torch.nn.functional.linear(sparse_x, self.weight, self.bias)
@classmethod
def from_dense(cls, linear):
mod = cls(linear.in_features, linear.out_features)
mod.weight = linear.weight
mod.bias = linear.bias
return mod
@classmethod
def to_dense(cls, semi_sparse_linear):
mod = torch.nn.Linear(
semi_sparse_linear.in_features, semi_sparse_linear.out_features
)
mod.weight = semi_sparse_linear.weight
mod.bias = semi_sparse_linear.bias
return mod
def swap_linear_with_semi_sparse_linear(model, config, current=""):
"""
Public API for replacing nn.Linear with SemiSparseLinear
"""
name_to_child = dict(model.named_children())
for name, child in name_to_child.items():
fqn = f"{current}.{name}" if current else name
if isinstance(child, torch.nn.Linear):
if fqn in config:
setattr(model, name, config[fqn].from_dense(child))
del child
else:
swap_linear_with_semi_sparse_linear(child, config, current=fqn)
def swap_semi_sparse_linear_with_linear(model, current=""):
"""
Public API for replacing instances of SemiSparseLinear/SemiSparseActivaitonLinear with nn.Linear
"""
name_to_child = dict(model.named_children())
for name, child in name_to_child.items():
fqn = f"{current}.{name}" if current else name
if isinstance(child, (SemiSparseLinear, SemiSparseActivationLinear)):
setattr(model, name, child.to_dense(child))
del child
else:
swap_semi_sparse_linear_with_linear(child, current=fqn)