Skip to content

Commit

Permalink
align minibatches (pymc-devs#2760)
Browse files Browse the repository at this point in the history
* align minibatches

* add simple test

* align specific minibatches
  • Loading branch information
ferrine authored and jordan-melendez committed Feb 6, 2018
1 parent eb8edd7 commit 1664574
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 10 deletions.
40 changes: 30 additions & 10 deletions pymc3/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import io
import os
import pkgutil

import collections
import numpy as np
import pymc3 as pm
import theano.tensor as tt
Expand All @@ -11,7 +11,8 @@
__all__ = [
'get_data',
'GeneratorAdapter',
'Minibatch'
'Minibatch',
'align_minibatches'
]


Expand Down Expand Up @@ -221,6 +222,9 @@ class Minibatch(tt.TensorVariable):
>>> x = Minibatch(moredata, [2, None, 20, Ellipsis, 10])
>>> assert x.eval().shape == (2, 20, 20, 40, 10)
"""

RNG = collections.defaultdict(list)

@theano.configparser.change_flags(compute_test_value='raise')
def __init__(self, data, batch_size=128, dtype=None, broadcastable=None, name='Minibatch',
random_seed=42, update_shared_f=None, in_memory_size=None):
Expand All @@ -244,17 +248,21 @@ def __init__(self, data, batch_size=128, dtype=None, broadcastable=None, name='M
inputs=[self.minibatch], outputs=[self])
self.tag.test_value = copy(self.minibatch.tag.test_value)

@staticmethod
def rslice(total, size, seed):
def rslice(self, total, size, seed):
if size is None:
return slice(None)
elif isinstance(size, int):
return (pm.tt_rng(seed)
rng = pm.tt_rng(seed)
Minibatch.RNG[id(self)].append(rng)
return (rng
.uniform(size=(size, ), low=0.0, high=pm.floatX(total) - 1e-16)
.astype('int64'))
else:
raise TypeError('Unrecognized size type, %r' % size)

def __del__(self):
del Minibatch.RNG[id(self)]

@staticmethod
def make_static_slices(user_size):
if user_size is None:
Expand All @@ -278,12 +286,11 @@ def make_static_slices(user_size):
else:
raise TypeError('Unrecognized size type, %r' % user_size)

@classmethod
def make_random_slices(cls, in_memory_shape, batch_size, default_random_seed):
def make_random_slices(self, in_memory_shape, batch_size, default_random_seed):
if batch_size is None:
return [Ellipsis]
elif isinstance(batch_size, int):
slc = [cls.rslice(in_memory_shape[0], batch_size, default_random_seed)]
slc = [self.rslice(in_memory_shape[0], batch_size, default_random_seed)]
elif isinstance(batch_size, (list, tuple)):
def check(t):
if t is Ellipsis or t is None:
Expand Down Expand Up @@ -334,10 +341,10 @@ def check(t):
else:
shp_end = np.asarray([])
shp_begin = shape[:len(begin)]
slc_begin = [cls.rslice(shp_begin[i], t[0], t[1])
slc_begin = [self.rslice(shp_begin[i], t[0], t[1])
if t is not None else tt.arange(shp_begin[i])
for i, t in enumerate(begin)]
slc_end = [cls.rslice(shp_end[i], t[0], t[1])
slc_end = [self.rslice(shp_end[i], t[0], t[1])
if t is not None else tt.arange(shp_end[i])
for i, t in enumerate(end)]
slc = slc_begin + mid + slc_end
Expand All @@ -359,3 +366,16 @@ def clone(self):
ret.name = self.name
ret.tag = copy(self.tag)
return ret


def align_minibatches(batches=None):
if batches is None:
for rngs in Minibatch.RNG.values():
for rng in rngs:
rng.seed()
else:
for b in batches:
if not isinstance(b, Minibatch):
raise TypeError('{b} is not a Minibatch')
for rng in Minibatch.RNG[id(b)]:
rng.seed()
18 changes: 18 additions & 0 deletions pymc3/tests/test_minibatches.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,3 +313,21 @@ def test_cloning_available(self):
res1 = theano.clone(res, {gop: shared})
f = theano.function([], res1)
assert f() == np.array([100])

def test_align(self):
m = pm.Minibatch(np.arange(1000), 1, random_seed=1)
n = pm.Minibatch(np.arange(1000), 1, random_seed=1)
f = theano.function([], [m, n])
n.eval() # not aligned
a, b = zip(*(f() for _ in range(1000)))
assert a != b
pm.align_minibatches()
a, b = zip(*(f() for _ in range(1000)))
assert a == b
n.eval() # not aligned
pm.align_minibatches([m])
a, b = zip(*(f() for _ in range(1000)))
assert a != b
pm.align_minibatches([m, n])
a, b = zip(*(f() for _ in range(1000)))
assert a == b

0 comments on commit 1664574

Please sign in to comment.