You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
importaesara.tensorasatfromaepplimportjoint_logprob# Using a multivariate to force our Dimshuffle rewritex=at.random.dirichlet([1, 1, 1])[None, ...]
y=at.random.dirichlet([1, 1, 1], size=(9,))
z=at.concatenate([x, y], axis=0)
z_vv=z.clone()
joint_logprob({z: z_vv}) # ValueError: Cannot drop a non-broadcastable dimension: [False, False], [1]
The logp of the concatenate will split the value of z_vv in two, and the first one will not be inferred to be broadcastable at runtime. We could fix this, by adding a specify_shape of 1 for the dimensions we are dropping in the MeasurableDimshuffle logp here:
More generally is there a reason why we don't add always add a specify_shape when dropping dimensions in Aesara via Dimshuffle instead of raising that error?
The text was updated successfully, but these errors were encountered:
ricardoV94
changed the title
Dimshuffle rewrite fails when information about value broadcastable dimension is lost
MeasurableDimshuffle logp fails when value broadcastable information is lost
Oct 8, 2022
The logp of the concatenate will split the value of
z_vv
in two, and the first one will not be inferred to be broadcastable at runtime. We could fix this, by adding aspecify_shape
of1
for the dimensions we are dropping in theMeasurableDimshuffle
logp here:aeppl/aeppl/tensor.py
Line 211 in 0f06fe9
More generally is there a reason why we don't add always add a
specify_shape
when dropping dimensions in Aesara viaDimshuffle
instead of raising that error?The text was updated successfully, but these errors were encountered: