This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Add contrib.rand_zipfian #9747
Merged
Merged
Add contrib.rand_zipfian #9747
Changes from 17 commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
65fcf2f
draft
ZiyueHuang 4ed3ba4
move to contrib
ZiyueHuang 0792162
rename op
ZiyueHuang 6162c18
CR comments
ZiyueHuang 105c212
Update contrib.py
eric-haibin-lin 6be919c
Update contrib.py
eric-haibin-lin 136defb
Update random.py
eric-haibin-lin 4d128a7
update example in the doc
eric-haibin-lin 436543b
update example in symbol doc
eric-haibin-lin 1cee16f
CR comments
ZiyueHuang 2533576
Merge branch 'master' into log-uniform
eric-haibin-lin a765d2f
update op name
eric-haibin-lin c93af4a
update op name
eric-haibin-lin c53c2a9
update op name in test
eric-haibin-lin c17c215
Merge remote-tracking branch 'upstream/master' into log-uniform
add866d
update test
c48aebe
Update contrib.py
eric-haibin-lin 90d684d
Merge remote-tracking branch 'upstream/master' into log-uniform
ZiyueHuang File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,9 +18,81 @@ | |
# coding: utf-8 | ||
# pylint: disable=wildcard-import, unused-wildcard-import | ||
"""Contrib NDArray API of MXNet.""" | ||
import math | ||
from ..context import current_context | ||
from ..random import uniform | ||
try: | ||
from .gen_contrib import * | ||
except ImportError: | ||
pass | ||
|
||
__all__ = [] | ||
__all__ = ["rand_zipfian"] | ||
|
||
# pylint: disable=line-too-long | ||
def rand_zipfian(true_classes, num_sampled, range_max, ctx=None): | ||
"""Draw random samples from an approximately log-uniform or Zipfian distribution. | ||
This operation randomly samples *num_sampled* candidates the range of integers [0, range_max). | ||
The elements of sampled_candidates are drawn with replacement from the base distribution. | ||
The base distribution for this operator is an approximately log-uniform or Zipfian distribution: | ||
P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1) | ||
This sampler is useful when the true classes approximately follow such a distribution. | ||
For example, if the classes represent words in a lexicon sorted in decreasing order of \ | ||
frequency. If your classes are not ordered by decreasing frequency, do not use this op. | ||
Additionaly, it also returns the number of times each of the \ | ||
true classes and the sampled classes is expected to occur. | ||
Parameters | ||
---------- | ||
true_classes : NDArray | ||
A 1-D NDArray of the target classes. | ||
num_sampled: int | ||
The number of classes to randomly sample. | ||
range_max: int | ||
The number of possible classes. | ||
ctx : Context | ||
Device context of output. Default is current context. Overridden by | ||
`mu.context` when `mu` is an NDArray. | ||
Returns | ||
------- | ||
samples: NDArray | ||
The sampled candidate classes in 1-D `int64` dtype. | ||
expected_count_true: NDArray | ||
The expected count for true classes in 1-D `float64` dtype. | ||
expected_count_sample: NDArray | ||
The expected count for sampled candidates in 1-D `float64` dtype. | ||
Examples | ||
-------- | ||
>>> true_cls = mx.nd.array([3]) | ||
>>> samples, exp_count_true, exp_count_sample = mx.nd.contrib.rand_zipfian(true_cls, 4, 5) | ||
>>> samples | ||
[1 3 3 3] | ||
<NDArray 4 @cpu(0)> | ||
>>> exp_count_true | ||
[ 0.12453879] | ||
<NDArray 1 @cpu(0)> | ||
>>> exp_count_sample | ||
[ 0.22629439 0.12453879 0.12453879 0.12453879] | ||
<NDArray 4 @cpu(0)> | ||
""" | ||
if ctx is None: | ||
ctx = current_context() | ||
log_range = math.log(range_max + 1) | ||
rand = uniform(0, log_range, shape=(num_sampled,), dtype='float64', ctx=ctx) | ||
# make sure sampled_classes are in the range of [0, range_max) | ||
sampled_classes = (rand.exp() - 1).astype('int64') % range_max | ||
|
||
true_cls = true_classes.as_in_context(ctx).astype('float64') | ||
expected_count_true = ((true_cls + 2.0) / (true_cls + 1.0)).log() / log_range * num_sampled | ||
# cast sampled classes to fp64 to avoid interget division | ||
sampled_cls_fp64 = sampled_classes.astype('float64') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is the output always float64? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 64-bit is adopted because this sampler is usually used for extremely large number of classes. Returned samples are always actually always in int64. The fp64 here is used to calculate the probability of a particular classes. (Limited precision of fp32 treat |
||
expected_prob_sampled = ((sampled_cls_fp64 + 2.0) / (sampled_cls_fp64 + 1.0)).log() / log_range | ||
expected_count_sampled = expected_prob_sampled * num_sampled | ||
return sampled_classes, expected_count_true, expected_count_sampled | ||
# pylint: enable=line-too-long |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The example output looks suspicious as it does not sum up to 1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I've misunderstood the term. It should be correct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel it's suspicious at first glance because the exp_count of 1 is larger than the exp_count of 3. However, the sampling result show that 3 is much more often then 1. We need to sample multiple times and test if the empirical expectation matches the true expectation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's just a coincident for the first 5 samples. If I sample 50 times, it returns:
0's = 19
1's = 12
2's = 8
3's = 7
4's = 4
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, looks good