Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon No.8】 add gumbel distribution api #46255

Merged
merged 69 commits into from
Oct 17, 2022
Merged
Changes from 1 commit
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
d4263c2
init gumbel api
PureNatural Sep 18, 2022
a29f111
commit: update test file
dasenCoding Sep 18, 2022
f5d62e1
fix:bug
PureNatural Sep 19, 2022
0b8faec
update Gumbel API
dasenCoding Sep 28, 2022
e6a8c1b
upgrade distribution/gumbel.py
dasenCoding Oct 4, 2022
2791493
add tests/test_distribution_gumbel.py
dasenCoding Oct 4, 2022
1541ecb
Merge branch 'PaddlePaddle:develop' into gumbel_api
PureNatural Oct 4, 2022
ef3dc50
fix:code style
PureNatural Oct 4, 2022
ddcd86a
fix:code style
PureNatural Oct 4, 2022
4e40718
fix:code style
PureNatural Oct 4, 2022
fff33ad
fix:code style
PureNatural Oct 4, 2022
517d053
fix bug
dasenCoding Oct 5, 2022
cedc871
fix:code style
dasenCoding Oct 5, 2022
72cd09b
fix:code style
PureNatural Oct 5, 2022
8e5bdc4
fix:rollback uniform
PureNatural Oct 5, 2022
cc3f783
fix:delete invalid code
PureNatural Oct 5, 2022
b6416cb
Merge branch 'PaddlePaddle:develop' into gumbel_api
PureNatural Oct 9, 2022
7f603a6
fix:bug and add static test
PureNatural Oct 9, 2022
6a06603
fix:code style
PureNatural Oct 9, 2022
06b9dc4
fix:code style
PureNatural Oct 9, 2022
d83d484
fix:delete init transforms
PureNatural Oct 9, 2022
db490e3
fix:bug
PureNatural Oct 9, 2022
381d059
fix:bug
PureNatural Oct 9, 2022
931f572
fix:code style
PureNatural Oct 9, 2022
b95dc13
fix:code style
PureNatural Oct 9, 2022
78b1b5b
fix:add transforms
PureNatural Oct 9, 2022
67047a2
fix:code style
PureNatural Oct 9, 2022
554a813
fix:code style
PureNatural Oct 9, 2022
c786d25
fix:bug
PureNatural Oct 9, 2022
c398fe3
fix:bug
PureNatural Oct 9, 2022
c713d81
fix:code style
PureNatural Oct 9, 2022
983a3f8
Merge branch 'PaddlePaddle:develop' into gumbel_api
PureNatural Oct 9, 2022
aea7df7
fix:code style
PureNatural Oct 9, 2022
33c780b
fix:bug
PureNatural Oct 9, 2022
da166ee
fix:code style
PureNatural Oct 9, 2022
6891ad8
fix:code style
PureNatural Oct 9, 2022
e10fd27
fix:bug for gumbel.py / add:judge transforms'len for transformed_dist…
dasenCoding Oct 10, 2022
8d5a83c
Merge branch 'gumbel_api' of https://github.com/PureNatural/Paddle in…
dasenCoding Oct 10, 2022
4e328be
update gumbel.py
dasenCoding Oct 11, 2022
9d89aac
fix:bug for test_distribution_gumbel.py
dasenCoding Oct 11, 2022
a0c357d
fix:bug for test_distribution_gumbel_static.py
dasenCoding Oct 11, 2022
db1cbfd
Merge branch 'PaddlePaddle:develop' into gumbel_api
PureNatural Oct 11, 2022
c6a2292
fix:code style
PureNatural Oct 11, 2022
c735592
fix:code style
PureNatural Oct 11, 2022
38530db
Merge branch 'PaddlePaddle:develop' into gumbel_api
PureNatural Oct 11, 2022
fc57abe
Merge branch 'PaddlePaddle:develop' into gumbel_api
PureNatural Oct 11, 2022
4bab5d1
Merge branch 'PaddlePaddle:develop' into gumbel_api
PureNatural Oct 11, 2022
98f8ed6
fix:coverage
PureNatural Oct 11, 2022
33a83fc
Merge branch 'PaddlePaddle:develop' into gumbel_api
PureNatural Oct 12, 2022
f2fa6dc
fix:bug
PureNatural Oct 12, 2022
a20a723
fix:bug
PureNatural Oct 12, 2022
0289b74
fix:code style
PureNatural Oct 12, 2022
fb972c3
fix:bug
PureNatural Oct 12, 2022
2f017a0
Merge branch 'PaddlePaddle:develop' into gumbel_api
PureNatural Oct 12, 2022
0e2892d
delete:no use package for gumbel.py
dasenCoding Oct 12, 2022
3261480
add:coverage transforms'len judge for test_distribution_gumbel.py
dasenCoding Oct 12, 2022
1ecfcc6
fix:code style for test_distribution_gumbel.py
dasenCoding Oct 12, 2022
6a9245e
fix:coverage
PureNatural Oct 12, 2022
e593fb4
fix:code style
PureNatural Oct 12, 2022
8c57748
fix:code style
PureNatural Oct 12, 2022
f7a0c36
fix:code style
PureNatural Oct 12, 2022
444454e
fix:code style
PureNatural Oct 12, 2022
3598ed5
fix:code style
PureNatural Oct 12, 2022
017f66c
Merge branch 'PaddlePaddle:develop' into gumbel_api
PureNatural Oct 14, 2022
e7108f0
fix:en doc
PureNatural Oct 14, 2022
069cadf
Merge branch 'gumbel_api' of github.com:PureNatural/Paddle into gumbe…
PureNatural Oct 14, 2022
e93ff40
fix:param
PureNatural Oct 14, 2022
6172d98
fix:copyright
PureNatural Oct 16, 2022
c957ab4
fixSample; test=document_fix
dasenCoding Oct 17, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix:code style
PureNatural committed Oct 12, 2022
commit 0289b742c480739000c8a296b3ba1fbda4f8b3e7
5 changes: 3 additions & 2 deletions python/paddle/distribution/gumbel.py
Original file line number Diff line number Diff line change
@@ -231,8 +231,9 @@ def rsample(self, shape):
"""
exp_trans = paddle.distribution.ExpTransform()
affine_trans_1 = paddle.distribution.AffineTransform(
paddle.full(shape=self.scale.shape, fill_value=0, dtype=self.loc.dtype),
-paddle.ones_like(self.scale))
paddle.full(shape=self.scale.shape,
fill_value=0,
dtype=self.loc.dtype), -paddle.ones_like(self.scale))
affine_trans_2 = paddle.distribution.AffineTransform(
self.loc, -self.scale)

9 changes: 6 additions & 3 deletions python/paddle/distribution/transformed_distribution.py
Original file line number Diff line number Diff line change
@@ -61,7 +61,8 @@ def __init__(self, base, transforms):
raise TypeError("All element of transforms must be Transform type.")

if not transforms:
super(TransformedDistribution, self).__init__(base.batch_shape, base.event_shape)
super(TransformedDistribution,
self).__init__(base.batch_shape, base.event_shape)
else:
chain = transform.ChainTransform(transforms)
base_shape = base.batch_shape + base.event_shape
@@ -80,8 +81,10 @@ def __init__(self, base, transforms):
transformed_event_rank = chain._codomain.event_rank + \
max(len(base.event_shape) - chain._domain.event_rank, 0)
super(TransformedDistribution, self).__init__(
transformed_shape[:len(transformed_shape) - transformed_event_rank],
transformed_shape[len(transformed_shape) - transformed_event_rank:])
transformed_shape[:len(transformed_shape) -
transformed_event_rank],
transformed_shape[len(transformed_shape) -
transformed_event_rank:])

def sample(self, shape=()):
"""Sample from ``TransformedDistribution``.