Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor distribution.draw_values #2902

Merged
merged 12 commits into from
Mar 21, 2018
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 58 additions & 9 deletions pymc3/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from theano import function
import theano
from ..memoize import memoize
from ..model import Model, get_named_nodes, FreeRV, ObservedRV
from ..model import Model, get_named_nodes_and_relations, FreeRV, ObservedRV
from ..vartypes import string_types

__all__ = ['DensityDist', 'Distribution', 'Continuous', 'Discrete',
Expand Down Expand Up @@ -230,16 +230,65 @@ def draw_values(params, point=None):
"""
# Distribution parameters may be nodes which have named node-inputs
# specified in the point. Need to find the node-inputs to replace them.
givens = {}

# Issue #2900 describes a situation in which, the named node-inputs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This information is too detailed here. It is better to put into the release note.

# do not have a random method, while some intermediate node may have
# it. This means that if the named node-input at the leaf of the
# graph does not have a fixed value, theano will try to compile it
# and fail to find inputs, raising a theano.gof.fg.MissingInputError.
# To deal with this problem, we have to try the leaf nodes
# _draw_value, and try it for the parents of the leaf nodes which
# fail with a theano.gof.fg.MissingInputError. This will fill in the
# givens dictionary for the final _draw_value

# Init named nodes dictionary
leaf_nodes = {}
named_nodes_parents = {}
named_nodes_children = {}
for param in params:
if hasattr(param, 'name'):
named_nodes = get_named_nodes(param)
if param.name in named_nodes:
named_nodes.pop(param.name)
for name, node in named_nodes.items():
if not isinstance(node, (tt.sharedvar.SharedVariable,
tt.TensorConstant)):
givens[name] = (node, _draw_value(node, point=point))
# Get the named nodes under the `param` node
nn, nnp, nnc = get_named_nodes_and_relations(param)
leaf_nodes.update(nn)
# Update the discovered parental relationships
for k in nnp.keys():
if k not in named_nodes_parents.keys():
named_nodes_parents[k] = nnp[k]
else:
named_nodes_parents[k].update(nnp[k])
# Update the discovered child relationships
for k in nnc.keys():
if k not in named_nodes_children.keys():
named_nodes_children[k] = nnc[k]
else:
named_nodes_children[k].update(nnc[k])

# Init givens and the stack of nodes to try to `_draw_value` from
givens = {}
stack = list(leaf_nodes.values()) # A queue would be more appropriate
while stack:
next_ = stack.pop(0)
if next_ in givens.keys(): # If the node already has a givens value, skip it
continue
else:
# If the node does not have a givens value, try to draw it
# The named node's children givens values must also be taken
# into account
children = named_nodes_children[next_]
temp_givens = [givens[k] for k in givens.keys() if k in children]
if not temp_givens:
temp_givens = None
try:
# This may fail for autotransformed RVs, which don't
# have the random method
givens[next_.name] = (next_, _draw_value(next_, point=point, givens=temp_givens))
except theano.gof.fg.MissingInputError:
# The node failed, so we must add the node's parents to
# the stack of nodes to try to draw from. We exclude the
# nodes in the `params` list.
stack.extend([node for node in named_nodes_parents[next_]
if node is not None and node.name not in givens.keys()
and node not in params])
values = []
for param in params:
values.append(_draw_value(param, point=point, givens=givens.values()))
Expand Down
70 changes: 56 additions & 14 deletions pymc3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,29 +78,71 @@ def incorporate_methods(source, destination, methods, default=None,
else:
setattr(destination, method, None)


def get_named_nodes(graph):
"""Get the named nodes in a theano graph
(i.e., nodes whose name attribute is not None).
def get_named_nodes_and_relations(graph):
"""Get the named nodes in a theano graph (i.e., nodes whose name
attribute is not None) along with their relationships (i.e., the
node's named parents, and named children, while skipping unnamed
intermediate nodes)

Parameters
----------
graph - a theano node

Returns:
A dictionary of name:node pairs.
leaf_nodes: A dictionary of name:node pairs, of the named nodes that
are also leafs of the graph
node_parents: A dictionary of node:set([parents]) pairs. Each key is
a theano named node, and the corresponding value is the set of
theano named nodes that are parents of the node. These parental
relations skip unnamed intermediate nodes.
node_children: A dictionary of node:set([children]) pairs. Each key
is a theano named node, and the corresponding value is the set
of theano named nodes that are children of the node. These child
relations skip unnamed intermediate nodes.

"""
return _get_named_nodes(graph, {})


def _get_named_nodes(graph, nodes):
if graph.owner is None:
if graph.name is not None:
nodes.update({graph.name: graph})
if graph.name is not None:
node_parents = {graph: set()}
node_children = {graph: set()}
else:
node_parents = {}
node_children = {}
return _get_named_nodes_and_relations(graph, None, {}, node_parents, node_children)

def _get_named_nodes_and_relations(graph, parent, leaf_nodes,
node_parents, node_children):
if graph.owner is None: # Leaf node
if graph.name is not None: # Named leaf node
leaf_nodes.update({graph.name: graph})
if parent is not None: # Is None for the root node
try:
node_parents[graph].add(parent)
except KeyError:
node_parents[graph] = set([parent])
node_children[parent].add(graph)
# Flag that the leaf node has no children
node_children[graph] = set()
else: # Intermediate node
if graph.name is not None: # Intermediate named node
if parent is not None: # Is only None for the root node
try:
node_parents[graph].add(parent)
except KeyError:
node_parents[graph] = set([parent])
node_children[parent].add(graph)
# The current node will be set as the parent of the next
# nodes only if it is a named node
parent = graph
# Init the nodes children to an empty set
node_children[graph] = set()
for i in graph.owner.inputs:
nodes.update(_get_named_nodes(i, nodes))
return nodes
temp_nodes, temp_inter, temp_tree = \
_get_named_nodes_and_relations(i, parent, leaf_nodes,
node_parents, node_children)
leaf_nodes.update(temp_nodes)
node_parents.update(temp_inter)
node_children.update(temp_tree)
return leaf_nodes, node_parents, node_children


class Context(object):
Expand Down
24 changes: 14 additions & 10 deletions pymc3/tests/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,17 @@ def test_dep_vars(self):
point = {'a': np.array([1., 2.])}
npt.assert_equal(draw_values([a], point=point), [point['a']])

with pytest.raises(theano.gof.MissingInputError):
draw_values([a])

# We need the untransformed vars
with pytest.raises(theano.gof.MissingInputError):
draw_values([a], point={'sd': np.array([2., 3.])})

val1 = draw_values([a], point={'sd_log__': np.array([2., 3.])})[0]
val2 = draw_values([a], point={'sd_log__': np.array([2., 3.])})[0]
assert np.all(val1 != val2)
# After #2900 theano.gof.MissingInputError should not be raised
# with a plain draw_values
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think we need comments here.

val1 = draw_values([a])[0]

# After #2900 theano.gof.MissingInputError should not be raised
# even when using the untransformed var sd
val2 = draw_values([a], point={'sd': np.array([2., 3.])})[0]

val3 = draw_values([a], point={'sd_log__': np.array([2., 3.])})[0]
val4 = draw_values([a], point={'sd_log__': np.array([2., 3.])})[0]

assert all([np.all(val1 != val2), np.all(val1 != val3),
np.all(val1 != val4), np.all(val2 != val3),
np.all(val2 != val4), np.all(val3 != val4)])