-
Notifications
You must be signed in to change notification settings - Fork 14
/
binary.py
128 lines (111 loc) · 4.12 KB
/
binary.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from numpy import mean
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
index = 0
@torch.no_grad()
def part_mean(tensor, op='-'):
non_zero = tensor*(tensor!=0)
mean_val = non_zero.mean(-1).view(-1, 1)
return mean_val
@torch.no_grad()
def high_order_residual(x, mask, order=2):
sum_order = torch.zeros_like(x)
new_matrix = x.clone()
new_matrix = new_matrix * mask
global index
index += 1
for od in range(order):
residual = new_matrix - sum_order
masked_x_tensor = torch.where(mask, residual, torch.tensor(float('nan')))
mean_tensor_all = torch.nanmean(masked_x_tensor, dim=1)
mean_tensor_all = torch.where(torch.isnan(mean_tensor_all), torch.zeros_like(mean_tensor_all), mean_tensor_all)
masked_x_tensor -= mean_tensor_all[:, None]
scale_tensor_all = torch.nanmean(torch.abs(masked_x_tensor), dim=1)
scale_tensor_all = torch.where(torch.isnan(scale_tensor_all), torch.zeros_like(scale_tensor_all), scale_tensor_all)
binary= torch.sign(masked_x_tensor)
binary *= scale_tensor_all[:, None]
binary += mean_tensor_all[:, None]
sum_order = sum_order + binary*mask
return sum_order
@torch.no_grad()
def normal_quantize(x, scale, zero, maxq):
q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
return scale * (q - zero)
class Binarization(nn.Module):
def __init__(self, weight, method="2bit", groupsize=-1):
super().__init__()
oc,ic=weight.shape
if groupsize==-1:
groupsize=ic
self.groupsize=groupsize
self.n_groups=math.ceil(ic/groupsize)
self.method=method
self.mean = 0
def quantize(self, w, mask, order=2, groupi=0):
if self.method=="xnor":
w_mean = self.mean[groupi]
w = w - w_mean # oc, ic
w = w.sign()
w = w * self.scale[groupi]
w+=w_mean
elif self.method=="braq": # The method used in paper
w = high_order_residual(w, mask, order=order)
elif self.method=="sign":
w=(w>0).float()
w*=self.scale[groupi]
elif self.method=="rtn":
w=F.relu(w)
w_int=(w/self.scale[groupi]).round().clamp(0,1)
w=w_int*self.scale[groupi]
elif self.method in ['2bit','4bit']:
bits = int(self.method[0])
perchannel = True
weight = True
dev = w.device
maxq = torch.tensor(2 ** bits - 1)
scale = torch.zeros(1)
zero = torch.zeros(1)
if dev != scale.device:
scale=scale.to(dev)
zero=zero.to(dev)
maxq=maxq.to(dev)
x = w.clone()
shape = x.shape
if perchannel:
if weight:
x = x.flatten(1)
else:
if len(shape) == 4:
x = x.permute([1, 0, 2, 3])
x = x.flatten(1)
if len(shape) == 3:
x = x.reshape((-1, shape[-1])).t()
if len(shape) == 2:
x = x.t()
else:
x = x.flatten().unsqueeze(0)
tmp = torch.zeros(x.shape[0], device=dev)
xmin = torch.minimum(x.min(1)[0], tmp)
xmax = torch.maximum(x.max(1)[0], tmp)
tmp = (xmin == 0) & (xmax == 0)
xmin[tmp] = -1
xmax[tmp] = +1
scale = (xmax - xmin) / maxq
zero = torch.round(-xmin / scale)
if not perchannel:
if weight:
tmp = shape[0]
else:
tmp = shape[1] if len(shape) != 3 else shape[2]
scale = scale.repeat(tmp)
zero = zero.repeat(tmp)
if weight:
shape = [-1] + [1] * (len(shape) - 1)
scale = scale.reshape(shape)
zero = zero.reshape(shape)
w = normal_quantize(w, scale, zero, maxq)
elif self.method=="prune":
return torch.zeros_like(w)
return w