-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmyutils.py
42 lines (37 loc) · 1.07 KB
/
myutils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import numpy as np
import sys
def calc_mi(pzcx,px):
pz = pzcx@px
inner_ker = (1./pz)[:,None]*pzcx
maps = (inner_ker == 0)
ker_val = np.where(maps,0.0,np.log(inner_ker+1e-7))
return np.sum(pzcx*px[None,:]*ker_val) # avoiding overflow
def checkAlgArgs(**kwargs):
arglist = ['qlevel',
'conv_thres',
'beta',
'max_iter',
]
ordered_args={}
for sid, item in enumerate(arglist):
if kwargs.get(item,False):
ordered_args[item] = kwargs[item]
else:
sys.exit('ERROR: the argument {} required by ib algorithms is missing'.format(item))
return True
def pxy2allprob(pxy):
px = np.sum(pxy,axis=1)
py = np.sum(pxy,axis=0)
pxcy = pxy*(1/py)[None,:]
pycx = np.transpose((1/px)[:,None]*pxy)
return {'px':px,'py':py,'pxcy':pxcy,'pycx':pycx}
def genOutName(**kwargs):
method = kwargs['method']
if method == 'orig':
return 'orig_{}_result'.format(kwargs['dataset'])
elif method == 'gd':
return 'gd_{}_result'.format(kwargs['dataset'])
elif method == 'alm':
return 'alm_{}_result'.format(kwargs['dataset'])
else:
sys.exit('undefined method {}'.format(method))