-
Notifications
You must be signed in to change notification settings - Fork 0
/
interp_feature.py
151 lines (112 loc) · 5.73 KB
/
interp_feature.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
### Adaptive Feature Interpolation
# create a set of new features from old features
# new_feature = near_interp(old_feature, k, augment_prob)
# k, augment_prob can be generated by function "dynamic_prob" or defined by user
import torch
import torch.nn.functional as F
import numpy as np
from sklearn.manifold import MDS
import random
def near_interp(embeddings, k, augment_prob):
if k == 1 or augment_prob == 0:
return embeddings
k = min(k, embeddings.size()[0])
pd = pairwise_distances(embeddings, embeddings)
pd = pd/pd.max()
pd_s = (1 / (1+pd))
# Select top k near neighbours
k_smallest = torch.topk(pd, k, largest=False).indices # shape: batch_size x k
# Feature interpolation
t = 1
alpha = torch.ones(k, device=embeddings.device)
inner_embeddings = []
for row in k_smallest:
for i in range(k):
alpha[i] = pd_s[row[0],row[i]]**t
p = torch.distributions.dirichlet.Dirichlet(alpha).sample().to(embeddings.device)
inner_pts = torch.matmul(p.reshape((1,-1)),embeddings.index_select(0,row))
inner_embeddings.append(F.normalize(inner_pts))
batch_size = embeddings.size()[0]
out_embeddings = []
# Output interpolated feature with probability p
for idx in range(batch_size):
p = random.random()
if p < augment_prob:
out_embeddings.append(inner_embeddings[idx])
else:
out_embeddings.append(embeddings[idx,:].unsqueeze(0))
return torch.stack(out_embeddings).reshape((batch_size,-1))
def dynamic_prob(embeddings):
embeddings = F.normalize(embeddings)
batch_size = embeddings.size()[0]
D = pairwise_distances(embeddings, embeddings)
D = D.detach().cpu().numpy()
D = D / np.amax(D)
#l_sorted = cmdscale(D)
l_sorted = eigen_mds(D)
# Calculate k,p based on number of large eigenvalues
k = batch_size - next(x[0] for x in enumerate(l_sorted) if x[1] < 0.1 * l_sorted[0])
p = (k-1) / batch_size
#k = 2
#p = 0.9
return p, k
def cmdscale(D):
"""
Classical multidimensional scaling (MDS)
Parameters
----------
D : (n, n) array
Symmetric distance matrix.
Returns
-------
Y : (n, p) array
Configuration matrix. Each column represents a dimension. Only the
p dimensions corresponding to positive eigenvalues of B are returned.
Note that each dimension is only determined up to an overall sign,
corresponding to a reflection.
e : (n,) array
Eigenvalues of B.
"""
# Number of points
n = len(D)
# Centering matrix
H = np.eye(n) - np.ones((n, n))/n
# YY^T
B = -H.dot(D**2).dot(H)/2
# Diagonalize
evals, evecs = np.linalg.eigh(B)
# Sort by eigenvalue in descending order
idx = np.argsort(evals)[::-1]
evals = evals[idx]
evecs = evecs[:,idx]
# Compute the coordinates using positive-eigenvalued components only
# w, = np.where(evals > 0)
# L = np.diag(np.sqrt(evals[w]))
# V = evecs[:,w]
# Y = V.dot(L)
return np.sort(evals)[::-1]
def eigen_mds(pd):
mds = MDS(n_components=len(pd), dissimilarity='precomputed')
pts = mds.fit_transform(pd)
_,l_sorted,_ = np.linalg.svd(pts)
return l_sorted
def pairwise_distances(x, y):
'''
Input: x is a Nxd matrix
y is an optional Mxd matirx
Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:]
if y is not given then use 'y=x'.
i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2
'''
x_norm = (x**2).sum(1).view(-1, 1)
if y is not None:
y_t = torch.transpose(y, 0, 1)
y_norm = (y**2).sum(1).view(1, -1)
else:
y_t = torch.transpose(x, 0, 1)
y_norm = x_norm.view(1, -1)
dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t)
# Ensure diagonal is zero if x=y
# if y is None:
# dist = dist - torch.diag(dist.diag)
return torch.sqrt(torch.clamp(dist, 0.0, np.inf))