Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Cpu lstm inference #9977

Merged
merged 26 commits into from
Mar 10, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions python/mxnet/gluon/rnn/rnn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from __future__ import print_function
__all__ = ['RNN', 'LSTM', 'GRU']

from ...autograd import is_training
from ... import ndarray
from .. import Block
from . import rnn_cell
Expand Down Expand Up @@ -185,15 +186,17 @@ def forward(self, inputs, states=None):
for i in range(self._dir):
self.i2h_weight[i].shape = (self._gates*self._hidden_size, inputs.shape[2])
self.i2h_weight[i]._finish_deferred_init()
if inputs.context.device_type == 'gpu':
out = self._forward_gpu(inputs, states)
if inputs.context.device_type == 'gpu' or \
(not is_training() and self._mode == 'lstm'):
out = self._forward_kernel(inputs, states)
else:
out = self._forward_cpu(inputs, states)
out = self._forward(inputs, states)

# out is (output, state)
return out[0] if skip_states else out

def _forward_cpu(self, inputs, states):
def _forward(self, inputs, states):
"""forward using gluon cell"""
ns = len(states)
axis = self._layout.find('T')
states = sum(zip(*((j for j in i) for i in states)), ())
Expand All @@ -207,7 +210,8 @@ def _forward_cpu(self, inputs, states):

return outputs, new_states

def _forward_gpu(self, inputs, states):
def _forward_kernel(self, inputs, states):
""" forward using CUDNN or CPU kenrel"""
if self._layout == 'NTC':
inputs = ndarray.swapaxes(inputs, dim1=0, dim2=1)
ctx = inputs.context
Expand Down
210 changes: 210 additions & 0 deletions src/operator/rnn-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@
#include <vector>
#include <string>
#include <utility>
#include "./math.h"
#include "./math_functions-inl.h"
#include "./operator_common.h"
#include "./mshadow_op.h"
#include "./linalg.h"

