Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
create multiple copies for kaddto
Browse files Browse the repository at this point in the history
  • Loading branch information
azai91 committed Jun 26, 2018
1 parent 637c76a commit 8b909d5
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions tests/cpp/operator/mkldnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -697,9 +697,14 @@ void VerifyAddRequest(const std::vector<NDArray*> &in_arrs,
const std::vector<NDArray*> &original_outputs,
const std::vector<NDArray*> &new_outputs,
VerifyFunc verify_fn) {
NDArray tmp = new_outputs[0]->Reorder2Default() - original_outputs[0]->Reorder2Default();
tmp.WaitToRead();
verify_fn(in_arrs, {&tmp});
CHECK(original_outputs.size() == new_outputs.size());
std::vector<NDArray*> tmp_outputs;
for (size_t i = 0; i < new_outputs.size(); i++) {
NDArray tmp = new_outputs[i]->Reorder2Default() - original_outputs[i]->Reorder2Default();
tmp_outputs.push_back(&tmp);
}
Engine::Get()->WaitForAll();
verify_fn(in_arrs, tmp_outputs);
}

void PrintVerifyMsg(const NDArrayAttrs &arr1, const NDArrayAttrs &arr2) {
Expand Down Expand Up @@ -796,6 +801,7 @@ void TestOp(const OpAttrs &attrs, VerifyFunc verify_fn) {
}

if (attrs.requests.find(OpReqType::kAddTo) != attrs.requests.end()) {
std::vector<NDArray*> original_outputs(attrs.num_outputs);
in_arrs = GetTestInputArrays();
for (auto in_arr : in_arrs) {
for (auto dispatch : dispatches) {
Expand All @@ -804,17 +810,18 @@ void TestOp(const OpAttrs &attrs, VerifyFunc verify_fn) {
for (size_t i = 0; i < attrs.num_inputs; i++)
inputs[i] = &in_arr.arr;
for (size_t output_i = 0; output_i < out_arrs[0].size(); output_i++) {
auto out_arr = out_arrs[0][output_i];
NDArray orig_output = out_arr.arr.Copy(out_arr.arr.ctx());
for (size_t i = 0; i < attrs.num_outputs; i++) {
auto out_arr = out_arrs[i][output_i];
NDArray tmp = out_arr.arr.Copy(out_arr.arr.ctx());
original_outputs[i] = &tmp;
outputs[i] = &out_arrs[i][output_i].arr;
req[i] = kAddTo;
}
PrintVerifyMsg(in_arr, out_arr);
PrintVerifyMsg(in_arr, out_arrs[0][output_i]);
Imperative::Get()->InvokeOp(Context(), attrs.attrs, inputs,
outputs, req, dispatch, mxnet::OpStatePtr());
Engine::Get()->WaitForAll();
VerifyAddRequest(inputs, {&orig_output}, outputs, verify_fn);
VerifyAddRequest(inputs, original_outputs, outputs, verify_fn);
}
}
}
Expand Down

0 comments on commit 8b909d5

Please sign in to comment.