From 890586b9d3c792fd10818ef795ad12b00548d60c Mon Sep 17 00:00:00 2001 From: Stephen Tiedemann Date: Sat, 9 Feb 2019 16:31:26 +0100 Subject: [PATCH] Add random_choice function and tests. --- .../code_generator/function_types.yaml | 3 + build-tools/code_generator/functions.yaml | 73 ++++++++ doc/python/api/function.rst | 1 + include/nbla/function/random_choice.hpp | 134 ++++++++++++++ python/test/function/test_random_choice.py | 75 ++++++++ src/nbla/function/generic/random_choice.cpp | 173 ++++++++++++++++++ 6 files changed, 459 insertions(+) create mode 100644 include/nbla/function/random_choice.hpp create mode 100644 python/test/function/test_random_choice.py create mode 100644 src/nbla/function/generic/random_choice.cpp diff --git a/build-tools/code_generator/function_types.yaml b/build-tools/code_generator/function_types.yaml index fc02f99a0..f32cf6da3 100644 --- a/build-tools/code_generator/function_types.yaml +++ b/build-tools/code_generator/function_types.yaml @@ -315,6 +315,9 @@ Randint: Randn: float: [float] half: [Half] +RandomChoice: + float: [float] + half: [Half] RandomCrop: float: [float] half: [Half] diff --git a/build-tools/code_generator/functions.yaml b/build-tools/code_generator/functions.yaml index 42575873b..64489dc4f 100755 --- a/build-tools/code_generator/functions.yaml +++ b/build-tools/code_generator/functions.yaml @@ -3525,6 +3525,79 @@ Stochasticity: function_ids: ffiIi: 91 c_runtime: not support + RandomChoice: + snake_name: random_choice + doc: |2 + + Generate random samples from population `x` with selection probabilities + determined by the relative weights `w`. The number of samples to draw is + given by the product of `shape`s dimensions, and the samples are returned + with the given `shape`. By default, samples are drawn with replacement, + i.e. selection of a specific population member is solely determined by + its associated weight. Sampling without replacement, where any population + member may be drawn only once, is used if `replace` is set to False. + + For both `x` and `w` the innermost dimension corresponds to the individual + populations and their weights from which samples are returned with the + requested `shape` following all outermost dimensions of the input. + + .. code-block:: python + + import nnabla as nn + import nnabla.functions as F + import numpy as np + nn.set_auto_forward(True) + + # x holds two populations + x = nn.Variable.from_numpy_array(np.array([[11, 22, 33], [110, 220, 330]])) + # w holds the weights for each population + w = nn.Variable.from_numpy_array(np.array([[10, 20, 70], [70, 20, 10]])) + + # draw one sample from each population + y = F.random_choice(x, w) # y.shape => (2, 1) + + # draw 12 samples with shape (3, 4) from each population + y = F.random_choice(x, w, shape=(3, 4)) # y.shape => (2, 3, 4) + + Note that weights must not be less than zero and for each population the + sum of weights must be greater than zero. Additionally, sampling without + replacement requires that the number of non-zero weights is not less than + the number of samples to be drawn. These conditions are verified in "cpu" + computation context but not when using "cuda" or "cudnn" acceleration + (this would require additional device synchronization steps penalizing + performance). + + Random sampling from an implicit array of index values (like categorical + or multinomial) can be realized with input `x` constructed as indices. + + .. code-block:: python + + w = nn.Variable.from_numpy_array(np.array([1, 2, 3, 2, 1])) + y = F.random_choice(F.arange(0, 5), w) + inputs: + x: + doc: N-D array from which a random sample is generated. + w: + doc: N-D array of associated weights of elements in `x`. + arguments: + shape: + doc: Number and shape of generated samples. + type: Shape + default: '[]' + replace: + doc: Whether sampling is with or without replacement. + type: bool + default: 'True' + seed: + doc: Random seed. + type: int64 + default: '-1' + outputs: + y: + doc: N-D array + c_runtime: not support + function_ids: + iIBi: 246 RandomCrop: snake_name: random_crop doc: |2 diff --git a/doc/python/api/function.rst b/doc/python/api/function.rst index 9c18482b0..becb96a76 100644 --- a/doc/python/api/function.rst +++ b/doc/python/api/function.rst @@ -191,6 +191,7 @@ Stochasticity .. autofunction:: dropout .. autofunction:: top_k_data .. autofunction:: top_k_grad +.. autofunction:: random_choice .. autofunction:: random_crop .. autofunction:: random_flip .. autofunction:: random_shift diff --git a/include/nbla/function/random_choice.hpp b/include/nbla/function/random_choice.hpp new file mode 100644 index 000000000..4fc3ea4a2 --- /dev/null +++ b/include/nbla/function/random_choice.hpp @@ -0,0 +1,134 @@ +// Copyright (c) 2017 Sony Corporation. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef NBLA_FUNCTION_RANDOM_CHOICE_HPP +#define NBLA_FUNCTION_RANDOM_CHOICE_HPP + +#include +#include +#include +#include + +namespace nbla { + +NBLA_REGISTER_FUNCTION_HEADER(RandomChoice, const vector &, bool, int); + +/** Generate random samples from population `x` with selection probabilities + determined by the relative weights `w`. The number of samples to draw is + given by the product of `shape`s dimensions, and the samples are returned + with the given `shape`. By default, samples are drawn with replacement, + i.e. selection of a specific population member is solely determined by + its associated weight. Sampling without replacement, where any population + member may be drawn only once, is used if `replace` is set to False. + + For both `x` and `w` the innermost dimension corresponds to the individual + populations and their weights from which samples are returned with the + requested `shape` following all outermost dimensions of the input. + + .. code-block:: python + + import nnabla as nn + import nnabla.functions as F + import numpy as np + nn.set_auto_forward(True) + + # x holds two populations + x = nn.Variable.from_numpy_array(np.array([[11, 22, 33], [110, 220, 330]])) + # w holds the weights for each population + w = nn.Variable.from_numpy_array(np.array([[10, 20, 70], [70, 20, 10]])) + + # draw one sample from each population + y = F.random_choice(x, w) # y.shape => (2, 1) + + # draw 12 samples with shape (3, 4) from each population + y = F.random_choice(x, w, shape=(3, 4)) # y.shape => (2, 3, 4) + + Note that weights must not be less than zero and for each population the + sum of weights must be greater than zero. Additionally, sampling without + replacement requires that the number of non-zero weights is not less than + the number of samples to be drawn. These conditions are verified in "cpu" + computation context but not when using "cuda" or "cudnn" acceleration + (this would require additional device synchronization steps penalizing + performance). + + Random sampling from an implicit array of index values (like categorical + or multinomial) can be realized with input `x` constructed as indices. + + .. code-block:: python + + w = nn.Variable.from_numpy_array(np.array([1, 2, 3, 2, 1])) + y = F.random_choice(F.arange(0, 5), w) + +Inputs: + +- x: N-D array from which a random sample is generated. +- w: N-D array of associated weights of elements in `x`. + +Outputs: + +- N-D array + +@tparam T Data type for computation. + +@param shape Number and shape of generated samples. + +@param replace Whether sampling is with or without replacement. + +@param seed: Random seed. + +\ingroup FunctionImplGrp + */ +template +class RandomChoice : public BaseFunction &, bool, int> { +protected: + const vector shape_; + bool replace_; + int seed_; + std::mt19937 rgen_; + Variable idxbuf_; // stores chosen indices for backward + Size_t outer_loop_; // product of batch dimensions + Size_t inner_loop_; // product of shape dimensions + +public: + RandomChoice(const Context &ctx, const vector &shape, bool replace, + int seed) + : BaseFunction(ctx, shape, replace, seed), shape_(shape), + replace_(replace), seed_(seed) {} + virtual ~RandomChoice() {} + virtual shared_ptr copy() const { + return create_RandomChoice(ctx_, shape_, replace_, seed_); + } + virtual int min_inputs() { return 2; } + virtual int min_outputs() { return 1; } + virtual vector in_types() { + return vector{get_dtype(), get_dtype()}; + } + virtual vector out_types() { return vector{get_dtype()}; } + virtual vector allowed_array_classes() { + return SingletonManager::get()->array_classes(); + } + virtual string name() { return "RandomChoice"; } + +protected: + NBLA_API virtual void setup_impl(const Variables &inputs, + const Variables &outputs); + NBLA_API virtual void forward_impl(const Variables &inputs, + const Variables &outputs); + NBLA_API virtual void backward_impl(const Variables &inputs, + const Variables &outputs, + const vector &propagate_down, + const vector &accum); +}; +} +#endif diff --git a/python/test/function/test_random_choice.py b/python/test/function/test_random_choice.py new file mode 100644 index 000000000..c4bfcdee3 --- /dev/null +++ b/python/test/function/test_random_choice.py @@ -0,0 +1,75 @@ +# Copyright (c) 2017 Sony Corporation. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import numpy as np +import nnabla as nn +import nnabla.functions as F +from nbla_test_utils import list_context + +ctxs = list_context('RandomChoice') + + +@pytest.mark.parametrize("ctx, func_name", ctxs) +@pytest.mark.parametrize("seed", [-1, 313, 999]) +def test_random_choice_with_replacement(ctx, func_name, seed): + trials = 1000000 + x = nn.Variable([100], need_grad=True) + x.d = np.random.random(x.size).astype(np.float32) + w = nn.Variable([x.size], need_grad=True) + w.d = np.random.randint(1, 100, w.size) + with nn.context_scope(ctx), nn.auto_forward(True): + y = F.random_choice(x, w, shape=[trials], replace=True, seed=seed) + hist_nn, _ = np.histogram(y.d) + hist_np, _ = np.histogram(np.random.choice( + x.d, trials, True, w.d / w.d.sum())) + assert np.allclose(hist_nn / trials, hist_np / trials, atol=1e-2) + x.g = w.g = 0 + y.backward() + assert np.allclose(x.g / trials, w.d / w.d.sum(), atol=1e-2) + assert np.allclose(w.g / trials, w.d / w.d.sum(), atol=1e-2) + + x = nn.Variable.from_numpy_array(np.array([[1, 2, 3], [-1, -2, -3]])) + w = nn.Variable.from_numpy_array(np.array([[1, 1, 1], [10, 10, 10]])) + with nn.context_scope(ctx), nn.auto_forward(): + y = F.random_choice(x, w, shape=(10,), replace=True, seed=seed) + assert y.shape == (2, 10) and np.all(y.d[0] > 0) and np.all(y.d[1] < 0) + + return + x = nn.Variable((3, 3), need_grad=True) + w = nn.Variable((3, 3), need_grad=True) + w.d = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + with nn.context_scope(ctx), nn.auto_forward(True): + y = F.random_choice(x, w, shape=[10], replace=True, seed=seed) + x.g = w.g = 0 + y.backward(1) + assert np.all(x.g == np.array([[10, 0, 0], [0, 10, 0], [0, 0, 10]])) + assert np.all(w.g == np.array([[10, 0, 0], [0, 10, 0], [0, 0, 10]])) + + +@pytest.mark.parametrize("ctx, func_name", ctxs) +@pytest.mark.parametrize("seed", [-1, 313, 999]) +def test_random_choice_without_replacement(ctx, func_name, seed): + x = nn.Variable.from_numpy_array(np.array([0, 1, 2]).astype(np.int32)) + w = nn.Variable.from_numpy_array(np.array([5, 5, 90]).astype(np.int32)) + x.need_grad = True + w.need_grad = True + repeats = 1000 + with nn.context_scope(ctx): + y = F.random_choice(x, w, shape=[w.size], replace=False, seed=seed) + r = np.zeros((repeats, w.size)).astype(np.int32) + for i in range(repeats): + y.forward() + r[i] = y.d + assert np.all(np.bincount(r.flatten()) == x.size * [repeats]) diff --git a/src/nbla/function/generic/random_choice.cpp b/src/nbla/function/generic/random_choice.cpp new file mode 100644 index 000000000..9d1e80a43 --- /dev/null +++ b/src/nbla/function/generic/random_choice.cpp @@ -0,0 +1,173 @@ +// Copyright (c) 2017 Sony Corporation. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include + +namespace nbla { + +NBLA_REGISTER_FUNCTION_SOURCE(RandomChoice, const vector &, bool, int); + +template +void RandomChoice::setup_impl(const Variables &inputs, + const Variables &outputs) { + NBLA_CHECK(inputs[0]->shape() == inputs[1]->shape(), error_code::value, + "Dimensions of inputs must match. " + "inputs[0]: %s != inputs[1]: %s.", + string_join(inputs[0]->shape(), string(", ")).c_str(), + string_join(inputs[1]->shape(), string(", ")).c_str()); + + Shape_t ishape(inputs[0]->shape()); + Shape_t oshape(ishape.begin(), ishape.end() - 1); + + if (shape_.size() > 0) + oshape.insert(oshape.end(), shape_.begin(), shape_.end()); + else + oshape.push_back(1); + + outer_loop_ = ndi::outer_size(oshape, ishape.size() - 1); + inner_loop_ = ndi::inner_size(oshape, ishape.size() - 1); + + if (replace_ == false) { + NBLA_CHECK(inner_loop_ <= ishape.back(), error_code::value, + "Can not sample more values than population without replacement." + " product of shape %d > last dim of inputs %d", + inner_loop_, ishape.back()); + } + + idxbuf_.reshape(oshape, true); + outputs[0]->reshape(oshape, true); + + rgen_ = std::mt19937((seed_ == -1 ? std::random_device()() : seed_)); +} + +template +void RandomChoice::forward_impl(const Variables &inputs, + const Variables &outputs) { + using std::uniform_real_distribution; + using std::partial_sum; + using std::count_if; + using std::vector; + + auto x_data = inputs[0]->get_data_pointer(this->ctx_); + auto w_data = inputs[1]->get_data_pointer(this->ctx_); + auto y_data = outputs[0]->cast_data_and_get_pointer(this->ctx_, true); + auto idxbuf = idxbuf_.cast_data_and_get_pointer(this->ctx_, true); + auto w_size = inputs[0]->shape().back(); // size of each weight vector + auto less_0 = std::bind(std::less(), std::placeholders::_1, (T)0); + + if (replace_ == true) { + vector w_sum(w_size); + for (int b = 0; b < this->outer_loop_; b++) { + NBLA_CHECK(std::none_of(w_data, w_data + w_size, less_0), + error_code::value, "Negative weights are not allowed."); + partial_sum(w_data, w_data + w_size, w_sum.begin()); + NBLA_CHECK(w_sum.back() > (T)0, error_code::value, + "At least one weight must be greater zero.") + uniform_real_distribution<> uniform(0, w_sum.back()); + for (int i = 0; i < this->inner_loop_; i++) { + T u = uniform(this->rgen_); + auto index = w_size - 1; + for (int i = 0; i < w_size; i++) { + if (u < w_sum[i]) { + index = i; + break; + } + } + *y_data++ = x_data[index]; + *idxbuf++ = index; + } + w_data += w_size; + x_data += w_size; + } + } else { + vector w_vec(w_size), w_sum(w_size); + for (int b = 0; b < this->outer_loop_; b++) { + auto greater_zero = [](T v) { return v > (T)0; }; + auto positive_weights = count_if(w_data, w_data + w_size, greater_zero); + NBLA_CHECK(positive_weights >= this->inner_loop_, error_code::value, + "insufficient positive weights for sampling w/o replacement"); + w_vec.assign(w_data, w_data + w_size); + int have = 0, need = this->inner_loop_; + while (have < this->inner_loop_) { + partial_sum(w_vec.begin(), w_vec.end(), w_sum.begin()); + uniform_real_distribution<> uniform(0, w_sum.back()); + while (need--) { + T u = uniform(this->rgen_); + for (int i = 0; i < w_size; i++) { + if (u < w_sum[i]) { + if (w_vec[i] > 0) { + *y_data++ = x_data[i]; + *idxbuf++ = i; + w_vec[i] = 0; + have++; + } + break; + } + } + } + need = this->inner_loop_ - have; + } + w_data += w_size; + x_data += w_size; + } + } +} + +template +void RandomChoice::backward_impl(const Variables &inputs, + const Variables &outputs, + const vector &propagate_down, + const vector &accum) { + if (!(propagate_down[0] || propagate_down[1])) { + return; + } + + if ((propagate_down[0]) && (!accum[0])) + inputs[0]->grad()->zero(); + + if ((propagate_down[1]) && (!accum[1])) + inputs[1]->grad()->zero(); + + auto w_size = inputs[0]->shape().back(); + + if (propagate_down[0]) { + auto x_grad = inputs[0]->cast_grad_and_get_pointer(this->ctx_, false); + auto y_grad = outputs[0]->get_grad_pointer(this->ctx_); + auto idxbuf = idxbuf_.get_data_pointer(this->ctx_); + for (int b = 0; b < this->outer_loop_; b++) { + for (int i = 0; i < this->inner_loop_; i++) { + x_grad[*idxbuf++] += *y_grad++; + } + x_grad += w_size; + } + } + + if (propagate_down[1]) { + auto w_grad = inputs[1]->cast_grad_and_get_pointer(this->ctx_, false); + auto y_grad = outputs[0]->get_grad_pointer(this->ctx_); + auto idxbuf = idxbuf_.get_data_pointer(this->ctx_); + for (int b = 0; b < this->outer_loop_; b++) { + for (int i = 0; i < this->inner_loop_; i++) { + w_grad[*idxbuf++] += *y_grad++; + } + w_grad += w_size; + } + } +} +}