namespace mxnet {
namespace op {
Expand Down Expand Up @@ -153,6 +157,212 @@ class RNNOp : public Operator {
RNNParam param_;
}; // class RNNOp

template<typename DType>
class RNNOp<cpu, DType> : public Operator {
public:
explicit RNNOp(RNNParam param) {
this->param_ = param;
// RNN Mode
param_.lstm_q_ = false;
switch (param_.mode) {
case rnn_enum::kLstm:
param_.lstm_q_ = true;
break;
default:
LOG(FATAL) << "only LSTM is implmented on CPU";
}
}

virtual void Forward(const OpContext &ctx,
const std::vector<TBlob> &in_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &out_data,
const std::vector<TBlob> &aux_args) {
// Layout TNC
CHECK(!ctx.is_train) << "only inference mode is available"
"for cpu at the moment.";
size_t in_expected = param_.lstm_q_ ? 4 : 3;
size_t out_expected = param_.lstm_q_ ? 3 : 2;

if (!param_.state_outputs)
LOG(FATAL) << "no state outputs is currently not supported for cpu.";

CHECK_EQ(req[rnn_enum::kOut], kWriteTo);
CHECK_EQ(in_data.size(), in_expected);
CHECK_EQ(out_data.size(), out_expected);

mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
// get input + output tensors
// w layout i2h_w, h2h_w, i2h_b, h2h_b
Tensor<cpu, 3, DType> x =
in_data[rnn_enum::kData].get<cpu, 3, DType>(s); // TNC
Tensor<cpu, 1, DType> w = in_data[rnn_enum::kParams].get<cpu, 1, DType>(s);
Tensor<cpu, 3, DType> hx =
in_data[rnn_enum::kState].get<cpu, 3, DType>(s); // LNC
Tensor<cpu, 3, DType> y =
out_data[rnn_enum::kOut].get<cpu, 3, DType>(s); // TNC
int64_t seq_len = x.shape_[0];
int64_t num_layers = hx.shape_[0];
int64_t batch_size = x.shape_[1];
int64_t h_channel = hx.shape_[2];
int64_t in_channel = x.shape_[2];
Tensor<cpu, 2, DType> x_flatten = in_data[rnn_enum::kData]
.get_with_shape<cpu, 2, DType>(
mshadow::Shape2(seq_len * batch_size, in_channel), s); // (T*N)C
Tensor<cpu, 2, DType> y_flatten = out_data[rnn_enum::kOut]
.get_with_shape<cpu, 2, DType>(
mshadow::Shape2(
y.shape_[0] * y.shape_[1], y.shape_[2]), s); // (T*N)C

CHECK(x.CheckContiguous());
CHECK(w.CheckContiguous());
CHECK(hx.CheckContiguous());
CHECK(y.CheckContiguous());

if (param_.lstm_q_) {
const size_t kNumMat = 4;
int64_t fused_h_ch = kNumMat * h_channel;
int64_t h_size = batch_size * fused_h_ch;
int64_t num_dir = 1 + param_.bidirectional;
int64_t h2h_w_size = h_channel * fused_h_ch;

Tensor<cpu, 3, DType> cx =
in_data[rnn_enum::kStateCell].get<cpu, 3, DType>(s);
CHECK(cx.CheckContiguous());

Tensor<cpu, 3, DType> cy =
out_data[rnn_enum::kStateCellOut].get<cpu, 3, DType>(s);
Tensor<cpu, 3, DType> hy =
out_data[rnn_enum::kStateOut].get<cpu, 3, DType>(s);
CHECK(cy.CheckContiguous());
CHECK(hy.CheckContiguous());

DType* workspace_addr =
static_cast<DType *>(ctx.requested[rnn_enum::kTempSpace]
.get_host_space_internal(sizeof(DType) *
(seq_len * h_size + h_size
+ y.shape_[0] * y.shape_[1] * y.shape_[2])));
Tensor<cpu, 3, DType> i2h_y(
workspace_addr, mshadow::Shape3(seq_len, batch_size, fused_h_ch));
Tensor<cpu, 2, DType> i2h_y_flatten(
workspace_addr, mshadow::Shape2(seq_len * batch_size, fused_h_ch));
Tensor<cpu, 2, DType> h2h_y(workspace_addr
+ seq_len * h_size, mshadow::Shape2(batch_size, fused_h_ch));
Tensor<cpu, 3, DType> y_tmp(workspace_addr
+ (seq_len + 1) * h_size, y.shape_);
Tensor<cpu, 2, DType> y_flatten_tmp(workspace_addr
+ (seq_len + 1) * h_size, y_flatten.shape_);
CHECK(i2h_y.CheckContiguous());
CHECK(h2h_y.CheckContiguous());
CHECK(y_tmp.CheckContiguous());

for (int64_t layer = 0; layer < num_layers; layer++) {
int reverse_dir = 0;
int out_tmp = 0;
if (param_.bidirectional && layer % 2)
reverse_dir = 1;
if (layer / num_dir % 2 == 0)
out_tmp = 1;
mshadow::Shape<2> i2h_w_shape = mshadow::Shape2(fused_h_ch,
(layer < num_dir) ? in_channel : num_dir * h_channel);
mshadow::Shape<2> h2h_w_shape = mshadow::Shape2(fused_h_ch, h_channel);
int64_t start = layer < num_dir ?
(layer * (in_channel * fused_h_ch + h2h_w_size)) : // input layer
(num_dir * (in_channel * fused_h_ch + h2h_w_size)
+ (layer - num_dir) * (h2h_w_size * num_dir + h2h_w_size));
Tensor<cpu, 2, DType> i2h_w(w.dptr_ + start, i2h_w_shape);
start += layer < num_dir ?
in_channel * fused_h_ch : h2h_w_size * num_dir;
Tensor<cpu, 2, DType> h2h_w(w.dptr_ + start, h2h_w_shape);
start = num_dir * (in_channel * fused_h_ch + h2h_w_size)
+ (num_layers - num_dir) * (h2h_w_size * (num_dir + 1))
+ layer * fused_h_ch * 2;
Tensor<cpu, 1, DType> i2h_b = w.Slice(start, start + fused_h_ch);
start += fused_h_ch;
Tensor<cpu, 1, DType> h2h_b = w.Slice(start, start + fused_h_ch);
if (out_tmp) {
linalg_gemm(layer < num_dir ? x_flatten:y_flatten, i2h_w,
i2h_y_flatten, false, true, s);
} else {
linalg_gemm(layer < num_dir ? x_flatten:y_flatten_tmp, i2h_w,
i2h_y_flatten, false, true, s);
}
i2h_y_flatten += repmat(i2h_b, seq_len * batch_size);
for (int64_t t = 0; t < seq_len; t++) {
int64_t timestep = t;
if (reverse_dir)
timestep = seq_len - 1 - t;
linalg_gemm(t == 0 ? hx[layer]:hy[layer], h2h_w, h2h_y,
false, true, s);
h2h_y += repmat(h2h_b, batch_size);
// fused element-wise ops
LSTMFusedElementWiseCPUOps(i2h_y[timestep], cx[layer], h2h_y,
y[timestep], out_tmp ? y_tmp[timestep]: y[timestep],
hy[layer], cy[layer], batch_size, h_channel, t,
reverse_dir, out_tmp && (layer == num_layers - 1));
}
}
} else {
LOG(FATAL) << "only LSTM is available for cpu at the moment.";
}
}

virtual void Backward(const OpContext &ctx,
const std::vector<TBlob> &out_grad,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &in_grad,
const std::vector<TBlob> &aux_args) {
LOG(FATAL) << "LSTM backward is not available for cpu at the moment.";
}

private:
RNNParam param_;

void LSTMFusedElementWiseCPUOps(const Tensor<cpu, 2, DType> &i2h_y,
const Tensor<cpu, 2, DType> &cx,
const Tensor<cpu, 2, DType> &h2h_y,
const Tensor<cpu, 2, DType> &y,
// holding intermediate layer output
const Tensor<cpu, 2, DType> &tmp,
const Tensor<cpu, 2, DType> &hy,
const Tensor<cpu, 2, DType> &cy,
const int64_t batch_size,
const int64_t h_channel,
const int64_t t,
const int reverse_dir,
const int copy_tmp2y) {
int64_t length = batch_size * h_channel;
#pragma omp parallel for
for (int64_t ji = 0; ji < length; ++ji) {
int64_t j = ji / h_channel; // batch dim
int64_t i = ji % h_channel;
int64_t f = i + h_channel;
int64_t c = i + h_channel * 2;
int64_t o = i + h_channel * 3;
int64_t j_pos = j * h_channel * 4;
h2h_y.dptr_[j_pos + i] += i2h_y.dptr_[j_pos + i];
h2h_y.dptr_[j_pos + f] += i2h_y.dptr_[j_pos + f];
h2h_y.dptr_[j_pos + o] += i2h_y.dptr_[j_pos + o];
h2h_y.dptr_[j_pos + c] += i2h_y.dptr_[j_pos + c];
h2h_y.dptr_[j_pos + i] = 1.0f / (1.0f + math::exp(-h2h_y.dptr_[j_pos + i]));
h2h_y.dptr_[j_pos + f] = 1.0f / (1.0f + math::exp(-h2h_y.dptr_[j_pos + f]));
h2h_y.dptr_[j_pos + o] = 1.0f / (1.0f + math::exp(-h2h_y.dptr_[j_pos + o]));
h2h_y.dptr_[j_pos + c] = tanh(h2h_y.dptr_[j_pos + c]);
cy[j][i] = h2h_y.dptr_[j_pos + f] * (t == 0 ? cx[j][i]:cy[j][i])
+ h2h_y.dptr_[j_pos + i] * h2h_y.dptr_[j_pos + c];
hy[j][i] = h2h_y.dptr_[j_pos + o] * tanh(cy[j][i]);
tmp[j][i + h_channel * reverse_dir] = hy[j][i];
if (copy_tmp2y) {
y[j][i] = tmp[j][i];
if (reverse_dir)
y[j][i + h_channel] = tmp[j][i + h_channel];
}
}
}
}; // class RNNOp

template<typename xpu>
Operator* CreateOp(RNNParam param, int dtype);

Expand Down
1 change: 0 additions & 1 deletion src/operator/rnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ namespace mxnet {
namespace op {
template<>
Operator *CreateOp<cpu>(RNNParam param, int dtype) {
LOG(FATAL) << "RNN is only available for gpu at the moment.";
Operator *op = NULL;
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
op = new RNNOp<cpu, DType>(param);
Expand Down
19 changes: 18 additions & 1 deletion tests/python/gpu/test_operator_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1523,6 +1523,23 @@ def check_rnn_layer(layer):
for g, c in zip(gs, cs):
assert_almost_equal(g.asnumpy(), c.asnumpy(), rtol=1e-2, atol=1e-6)

def check_rnn_layer_w_rand_inputs(layer):
layer.collect_params().initialize(ctx=[mx.cpu(0), mx.gpu(0)])
x = mx.nd.uniform(shape=(10, 16, 30))
with mx.gpu(0):
x = x.copyto(mx.gpu(0))
states = layer.begin_state(16)
go, gs = layer(x, states)

with mx.cpu(0):
x = x.copyto(mx.cpu(0))
states = layer.begin_state(16)
co, cs = layer(x, states)

assert_almost_equal(go.asnumpy(), co.asnumpy(), rtol=1e-2, atol=1e-6)
for g, c in zip(gs, cs):
assert_almost_equal(g.asnumpy(), c.asnumpy(), rtol=1e-2, atol=1e-6)

@with_seed()
def test_rnn_layer():
check_rnn_layer(gluon.rnn.RNN(100, num_layers=3))
Expand All @@ -1531,7 +1548,7 @@ def test_rnn_layer():
check_rnn_layer(gluon.rnn.GRU(100, num_layers=3))

check_rnn_layer(gluon.rnn.LSTM(100, num_layers=3, bidirectional=True))

check_rnn_layer_w_rand_inputs(gluon.rnn.LSTM(100, num_layers=3, bidirectional=True))

@with_seed()
def test_sequence_reverse():
Expand Down
14 changes: 14 additions & 0 deletions tests/python/unittest/test_gluon_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,20 @@ def test_lstm_forget_bias():
forget_bias * np.ones(100, ), np.zeros((2 * 100,))])
assert_allclose(mod.get_params()[0][bias_argument].asnumpy(), expected_bias)

def test_lstm_cpu_inference():
# should behave the same as lstm cell
EXPECTED_LSTM_OUTPUT = np.array([[[0.72045636, 0.72045636, 0.95215213, 0.95215213],
[0.72045636, 0.72045636, 0.95215213, 0.95215213]],
[[0.95215213, 0.95215213, 0.72045636, 0.72045636],
[0.95215213, 0.95215213, 0.72045636, 0.72045636]]])
x = mx.nd.ones(shape=(2, 2, 2))
model = mx.gluon.rnn.LSTM(2, num_layers=6, bidirectional=True)
model.initialize(mx.init.One())
y = model(x).asnumpy()

mx.test_utils.assert_almost_equal(y, EXPECTED_LSTM_OUTPUT,
rtol=1e-3, atol=1e-5)


def test_gru():
cell = gluon.rnn.GRUCell(100, prefix='rnn_')
Expand Down