Skip to content

Commit

Permalink
Add contrib.rand_zipfian (apache#9747)
Browse files Browse the repository at this point in the history
* draft

* move to contrib

* rename op

* CR comments

* Update contrib.py

* Update contrib.py

* Update random.py

* update example in the doc

* update example in symbol doc

* CR comments

* update op name

* update op name

* update op name in test

* update test

* Update contrib.py
  • Loading branch information
eric-haibin-lin authored and dabraude committed Feb 24, 2018
1 parent 2b4f55b commit a6b345a
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 2 deletions.
74 changes: 73 additions & 1 deletion python/mxnet/ndarray/contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,81 @@
# coding: utf-8
# pylint: disable=wildcard-import, unused-wildcard-import
"""Contrib NDArray API of MXNet."""
import math
from ..context import current_context
from ..random import uniform
try:
from .gen_contrib import *
except ImportError:
pass

__all__ = []
__all__ = ["rand_zipfian"]

# pylint: disable=line-too-long
def rand_zipfian(true_classes, num_sampled, range_max, ctx=None):
"""Draw random samples from an approximately log-uniform or Zipfian distribution.
This operation randomly samples *num_sampled* candidates the range of integers [0, range_max).
The elements of sampled_candidates are drawn with replacement from the base distribution.
The base distribution for this operator is an approximately log-uniform or Zipfian distribution:
P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)
This sampler is useful when the true classes approximately follow such a distribution.
For example, if the classes represent words in a lexicon sorted in decreasing order of \
frequency. If your classes are not ordered by decreasing frequency, do not use this op.
Additionaly, it also returns the number of times each of the \
true classes and the sampled classes is expected to occur.
Parameters
----------
true_classes : NDArray
A 1-D NDArray of the target classes.
num_sampled: int
The number of classes to randomly sample.
range_max: int
The number of possible classes.
ctx : Context
Device context of output. Default is current context. Overridden by
`mu.context` when `mu` is an NDArray.
Returns
-------
samples: NDArray
The sampled candidate classes in 1-D `int64` dtype.
expected_count_true: NDArray
The expected count for true classes in 1-D `float64` dtype.
expected_count_sample: NDArray
The expected count for sampled candidates in 1-D `float64` dtype.
Examples
--------
>>> true_cls = mx.nd.array([3])
>>> samples, exp_count_true, exp_count_sample = mx.nd.contrib.rand_zipfian(true_cls, 4, 5)
>>> samples
[1 3 3 3]
<NDArray 4 @cpu(0)>
>>> exp_count_true
[ 0.12453879]
<NDArray 1 @cpu(0)>
>>> exp_count_sample
[ 0.22629439 0.12453879 0.12453879 0.12453879]
<NDArray 4 @cpu(0)>
"""
if ctx is None:
ctx = current_context()
log_range = math.log(range_max + 1)
rand = uniform(0, log_range, shape=(num_sampled,), dtype='float64', ctx=ctx)
# make sure sampled_classes are in the range of [0, range_max)
sampled_classes = (rand.exp() - 1).astype('int64') % range_max

true_cls = true_classes.as_in_context(ctx).astype('float64')
expected_count_true = ((true_cls + 2.0) / (true_cls + 1.0)).log() / log_range * num_sampled
# cast sampled classes to fp64 to avoid interget division
sampled_cls_fp64 = sampled_classes.astype('float64')
expected_prob_sampled = ((sampled_cls_fp64 + 2.0) / (sampled_cls_fp64 + 1.0)).log() / log_range
expected_count_sampled = expected_prob_sampled * num_sampled
return sampled_classes, expected_count_true, expected_count_sampled
# pylint: enable=line-too-long
69 changes: 68 additions & 1 deletion python/mxnet/symbol/contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,76 @@
# coding: utf-8
# pylint: disable=wildcard-import, unused-wildcard-import
"""Contrib Symbol API of MXNet."""
import math
from .random import uniform
from .symbol import Symbol
try:
from .gen_contrib import *
except ImportError:
pass

__all__ = []
__all__ = ["rand_zipfian"]

def rand_zipfian(true_classes, num_sampled, range_max):
"""Draw random samples from an approximately log-uniform or Zipfian distribution.
This operation randomly samples *num_sampled* candidates the range of integers [0, range_max).
The elements of sampled_candidates are drawn with replacement from the base distribution.
The base distribution for this operator is an approximately log-uniform or Zipfian distribution:
P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)
This sampler is useful when the true classes approximately follow such a distribution.
For example, if the classes represent words in a lexicon sorted in decreasing order of \
frequency. If your classes are not ordered by decreasing frequency, do not use this op.
Additionaly, it also returns the number of times each of the \
true classes and the sampled classes is expected to occur.
Parameters
----------
true_classes : Symbol
The target classes in 1-D.
num_sampled: int
The number of classes to randomly sample.
range_max: int
The number of possible classes.
Returns
-------
samples: Symbol
The sampled candidate classes in 1-D `int64` dtype.
expected_count_true: Symbol
The expected count for true classes in 1-D `float64` dtype.
expected_count_sample: Symbol
The expected count for sampled candidates in 1-D `float64` dtype.
Examples
--------
>>> true_cls = mx.nd.array([3])
>>> samples, exp_count_true, exp_count_sample = mx.nd.contrib.rand_zipfian(true_cls, 4, 5)
>>> samples
[1 3 3 3]
<NDArray 4 @cpu(0)>
>>> exp_count_true
[ 0.12453879]
<NDArray 1 @cpu(0)>
>>> exp_count_sample
[ 0.22629439 0.12453879 0.12453879 0.12453879]
<NDArray 4 @cpu(0)>
"""
assert(isinstance(true_classes, Symbol)), "unexpected type %s" % type(true_classes)
log_range = math.log(range_max + 1)
rand = uniform(0, log_range, shape=(num_sampled,), dtype='float64')
# make sure sampled_classes are in the range of [0, range_max)
sampled_classes = (rand.exp() - 1).astype('int64') % range_max

true_classes = true_classes.astype('float64')
expected_prob_true = ((true_classes + 2.0) / (true_classes + 1.0)).log() / log_range
expected_count_true = expected_prob_true * num_sampled
# cast sampled classes to fp64 to avoid interget division
sampled_cls_fp64 = sampled_classes.astype('float64')
expected_prob_sampled = ((sampled_cls_fp64 + 2.0) / (sampled_cls_fp64 + 1.0)).log() / log_range
expected_count_sampled = expected_prob_sampled * num_sampled
return sampled_classes, expected_count_true, expected_count_sampled
30 changes: 30 additions & 0 deletions tests/python/unittest/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,36 @@ def check_data(a, b):
for j in range(i+1, num_seeds):
check_data(data[i],data[j])

@with_seed()
def test_zipfian_generator():
# dummy true classes
num_true = 5
num_sampled = 1000
range_max = 20

def compute_expected_prob():
# P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)
classes = mx.nd.arange(0, range_max)
expected_counts = ((classes + 2).log() - (classes + 1).log()) / np.log(range_max + 1)
return expected_counts

exp_cnt = compute_expected_prob() * num_sampled

# test ndarray
true_classes = mx.nd.random.uniform(0, range_max, shape=(num_true,)).astype('int32')
sampled_classes, exp_cnt_true, exp_cnt_sampled = mx.nd.contrib.rand_zipfian(true_classes, num_sampled, range_max)
mx.test_utils.assert_almost_equal(exp_cnt_sampled.asnumpy(), exp_cnt[sampled_classes].asnumpy(), rtol=1e-1, atol=1e-2)
mx.test_utils.assert_almost_equal(exp_cnt_true.asnumpy(), exp_cnt[true_classes].asnumpy(), rtol=1e-1, atol=1e-2)

# test symbol
true_classes_var = mx.sym.var('true_classes')
outputs = mx.sym.contrib.rand_zipfian(true_classes_var, num_sampled, range_max)
outputs = mx.sym.Group(outputs)
executor = outputs.bind(mx.context.current_context(), {'true_classes' : true_classes})
executor.forward()
sampled_classes, exp_cnt_true, exp_cnt_sampled = executor.outputs
mx.test_utils.assert_almost_equal(exp_cnt_sampled.asnumpy(), exp_cnt[sampled_classes].asnumpy(), rtol=1e-1, atol=1e-2)
mx.test_utils.assert_almost_equal(exp_cnt_true.asnumpy(), exp_cnt[true_classes].asnumpy(), rtol=1e-1, atol=1e-2)

if __name__ == '__main__':
import nose
Expand Down

0 comments on commit a6b345a

Please sign in to comment.