-
Notifications
You must be signed in to change notification settings - Fork 334
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #395 from sony/feature/20190209_random_choice_func…
…tion Add random_choice function and tests.
- Loading branch information
Showing
6 changed files
with
459 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
// Copyright (c) 2017 Sony Corporation. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#ifndef NBLA_FUNCTION_RANDOM_CHOICE_HPP | ||
#define NBLA_FUNCTION_RANDOM_CHOICE_HPP | ||
|
||
#include <nbla/cpu.hpp> | ||
#include <nbla/function.hpp> | ||
#include <nbla/function_registry.hpp> | ||
#include <random> | ||
|
||
namespace nbla { | ||
|
||
NBLA_REGISTER_FUNCTION_HEADER(RandomChoice, const vector<int> &, bool, int); | ||
|
||
/** Generate random samples from population `x` with selection probabilities | ||
determined by the relative weights `w`. The number of samples to draw is | ||
given by the product of `shape`s dimensions, and the samples are returned | ||
with the given `shape`. By default, samples are drawn with replacement, | ||
i.e. selection of a specific population member is solely determined by | ||
its associated weight. Sampling without replacement, where any population | ||
member may be drawn only once, is used if `replace` is set to False. | ||
For both `x` and `w` the innermost dimension corresponds to the individual | ||
populations and their weights from which samples are returned with the | ||
requested `shape` following all outermost dimensions of the input. | ||
.. code-block:: python | ||
import nnabla as nn | ||
import nnabla.functions as F | ||
import numpy as np | ||
nn.set_auto_forward(True) | ||
# x holds two populations | ||
x = nn.Variable.from_numpy_array(np.array([[11, 22, 33], [110, 220, 330]])) | ||
# w holds the weights for each population | ||
w = nn.Variable.from_numpy_array(np.array([[10, 20, 70], [70, 20, 10]])) | ||
# draw one sample from each population | ||
y = F.random_choice(x, w) # y.shape => (2, 1) | ||
# draw 12 samples with shape (3, 4) from each population | ||
y = F.random_choice(x, w, shape=(3, 4)) # y.shape => (2, 3, 4) | ||
Note that weights must not be less than zero and for each population the | ||
sum of weights must be greater than zero. Additionally, sampling without | ||
replacement requires that the number of non-zero weights is not less than | ||
the number of samples to be drawn. These conditions are verified in "cpu" | ||
computation context but not when using "cuda" or "cudnn" acceleration | ||
(this would require additional device synchronization steps penalizing | ||
performance). | ||
Random sampling from an implicit array of index values (like categorical | ||
or multinomial) can be realized with input `x` constructed as indices. | ||
.. code-block:: python | ||
w = nn.Variable.from_numpy_array(np.array([1, 2, 3, 2, 1])) | ||
y = F.random_choice(F.arange(0, 5), w) | ||
Inputs: | ||
- x: N-D array from which a random sample is generated. | ||
- w: N-D array of associated weights of elements in `x`. | ||
Outputs: | ||
- N-D array | ||
@tparam T Data type for computation. | ||
@param shape Number and shape of generated samples. | ||
@param replace Whether sampling is with or without replacement. | ||
@param seed: Random seed. | ||
\ingroup FunctionImplGrp | ||
*/ | ||
template <typename T> | ||
class RandomChoice : public BaseFunction<const vector<int> &, bool, int> { | ||
protected: | ||
const vector<int> shape_; | ||
bool replace_; | ||
int seed_; | ||
std::mt19937 rgen_; | ||
Variable idxbuf_; // stores chosen indices for backward | ||
Size_t outer_loop_; // product of batch dimensions | ||
Size_t inner_loop_; // product of shape dimensions | ||
|
||
public: | ||
RandomChoice(const Context &ctx, const vector<int> &shape, bool replace, | ||
int seed) | ||
: BaseFunction(ctx, shape, replace, seed), shape_(shape), | ||
replace_(replace), seed_(seed) {} | ||
virtual ~RandomChoice() {} | ||
virtual shared_ptr<Function> copy() const { | ||
return create_RandomChoice(ctx_, shape_, replace_, seed_); | ||
} | ||
virtual int min_inputs() { return 2; } | ||
virtual int min_outputs() { return 1; } | ||
virtual vector<dtypes> in_types() { | ||
return vector<dtypes>{get_dtype<T>(), get_dtype<T>()}; | ||
} | ||
virtual vector<dtypes> out_types() { return vector<dtypes>{get_dtype<T>()}; } | ||
virtual vector<string> allowed_array_classes() { | ||
return SingletonManager::get<Cpu>()->array_classes(); | ||
} | ||
virtual string name() { return "RandomChoice"; } | ||
|
||
protected: | ||
NBLA_API virtual void setup_impl(const Variables &inputs, | ||
const Variables &outputs); | ||
NBLA_API virtual void forward_impl(const Variables &inputs, | ||
const Variables &outputs); | ||
NBLA_API virtual void backward_impl(const Variables &inputs, | ||
const Variables &outputs, | ||
const vector<bool> &propagate_down, | ||
const vector<bool> &accum); | ||
}; | ||
} | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
# Copyright (c) 2017 Sony Corporation. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import pytest | ||
import numpy as np | ||
import nnabla as nn | ||
import nnabla.functions as F | ||
from nbla_test_utils import list_context | ||
|
||
ctxs = list_context('RandomChoice') | ||
|
||
|
||
@pytest.mark.parametrize("ctx, func_name", ctxs) | ||
@pytest.mark.parametrize("seed", [-1, 313, 999]) | ||
def test_random_choice_with_replacement(ctx, func_name, seed): | ||
trials = 1000000 | ||
x = nn.Variable([100], need_grad=True) | ||
x.d = np.random.random(x.size).astype(np.float32) | ||
w = nn.Variable([x.size], need_grad=True) | ||
w.d = np.random.randint(1, 100, w.size) | ||
with nn.context_scope(ctx), nn.auto_forward(True): | ||
y = F.random_choice(x, w, shape=[trials], replace=True, seed=seed) | ||
hist_nn, _ = np.histogram(y.d) | ||
hist_np, _ = np.histogram(np.random.choice( | ||
x.d, trials, True, w.d / w.d.sum())) | ||
assert np.allclose(hist_nn / trials, hist_np / trials, atol=1e-2) | ||
x.g = w.g = 0 | ||
y.backward() | ||
assert np.allclose(x.g / trials, w.d / w.d.sum(), atol=1e-2) | ||
assert np.allclose(w.g / trials, w.d / w.d.sum(), atol=1e-2) | ||
|
||
x = nn.Variable.from_numpy_array(np.array([[1, 2, 3], [-1, -2, -3]])) | ||
w = nn.Variable.from_numpy_array(np.array([[1, 1, 1], [10, 10, 10]])) | ||
with nn.context_scope(ctx), nn.auto_forward(): | ||
y = F.random_choice(x, w, shape=(10,), replace=True, seed=seed) | ||
assert y.shape == (2, 10) and np.all(y.d[0] > 0) and np.all(y.d[1] < 0) | ||
|
||
return | ||
x = nn.Variable((3, 3), need_grad=True) | ||
w = nn.Variable((3, 3), need_grad=True) | ||
w.d = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) | ||
with nn.context_scope(ctx), nn.auto_forward(True): | ||
y = F.random_choice(x, w, shape=[10], replace=True, seed=seed) | ||
x.g = w.g = 0 | ||
y.backward(1) | ||
assert np.all(x.g == np.array([[10, 0, 0], [0, 10, 0], [0, 0, 10]])) | ||
assert np.all(w.g == np.array([[10, 0, 0], [0, 10, 0], [0, 0, 10]])) | ||
|
||
|
||
@pytest.mark.parametrize("ctx, func_name", ctxs) | ||
@pytest.mark.parametrize("seed", [-1, 313, 999]) | ||
def test_random_choice_without_replacement(ctx, func_name, seed): | ||
x = nn.Variable.from_numpy_array(np.array([0, 1, 2]).astype(np.int32)) | ||
w = nn.Variable.from_numpy_array(np.array([5, 5, 90]).astype(np.int32)) | ||
x.need_grad = True | ||
w.need_grad = True | ||
repeats = 1000 | ||
with nn.context_scope(ctx): | ||
y = F.random_choice(x, w, shape=[w.size], replace=False, seed=seed) | ||
r = np.zeros((repeats, w.size)).astype(np.int32) | ||
for i in range(repeats): | ||
y.forward() | ||
r[i] = y.d | ||
assert np.all(np.bincount(r.flatten()) == x.size * [repeats]) |
Oops, something went wrong.