Skip to content

Commit

Permalink
Cpu lstm inference (apache#9977)
Browse files Browse the repository at this point in the history
* fix autograd import path

* cpu lstm working

* remove fatal log

* add simple unittest
remove redundant log
enable openmp

* fused input2hidden gemm

* fix lint

* fix pylint

* fix windows build error

* fix gluon rnn interface

* Update dataloader.py

* address cr

* address cr

* fix import

* revert some cosmetic change

* fix typo

* remove newline

* rm virtual
mv hardcoded number to constant

* address cr
add tests

* simplify test

* fix test

* fix tests

* change magic number scope
  • Loading branch information
Jerryzcn authored and szha committed Mar 10, 2018
1 parent 4411326 commit 52b5196
Show file tree
Hide file tree
Showing 5 changed files with 251 additions and 7 deletions.
14 changes: 9 additions & 5 deletions python/mxnet/gluon/rnn/rnn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from __future__ import print_function
__all__ = ['RNN', 'LSTM', 'GRU']

from ...autograd import is_training
from ... import ndarray
from .. import Block
from . import rnn_cell
Expand Down Expand Up @@ -185,15 +186,17 @@ def forward(self, inputs, states=None):
for i in range(self._dir):
self.i2h_weight[i].shape = (self._gates*self._hidden_size, inputs.shape[2])
self.i2h_weight[i]._finish_deferred_init()
if inputs.context.device_type == 'gpu':
out = self._forward_gpu(inputs, states)
if inputs.context.device_type == 'gpu' or \
(not is_training() and self._mode == 'lstm'):
out = self._forward_kernel(inputs, states)
else:
out = self._forward_cpu(inputs, states)
out = self._forward(inputs, states)

# out is (output, state)
return out[0] if skip_states else out

def _forward_cpu(self, inputs, states):
def _forward(self, inputs, states):
"""forward using gluon cell"""
ns = len(states)
axis = self._layout.find('T')
states = sum(zip(*((j for j in i) for i in states)), ())
Expand All @@ -207,7 +210,8 @@ def _forward_cpu(self, inputs, states):

return outputs, new_states

def _forward_gpu(self, inputs, states):
def _forward_kernel(self, inputs, states):
""" forward using CUDNN or CPU kenrel"""
if self._layout == 'NTC':
inputs = ndarray.swapaxes(inputs, dim1=0, dim2=1)
ctx = inputs.context
Expand Down
210 changes: 210 additions & 0 deletions src/operator/rnn-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@
#include <vector>
#include <string>
#include <utility>
#include "./math.h"
#include "./math_functions-inl.h"
#include "./operator_common.h"
#include "./mshadow_op.h"
#include "./linalg.h"

namespace mxnet {
namespace op {
Expand Down Expand Up @@ -153,6 +157,212 @@ class RNNOp : public Operator {
RNNParam param_;
}; // class RNNOp

template<typename DType>
class RNNOp<cpu, DType> : public Operator {
public:
explicit RNNOp(RNNParam param) {
this->param_ = param;
// RNN Mode
param_.lstm_q_ = false;
switch (param_.mode) {
case rnn_enum::kLstm:
param_.lstm_q_ = true;
break;
default:
LOG(FATAL) << "only LSTM is implmented on CPU";
}
}

virtual void Forward(const OpContext &ctx,
const std::vector<TBlob> &in_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &out_data,
const std::vector<TBlob> &aux_args) {
// Layout TNC
CHECK(!ctx.is_train) << "only inference mode is available"
"for cpu at the moment.";
size_t in_expected = param_.lstm_q_ ? 4 : 3;
size_t out_expected = param_.lstm_q_ ? 3 : 2;

if (!param_.state_outputs)
LOG(FATAL) << "no state outputs is currently not supported for cpu.";

CHECK_EQ(req[rnn_enum::kOut], kWriteTo);
CHECK_EQ(in_data.size(), in_expected);
CHECK_EQ(out_data.size(), out_expected);

mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
// get input + output tensors
// w layout i2h_w, h2h_w, i2h_b, h2h_b
Tensor<cpu, 3, DType> x =
in_data[rnn_enum::kData].get<cpu, 3, DType>(s); // TNC
Tensor<cpu, 1, DType> w = in_data[rnn_enum::kParams].get<cpu, 1, DType>(s);
Tensor<cpu, 3, DType> hx =
in_data[rnn_enum::kState].get<cpu, 3, DType>(s); // LNC
Tensor<cpu, 3, DType> y =
out_data[rnn_enum::kOut].get<cpu, 3, DType>(s); // TNC
int64_t seq_len = x.shape_[0];
int64_t num_layers = hx.shape_[0];
int64_t batch_size = x.shape_[1];
int64_t h_channel = hx.shape_[2];
int64_t in_channel = x.shape_[2];
Tensor<cpu, 2, DType> x_flatten = in_data[rnn_enum::kData]
.get_with_shape<cpu, 2, DType>(
mshadow::Shape2(seq_len * batch_size, in_channel), s); // (T*N)C
Tensor<cpu, 2, DType> y_flatten = out_data[rnn_enum::kOut]
.get_with_shape<cpu, 2, DType>(
mshadow::Shape2(
y.shape_[0] * y.shape_[1], y.shape_[2]), s); // (T*N)C

CHECK(x.CheckContiguous());
CHECK(w.CheckContiguous());
CHECK(hx.CheckContiguous());
CHECK(y.CheckContiguous());

if (param_.lstm_q_) {
const size_t kNumMat = 4;
int64_t fused_h_ch = kNumMat * h_channel;
int64_t h_size = batch_size * fused_h_ch;
int64_t num_dir = 1 + param_.bidirectional;
int64_t h2h_w_size = h_channel * fused_h_ch;

Tensor<cpu, 3, DType> cx =
in_data[rnn_enum::kStateCell].get<cpu, 3, DType>(s);
CHECK(cx.CheckContiguous());

Tensor<cpu, 3, DType> cy =
out_data[rnn_enum::kStateCellOut].get<cpu, 3, DType>(s);
Tensor<cpu, 3, DType> hy =
out_data[rnn_enum::kStateOut].get<cpu, 3, DType>(s);
CHECK(cy.CheckContiguous());
CHECK(hy.CheckContiguous());

DType* workspace_addr =
static_cast<DType *>(ctx.requested[rnn_enum::kTempSpace]
.get_host_space_internal(sizeof(DType) *
(seq_len * h_size + h_size
+ y.shape_[0] * y.shape_[1] * y.shape_[2])));
Tensor<cpu, 3, DType> i2h_y(
workspace_addr, mshadow::Shape3(seq_len, batch_size, fused_h_ch));
Tensor<cpu, 2, DType> i2h_y_flatten(
workspace_addr, mshadow::Shape2(seq_len * batch_size, fused_h_ch));
Tensor<cpu, 2, DType> h2h_y(workspace_addr
+ seq_len * h_size, mshadow::Shape2(batch_size, fused_h_ch));
Tensor<cpu, 3, DType> y_tmp(workspace_addr
+ (seq_len + 1) * h_size, y.shape_);
Tensor<cpu, 2, DType> y_flatten_tmp(workspace_addr
+ (seq_len + 1) * h_size, y_flatten.shape_);
CHECK(i2h_y.CheckContiguous());
CHECK(h2h_y.CheckContiguous());
CHECK(y_tmp.CheckContiguous());

for (int64_t layer = 0; layer < num_layers; layer++) {
int reverse_dir = 0;
int out_tmp = 0;
if (param_.bidirectional && layer % 2)
reverse_dir = 1;
if (layer / num_dir % 2 == 0)
out_tmp = 1;
mshadow::Shape<2> i2h_w_shape = mshadow::Shape2(fused_h_ch,
(layer < num_dir) ? in_channel : num_dir * h_channel);
mshadow::Shape<2> h2h_w_shape = mshadow::Shape2(fused_h_ch, h_channel);
int64_t start = layer < num_dir ?
(layer * (in_channel * fused_h_ch + h2h_w_size)) : // input layer
(num_dir * (in_channel * fused_h_ch + h2h_w_size)
+ (layer - num_dir) * (h2h_w_size * num_dir + h2h_w_size));
Tensor<cpu, 2, DType> i2h_w(w.dptr_ + start, i2h_w_shape);
start += layer < num_dir ?
in_channel * fused_h_ch : h2h_w_size * num_dir;
Tensor<cpu, 2, DType> h2h_w(w.dptr_ + start, h2h_w_shape);
start = num_dir * (in_channel * fused_h_ch + h2h_w_size)
+ (num_layers - num_dir) * (h2h_w_size * (num_dir + 1))
+ layer * fused_h_ch * 2;
Tensor<cpu, 1, DType> i2h_b = w.Slice(start, start + fused_h_ch);
start += fused_h_ch;
Tensor<cpu, 1, DType> h2h_b = w.Slice(start, start + fused_h_ch);
if (out_tmp) {
linalg_gemm(layer < num_dir ? x_flatten:y_flatten, i2h_w,
i2h_y_flatten, false, true, s);
} else {
linalg_gemm(layer < num_dir ? x_flatten:y_flatten_tmp, i2h_w,
i2h_y_flatten, false, true, s);
}
i2h_y_flatten += repmat(i2h_b, seq_len * batch_size);
for (int64_t t = 0; t < seq_len; t++) {
int64_t timestep = t;
if (reverse_dir)
timestep = seq_len - 1 - t;
linalg_gemm(t == 0 ? hx[layer]:hy[layer], h2h_w, h2h_y,
false, true, s);
h2h_y += repmat(h2h_b, batch_size);
// fused element-wise ops
LSTMFusedElementWiseCPUOps(i2h_y[timestep], cx[layer], h2h_y,
y[timestep], out_tmp ? y_tmp[timestep]: y[timestep],
hy[layer], cy[layer], batch_size, h_channel, t,
reverse_dir, out_tmp && (layer == num_layers - 1));
}
}
} else {
LOG(FATAL) << "only LSTM is available for cpu at the moment.";
}
}

virtual void Backward(const OpContext &ctx,
const std::vector<TBlob> &out_grad,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &in_grad,
const std::vector<TBlob> &aux_args) {
LOG(FATAL) << "LSTM backward is not available for cpu at the moment.";
}

private:
RNNParam param_;

void LSTMFusedElementWiseCPUOps(const Tensor<cpu, 2, DType> &i2h_y,
const Tensor<cpu, 2, DType> &cx,
const Tensor<cpu, 2, DType> &h2h_y,
const Tensor<cpu, 2, DType> &y,
// holding intermediate layer output
const Tensor<cpu, 2, DType> &tmp,
const Tensor<cpu, 2, DType> &hy,
const Tensor<cpu, 2, DType> &cy,
const int64_t batch_size,
const int64_t h_channel,
const int64_t t,
const int reverse_dir,
const int copy_tmp2y) {
int64_t length = batch_size * h_channel;
#pragma omp parallel for
for (int64_t ji = 0; ji < length; ++ji) {
int64_t j = ji / h_channel; // batch dim
int64_t i = ji % h_channel;
int64_t f = i + h_channel;
int64_t c = i + h_channel * 2;
int64_t o = i + h_channel * 3;
int64_t j_pos = j * h_channel * 4;
h2h_y.dptr_[j_pos + i] += i2h_y.dptr_[j_pos + i];
h2h_y.dptr_[j_pos + f] += i2h_y.dptr_[j_pos + f];
h2h_y.dptr_[j_pos + o] += i2h_y.dptr_[j_pos + o];
h2h_y.dptr_[j_pos + c] += i2h_y.dptr_[j_pos + c];
h2h_y.dptr_[j_pos + i] = 1.0f / (1.0f + math::exp(-h2h_y.dptr_[j_pos + i]));
h2h_y.dptr_[j_pos + f] = 1.0f / (1.0f + math::exp(-h2h_y.dptr_[j_pos + f]));
h2h_y.dptr_[j_pos + o] = 1.0f / (1.0f + math::exp(-h2h_y.dptr_[j_pos + o]));
h2h_y.dptr_[j_pos + c] = tanh(h2h_y.dptr_[j_pos + c]);
cy[j][i] = h2h_y.dptr_[j_pos + f] * (t == 0 ? cx[j][i]:cy[j][i])
+ h2h_y.dptr_[j_pos + i] * h2h_y.dptr_[j_pos + c];
hy[j][i] = h2h_y.dptr_[j_pos + o] * tanh(cy[j][i]);
tmp[j][i + h_channel * reverse_dir] = hy[j][i];
if (copy_tmp2y) {
y[j][i] = tmp[j][i];
if (reverse_dir)
y[j][i + h_channel] = tmp[j][i + h_channel];
}
}
}
}; // class RNNOp

template<typename xpu>
Operator* CreateOp(RNNParam param, int dtype);

Expand Down
1 change: 0 additions & 1 deletion src/operator/rnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ namespace mxnet {
namespace op {
template<>
Operator *CreateOp<cpu>(RNNParam param, int dtype) {
LOG(FATAL) << "RNN is only available for gpu at the moment.";
Operator *op = NULL;
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
op = new RNNOp<cpu, DType>(param);
Expand Down
19 changes: 18 additions & 1 deletion tests/python/gpu/test_operator_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1523,6 +1523,23 @@ def check_rnn_layer(layer):
for g, c in zip(gs, cs):
assert_almost_equal(g.asnumpy(), c.asnumpy(), rtol=1e-2, atol=1e-6)

def check_rnn_layer_w_rand_inputs(layer):
layer.collect_params().initialize(ctx=[mx.cpu(0), mx.gpu(0)])
x = mx.nd.uniform(shape=(10, 16, 30))
with mx.gpu(0):
x = x.copyto(mx.gpu(0))
states = layer.begin_state(16)
go, gs = layer(x, states)

with mx.cpu(0):
x = x.copyto(mx.cpu(0))
states = layer.begin_state(16)
co, cs = layer(x, states)

assert_almost_equal(go.asnumpy(), co.asnumpy(), rtol=1e-2, atol=1e-6)
for g, c in zip(gs, cs):
assert_almost_equal(g.asnumpy(), c.asnumpy(), rtol=1e-2, atol=1e-6)

@with_seed()
def test_rnn_layer():
check_rnn_layer(gluon.rnn.RNN(100, num_layers=3))
Expand All @@ -1531,7 +1548,7 @@ def test_rnn_layer():
check_rnn_layer(gluon.rnn.GRU(100, num_layers=3))

check_rnn_layer(gluon.rnn.LSTM(100, num_layers=3, bidirectional=True))

check_rnn_layer_w_rand_inputs(gluon.rnn.LSTM(100, num_layers=3, bidirectional=True))

@with_seed()
def test_sequence_reverse():
Expand Down
14 changes: 14 additions & 0 deletions tests/python/unittest/test_gluon_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,20 @@ def test_lstm_forget_bias():
forget_bias * np.ones(100, ), np.zeros((2 * 100,))])
assert_allclose(mod.get_params()[0][bias_argument].asnumpy(), expected_bias)

def test_lstm_cpu_inference():
# should behave the same as lstm cell
EXPECTED_LSTM_OUTPUT = np.array([[[0.72045636, 0.72045636, 0.95215213, 0.95215213],
[0.72045636, 0.72045636, 0.95215213, 0.95215213]],
[[0.95215213, 0.95215213, 0.72045636, 0.72045636],
[0.95215213, 0.95215213, 0.72045636, 0.72045636]]])
x = mx.nd.ones(shape=(2, 2, 2))
model = mx.gluon.rnn.LSTM(2, num_layers=6, bidirectional=True)
model.initialize(mx.init.One())
y = model(x).asnumpy()

mx.test_utils.assert_almost_equal(y, EXPECTED_LSTM_OUTPUT,
rtol=1e-3, atol=1e-5)


def test_gru():
cell = gluon.rnn.GRUCell(100, prefix='rnn_')
Expand Down

0 comments on commit 52b5196

Please sign in to comment.