-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Cpu lstm inference #9977
Cpu lstm inference #9977
Changes from 12 commits
391f5a3
9e48c78
3be573f
1e8a7e7
6f209d2
7e930cb
19c85aa
7c84239
73a632b
417552f
3f06618
e7e67af
df2f836
e81b9ce
b8ca9c8
9b919af
110010d
f346598
9de6bf6
f41d8ef
a8cda1a
f051293
cffd778
6e2134e
d065b10
ef1e19d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,7 +34,11 @@ | |
#include <vector> | ||
#include <string> | ||
#include <utility> | ||
#include "./math.h" | ||
#include "./math_functions-inl.h" | ||
#include "./operator_common.h" | ||
#include "./mshadow_op.h" | ||
#include "./linalg.h" | ||
|
||
namespace mxnet { | ||
namespace op { | ||
|
@@ -78,10 +82,12 @@ inline int rnn_param_size(int layerNum, | |
int size = rnn_single_param_size(inputSize, hiddenSize, mode); | ||
// get size of remaining layers | ||
if (bidirectional) { | ||
size += (layerNum - 1) * rnn_single_param_size(2 * hiddenSize, hiddenSize, mode); | ||
size += (layerNum - 1) * rnn_single_param_size(2 * hiddenSize, | ||
hiddenSize, mode); | ||
size *= 2; | ||
} else { | ||
size += (layerNum - 1) * rnn_single_param_size(hiddenSize, hiddenSize, mode); | ||
size += (layerNum - 1) * rnn_single_param_size(hiddenSize, hiddenSize, | ||
mode); | ||
} | ||
return size; | ||
} | ||
|
@@ -114,7 +120,8 @@ struct RNNParam : public dmlc::Parameter<RNNParam> { | |
|
||
DMLC_DECLARE_FIELD(p).set_default(0.) | ||
.set_range(0, 1) | ||
.describe("Dropout probability, fraction of the input that gets dropped out at training time"); | ||
.describe("Dropout probability, fraction of the input that gets dropped" | ||
"out at training time"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove this change. Length of this line is less than 100. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pkeep_, lstm_q_ are used in cudnn_rnn-inl.h |
||
|
||
DMLC_DECLARE_FIELD(state_outputs).set_default(false) | ||
.describe("Whether to have the states as symbol outputs."); | ||
|
@@ -132,8 +139,6 @@ class RNNOp : public Operator { | |
const std::vector<OpReqType> &req, | ||
const std::vector<TBlob> &out_data, | ||
const std::vector<TBlob> &aux_args) { | ||
using namespace mshadow; | ||
using namespace mshadow::expr; | ||
// TODO(sbodenstein): add MShadow implementation | ||
} | ||
|
||
|
@@ -144,15 +149,224 @@ class RNNOp : public Operator { | |
const std::vector<OpReqType> &req, | ||
const std::vector<TBlob> &in_grad, | ||
const std::vector<TBlob> &aux_args) { | ||
using namespace mshadow; | ||
using namespace mshadow::expr; | ||
// TODO(sbodenstein): add MShadow implementation | ||
} | ||
|
||
private: | ||
RNNParam param_; | ||
}; // class RNNOp | ||
|
||
template<typename DType> | ||
class RNNOp<cpu, DType> : public Operator { | ||
public: | ||
explicit RNNOp(RNNParam param) { | ||
this->param_ = param; | ||
// RNN Mode | ||
switch (param_.mode) { | ||
case rnn_enum::kLstm: | ||
break; | ||
default: | ||
LOG(FATAL) << "only LSTM is implmented on CPU"; | ||
} | ||
if (param_.mode == rnn_enum::kLstm) | ||
param_.lstm_q_ = true; | ||
else | ||
param_.lstm_q_ = false; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it seems this check can be merged to the switch case statement above. |
||
} | ||
|
||
virtual void Forward(const OpContext &ctx, | ||
const std::vector<TBlob> &in_data, | ||
const std::vector<OpReqType> &req, | ||
const std::vector<TBlob> &out_data, | ||
const std::vector<TBlob> &aux_args) { | ||
// Layout TNC | ||
|
||
size_t in_expected = param_.lstm_q_ ? 4 : 3; | ||
size_t out_expected = param_.lstm_q_ ? 3 : 2; | ||
|
||
if (!param_.state_outputs) | ||
LOG(FATAL) << "no state outputs is currently not supported for cpu."; | ||
|
||
CHECK_EQ(req[rnn_enum::kOut], kWriteTo); | ||
CHECK_EQ(in_data.size(), in_expected); | ||
CHECK_EQ(out_data.size(), out_expected); | ||
|
||
mshadow::Stream<cpu> *s = ctx.get_stream<cpu>(); | ||
// get input + output tensors | ||
// w layout i2h_w, h2h_w, i2h_b, h2h_b | ||
Tensor<cpu, 3, DType> x = | ||
in_data[rnn_enum::kData].get<cpu, 3, DType>(s); // TNC | ||
Tensor<cpu, 1, DType> w = in_data[rnn_enum::kParams].get<cpu, 1, DType>(s); | ||
Tensor<cpu, 3, DType> hx = | ||
in_data[rnn_enum::kState].get<cpu, 3, DType>(s); // LNC | ||
Tensor<cpu, 3, DType> y = | ||
out_data[rnn_enum::kOut].get<cpu, 3, DType>(s); // TNC | ||
int64_t seq_len = x.shape_[0]; | ||
int64_t num_layers = hx.shape_[0]; | ||
int64_t batch_size = x.shape_[1]; | ||
int64_t h_channel = hx.shape_[2]; | ||
int64_t in_channel = x.shape_[2]; | ||
Tensor<cpu, 2, DType> x_flatten = in_data[rnn_enum::kData] | ||
.get_with_shape<cpu, 2, DType>( | ||
mshadow::Shape2(seq_len * batch_size, in_channel), s); // (T*N)C | ||
Tensor<cpu, 2, DType> y_flatten = out_data[rnn_enum::kOut] | ||
.get_with_shape<cpu, 2, DType>( | ||
mshadow::Shape2( | ||
y.shape_[0] * y.shape_[1], y.shape_[2]), s); // (T*N)C | ||
|
||
CHECK_EQ(x.CheckContiguous(), true); | ||
CHECK_EQ(w.CheckContiguous(), true); | ||
CHECK_EQ(hx.CheckContiguous(), true); | ||
CHECK_EQ(y.CheckContiguous(), true); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. CHECK(x.CheckContiguous()); |
||
|
||
if (ctx.is_train) | ||
LOG(FATAL) << "only inference mode is available for cpu at the moment."; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you can do CHECK(!ctx.is_train) << "..." |
||
if (param_.lstm_q_) { | ||
const size_t kNumMat = 4; | ||
int64_t fused_h_ch = kNumMat * h_channel; | ||
int64_t h_size = batch_size * fused_h_ch; | ||
int64_t num_dir = 1 + param_.bidirectional; | ||
int64_t h2h_w_size = h_channel * fused_h_ch; | ||
|
||
Tensor<cpu, 3, DType> cx = | ||
in_data[rnn_enum::kStateCell].get<cpu, 3, DType>(s); | ||
CHECK_EQ(cx.CheckContiguous(), true); | ||
|
||
Tensor<cpu, 3, DType> cy = | ||
out_data[rnn_enum::kStateCellOut].get<cpu, 3, DType>(s); | ||
Tensor<cpu, 3, DType> hy = | ||
out_data[rnn_enum::kStateOut].get<cpu, 3, DType>(s); | ||
CHECK_EQ(cy.CheckContiguous(), true); | ||
CHECK_EQ(hy.CheckContiguous(), true); | ||
|
||
DType* workspace_addr = | ||
static_cast<DType *>(ctx.requested[rnn_enum::kTempSpace] | ||
.get_host_space_internal(sizeof(DType) * | ||
(seq_len * h_size + h_size | ||
+ y.shape_[0] * y.shape_[1] * y.shape_[2]))); | ||
Tensor<cpu, 3, DType> i2h_y( | ||
workspace_addr, mshadow::Shape3(seq_len, batch_size, fused_h_ch)); | ||
Tensor<cpu, 2, DType> i2h_y_flatten( | ||
workspace_addr, mshadow::Shape2(seq_len * batch_size, fused_h_ch)); | ||
Tensor<cpu, 2, DType> h2h_y(workspace_addr | ||
+ seq_len * h_size, mshadow::Shape2(batch_size, fused_h_ch)); | ||
Tensor<cpu, 3, DType> y_tmp(workspace_addr | ||
+ (seq_len + 1) * h_size, y.shape_); | ||
Tensor<cpu, 2, DType> y_flatten_tmp(workspace_addr | ||
+ (seq_len + 1) * h_size, y_flatten.shape_); | ||
CHECK_EQ(i2h_y.CheckContiguous(), true); | ||
CHECK_EQ(h2h_y.CheckContiguous(), true); | ||
CHECK_EQ(y_tmp.CheckContiguous(), true); | ||
|
||
for (int64_t layer = 0; layer < num_layers; layer++) { | ||
int reverse_dir = 0; | ||
int out_tmp = 0; | ||
if (param_.bidirectional && layer % 2) | ||
reverse_dir = 1; | ||
if (layer / num_dir % 2 == 0) | ||
out_tmp = 1; | ||
mshadow::Shape<2> i2h_w_shape = mshadow::Shape2(fused_h_ch, | ||
(layer < num_dir) ? in_channel : num_dir * h_channel); | ||
mshadow::Shape<2> h2h_w_shape = mshadow::Shape2(fused_h_ch, h_channel); | ||
int64_t start = layer < num_dir ? | ||
(layer * (in_channel * fused_h_ch + h2h_w_size)) : // input layer | ||
(num_dir * (in_channel * fused_h_ch + h2h_w_size) | ||
+ (layer - num_dir) * (h2h_w_size * num_dir + h2h_w_size)); | ||
Tensor<cpu, 2, DType> i2h_w(w.Slice(start, start + (layer < num_dir ? | ||
(in_channel * fused_h_ch) : num_dir * h2h_w_size)).dptr_, | ||
i2h_w_shape); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why slice? i think |
||
start += layer < num_dir ? | ||
in_channel * fused_h_ch : h2h_w_size * num_dir; | ||
Tensor<cpu, 2, DType> h2h_w(w.Slice(start, start + h2h_w_size).dptr_, | ||
h2h_w_shape); | ||
start = num_dir * (in_channel * fused_h_ch + h2h_w_size) | ||
+ (num_layers - num_dir) * (h2h_w_size * (num_dir + 1)) | ||
+ layer * fused_h_ch * 2; | ||
Tensor<cpu, 1, DType> i2h_b = w.Slice(start, start + fused_h_ch); | ||
start += fused_h_ch; | ||
Tensor<cpu, 1, DType> h2h_b = w.Slice(start, start + fused_h_ch); | ||
if (out_tmp) { | ||
linalg_gemm(layer < num_dir ? x_flatten:y_flatten, i2h_w, | ||
i2h_y_flatten, false, true, s); | ||
} else { | ||
linalg_gemm(layer < num_dir ? x_flatten:y_flatten_tmp, i2h_w, | ||
i2h_y_flatten, false, true, s); | ||
} | ||
i2h_y_flatten += repmat(i2h_b, seq_len * batch_size); | ||
for (int64_t t = 0; t < seq_len; t++) { | ||
int64_t timestep = t; | ||
if (reverse_dir) | ||
timestep = seq_len - 1 - t; | ||
linalg_gemm(t == 0 ? hx[layer]:hy[layer], h2h_w, h2h_y, | ||
false, true, s); | ||
h2h_y += repmat(h2h_b, batch_size); | ||
// fused element-wise ops | ||
LSTMFusedElementWiseCPUOps(i2h_y[timestep], cx[layer], h2h_y, | ||
y[timestep], out_tmp ? y_tmp[timestep]: y[timestep], | ||
hy[layer], cy[layer], batch_size, h_channel, t, | ||
reverse_dir, out_tmp && (layer == num_layers - 1)); | ||
} | ||
} | ||
} else { | ||
LOG(FATAL) << "only LSTM is available for cpu at the moment."; | ||
} | ||
} | ||
|
||
virtual void Backward(const OpContext &ctx, | ||
const std::vector<TBlob> &out_grad, | ||
const std::vector<TBlob> &in_data, | ||
const std::vector<TBlob> &out_data, | ||
const std::vector<OpReqType> &req, | ||
const std::vector<TBlob> &in_grad, | ||
const std::vector<TBlob> &aux_args) { | ||
LOG(FATAL) << "LSTM backward is not available for cpu at the moment."; | ||
} | ||
|
||
private: | ||
RNNParam param_; | ||
|
||
virtual void LSTMFusedElementWiseCPUOps(const Tensor<cpu, 2, DType> &i2h_y, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why |
||
const Tensor<cpu, 2, DType> &cx, | ||
const Tensor<cpu, 2, DType> &h2h_y, | ||
const Tensor<cpu, 2, DType> &y, | ||
// holding intermediate layer output | ||
const Tensor<cpu, 2, DType> &tmp, | ||
const Tensor<cpu, 2, DType> &hy, | ||
const Tensor<cpu, 2, DType> &cy, | ||
const int64_t batch_size, | ||
const int64_t h_channel, | ||
const int64_t t, | ||
const int reverse_dir, | ||
const int copy_tmp2y) { | ||
int64_t ji; | ||
#pragma omp parallel for private(ji) | ||
for (ji = 0; ji < batch_size * h_channel; ji++) { | ||
int64_t j = ji / h_channel; // batch dim | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it ok to write There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. u mean move it out of condition expression? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't need to set
|
||
int64_t i = ji % h_channel; | ||
int64_t f = i + h_channel; | ||
int64_t c = i + h_channel * 2; | ||
int64_t o = i + h_channel * 3; | ||
h2h_y[j][i] += i2h_y[j][i]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Too many overloaded operator There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. tried, but i did not notice any difference in runtime. i think the tensor object probably does not generate new tensor object here. I think multiple There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When you use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. okay. but it seems inside mshadow, all the ops are implemented using multiple There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't have a strong opinion on this. If you could use I think the rule of thumb here is try to avoid temp tensor creation and destruction while keep the code readable. So it's okay to use |
||
h2h_y[j][f] += i2h_y[j][f]; | ||
h2h_y[j][o] += i2h_y[j][o]; | ||
h2h_y[j][c] += i2h_y[j][c]; | ||
h2h_y[j][i] = 1.0f / (1.0f + math::exp(-h2h_y[j][i])); | ||
h2h_y[j][f] = 1.0f / (1.0f + math::exp(-h2h_y[j][f])); | ||
h2h_y[j][o] = 1.0f / (1.0f + math::exp(-h2h_y[j][o])); | ||
h2h_y[j][c] = tanh(h2h_y[j][c]); | ||
cy[j][i] = h2h_y[j][f] * (t == 0 ? cx[j][i]:cy[j][i]) | ||
+ h2h_y[j][i] * h2h_y[j][c]; | ||
hy[j][i] = h2h_y[j][o] * tanh(cy[j][i]); | ||
tmp[j][i + h_channel * reverse_dir] = hy[j][i]; | ||
if (copy_tmp2y) { | ||
y[j][i] = tmp[j][i]; | ||
if (reverse_dir) | ||
y[j][i + h_channel] = tmp[j][i + h_channel]; | ||
} | ||
} | ||
} | ||
}; // class RNNOp | ||
|
||
template<typename xpu> | ||
Operator* CreateOp(RNNParam param, int dtype); | ||
|
||
|
@@ -184,7 +398,8 @@ class RNNProp : public OperatorProperty { | |
return num_outputs; | ||
} | ||
|
||
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override { | ||
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) | ||
override { | ||
param_.Init(kwargs); | ||
} | ||
|
||
|
@@ -195,28 +410,33 @@ class RNNProp : public OperatorProperty { | |
bool InferShape(std::vector<TShape> *in_shape, | ||
std::vector<TShape> *out_shape, | ||
std::vector<TShape> *aux_shape) const override { | ||
using namespace mshadow; | ||
if (param_.mode == rnn_enum::kLstm) { | ||
CHECK_EQ(in_shape->size(), 4U) << "Input:[data, parameters, state, cell_state]"; | ||
CHECK_EQ(in_shape->size(), 4U) << | ||
"Input:[data, parameters, state, cell_state]"; | ||
} else { | ||
CHECK_EQ(in_shape->size(), 3U) << "Input:[data, parameters, state]"; | ||
} | ||
const TShape &dshape = (*in_shape)[rnn_enum::kData]; | ||
if (dshape.ndim() == 0) return false; | ||
CHECK_EQ(dshape.ndim(), 3U) \ | ||
<< "Input data should be rank-3 tensor of dim [sequence length, batch size, input size]"; | ||
CHECK_EQ(dshape.ndim(), 3U) | ||
<< "Input data should be rank-3 tensor of dim [sequence length, " | ||
<< "batch size, input size]"; | ||
// data: [sequence len, batch, input dimension] | ||
int batch_size = dshape[1]; | ||
int input_size = dshape[2]; | ||
int numDirections = param_.bidirectional ? 2 : 1; | ||
int total_layers = numDirections * param_.num_layers; // double for bidirectional | ||
// double for bidirectional | ||
int total_layers = numDirections * param_.num_layers; | ||
|
||
SHAPE_ASSIGN_CHECK(*in_shape, | ||
rnn_enum::kState, | ||
Shape3(total_layers, batch_size, param_.state_size)); | ||
mshadow::Shape3(total_layers, batch_size, | ||
param_.state_size)); | ||
if (param_.mode == rnn_enum::kLstm) | ||
SHAPE_ASSIGN_CHECK(*in_shape, | ||
rnn_enum::kStateCell, | ||
Shape3(total_layers, batch_size, param_.state_size)); | ||
mshadow::Shape3(total_layers, batch_size, | ||
param_.state_size)); | ||
|
||
// calculate parameter vector length | ||
int param_size = rnn_param_size(param_.num_layers, | ||
|
@@ -287,8 +507,9 @@ class RNNProp : public OperatorProperty { | |
const std::vector<int> &out_grad, | ||
const std::vector<int> &in_data, | ||
const std::vector<int> &out_data) const override { | ||
std::vector<int> dep = {in_data[rnn_enum::kData], in_data[rnn_enum::kParams], | ||
in_data[rnn_enum::kState], out_data[rnn_enum::kOut], out_grad[rnn_enum::kOut]}; | ||
std::vector<int> dep = {in_data[rnn_enum::kData], | ||
in_data[rnn_enum::kParams], in_data[rnn_enum::kState], | ||
out_data[rnn_enum::kOut], out_grad[rnn_enum::kOut]}; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i'm not sure why you want to change the code in this function. it seems you just reorganize the code a little bit. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it exceeds 80 char per line limit. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the coding style in mxnet allows up to 100 char per line. |
||
|
||
if (param_.state_outputs) { | ||
dep.push_back(out_data[rnn_enum::kStateOut]); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -67,6 +67,22 @@ def test_lstm_forget_bias(): | |
forget_bias * np.ones(100, ), np.zeros((2 * 100,))]) | ||
assert_allclose(mod.get_params()[0][bias_argument].asnumpy(), expected_bias) | ||
|
||
def test_lstm_cpu_inference(): | ||
# should behave the same as lstm cell | ||
atol = 1e-6 | ||
x = mx.nd.ones(shape=(2, 2, 2)) | ||
model = mx.gluon.nn.Sequential() | ||
with model.name_scope(): | ||
model.add(mx.gluon.rnn.LSTM(2, num_layers=6, bidirectional=True)) | ||
model.initialize(mx.init.One()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you also test the consistency between cpu and gpu, with same random weights and random inputs? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will it break CPU tests? It might be too much an effort There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can use this function to test consistency. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Under the current input and weight, your test would still pass even if the weights are iterated backwards. Unfortunately it's not in an acceptable state. |
||
y = model(x).asnumpy() | ||
|
||
mx.test_utils.assert_almost_equal(y, np.array([[[0.72045636, 0.72045636, 0.95215213, 0.95215213], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are there hardcoded numbers? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/test_utils.py#L1207 is for symbol it seems. or does it support both? |
||
[0.72045636, 0.72045636, 0.95215213, 0.95215213]], | ||
[[0.95215213, 0.95215213, 0.72045636, 0.72045636], | ||
[0.95215213, 0.95215213, 0.72045636, 0.72045636]]]), | ||
rtol=1e-3, atol=1e-5) | ||
|
||
|
||
def test_gru(): | ||
cell = gluon.rnn.GRUCell(100, prefix='rnn_') | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you are just reformatting the code here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes