-
Notifications
You must be signed in to change notification settings - Fork 334
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Implement nnabla.experimental.distributions #565
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@takuseno
Sorry for not getting back to you sooner.
Thank you again for your contribution!
I have reviewed your code. Please reflect my comments.
else: | ||
ref_sample = ref_sample_fn(*params, shape=(10000, 10)) | ||
|
||
assert np.allclose(sample.d.mean(), ref_sample.mean(), atol=3e-2, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use nnabla.testing.assert_allclose rather than numpy.allclose.
loc (~nnabla.Variable): N-D array of :math:`\mu` in definition. | ||
scale (~nnabla.Variable): N-D array of diagonal entries of :math:`L` | ||
such that covariance matrix :math:`\Sigma = L L^T`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think these two args also accept numpy array, don't they?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If not, I think we should implement as numpy array could be used as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I expect it takes nn.Variable because equations in all methods of this class are differentiable. And, since the purpose of these classes is to easily build differentiable distributions, I don't think it is necessary to take numpy arrays, however it will be good for use. Then, I'll make numpy arrays acceptable in those classes.
loc (~nnabla.Variable): N-D array of :math:`\mu` in definition. | ||
scale (~nnabla.Variable): N-D array of :math:`\sigma` in definition. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same above.
low (~nnabla.Variable): N-D array of :math:`low` in definition. | ||
high (~nnabla.Variable): N-D arraya of :math:`high` in definition. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same above.
class MultivariateNormal(Distribution): | ||
"""Multivariate normal distribution. | ||
|
||
Multivariate normal distribution defined as follows: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now MultivariateNormal only supports a diagonal matrix as covariance, right?
If so, I think "Multivariate normal distribution with diagonal covariance matrix" or something like this is better.
|
||
def scipy_fn(loc, scale): | ||
return stats.multivariate_normal(np.reshape(loc, (-1,)), | ||
np.reshape(scale, (-1,))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we create diagonal matrix as covariance such as cov = scale @ scale.T (@ is matrix multiplication), we can delete following 3 functions and make distribution_test_util simpler, can't we?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure it is true but please try it.
2aa8449
to
d3eed32
Compare
I've noticed that current |
Squashed version of #392 .
cc. @TE-AkioHayakawa