-
Notifications
You must be signed in to change notification settings - Fork 6
/
CKA.py
61 lines (42 loc) · 1.62 KB
/
CKA.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# ORIGINAL SOURCE: https://github.com/yuanli2333/CKA-Centered-Kernel-Alignment/blob/master/CKA.py
import math
import numpy as np
def centering(K):
n = K.shape[0]
unit = np.ones([n, n])
I = np.eye(n)
H = I - unit / n
return np.dot(np.dot(H, K), H) # HKH are the same with KH, KH is the first centering, H(KH) do the second time, results are the sme with one time centering
# return np.dot(H, K) # KH
def rbf(X, sigma=None):
GX = np.dot(X, X.T)
KX = np.diag(GX) - GX + (np.diag(GX) - GX).T
if sigma is None:
mdist = np.median(KX[KX != 0])
sigma = math.sqrt(mdist)
KX *= - 0.5 / (sigma * sigma)
KX = np.exp(KX)
return KX
def kernel_HSIC(X, Y, sigma):
return np.sum(centering(rbf(X, sigma)) * centering(rbf(Y, sigma)))
def linear_HSIC(X, Y):
L_X = np.dot(X, X.T)
L_Y = np.dot(Y, Y.T)
return np.sum(centering(L_X) * centering(L_Y))
def linear_CKA(X, Y):
hsic = linear_HSIC(X, Y)
var1 = np.sqrt(linear_HSIC(X, X))
var2 = np.sqrt(linear_HSIC(Y, Y))
return hsic / (var1 * var2)
def kernel_CKA(X, Y, sigma=None):
hsic = kernel_HSIC(X, Y, sigma)
var1 = np.sqrt(kernel_HSIC(X, X, sigma))
var2 = np.sqrt(kernel_HSIC(Y, Y, sigma))
return hsic / (var1 * var2)
if __name__=='__main__':
X = np.random.randn(100, 64)
Y = np.random.randn(100, 64)
print('Linear CKA, between X and Y: {}'.format(linear_CKA(X, Y)))
print('Linear CKA, between X and X: {}'.format(linear_CKA(X, X)))
print('RBF Kernel CKA, between X and Y: {}'.format(kernel_CKA(X, Y)))
print('RBF Kernel CKA, between X and X: {}'.format(kernel_CKA(X, X)))