-
Notifications
You must be signed in to change notification settings - Fork 24
/
tf_utils.py
79 lines (56 loc) · 2.05 KB
/
tf_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import tensorflow as tf
import numpy as np
_FLOATX = tf.float32
def is_sparse(tensor):
return isinstance(tensor, tf.SparseTensor)
def to_dense(tensor):
if is_sparse(tensor):
return tf.sparse_tensor_to_dense(tensor)
else:
return tensor
def ndim(x):
'''Returns the number of axes in a tensor, as an integer.
'''
if is_sparse(x):
return int(x.shape.get_shape()[0])
dims = x.get_shape()._dims
if dims is not None:
return len(dims)
return None
def concatenate(tensors, axis=-1):
'''Concantes a list of tensors alongside the specified axis.
'''
if axis < 0:
dims = ndim(tensors[0])
if dims:
axis = axis % dims
else:
axis = 0
if all([is_sparse(x) for x in tensors]):
return tf.sparse_concat(axis, tensors)
else:
return tf.concat(axis, [to_dense(x) for x in tensors])
def random_normal(shape, mean=0.0, std=1.0, dtype=_FLOATX, seed=None):
if seed is None:
seed = np.random.randint(10e6)
return tf.random_normal(shape, mean=mean, stddev=std,
dtype=dtype, seed=seed)
def random_uniform(shape, low=0.0, high=1.0, dtype=_FLOATX, seed=None):
if seed is None:
seed = np.random.randint(10e6)
return tf.random_uniform(shape, minval=low, maxval=high,
dtype=dtype, seed=seed)
def random_binomial(shape, p=0.0, dtype=_FLOATX, seed=None):
if seed is None:
seed = np.random.randint(10e6)
return tf.select(tf.random_uniform(shape, dtype=dtype, seed=seed) <= p,
tf.ones(shape), tf.zeros(shape))
def random_uniform_variable(shape, low=-0.05, high=0.05, dtype=_FLOATX,
name=None, seed=None):
shape = tuple(map(int, shape))
if seed is None:
# ensure that randomness is conditioned by the Numpy RNG
seed = np.random.randint(10e8)
value = tf.random_uniform_initializer(
low, high, dtype=dtype, seed=seed)(shape)
return tf.Variable(value, dtype=dtype, name=name)