This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-497] fix bugs in MKLDNN operators to handle the kAddTo request #11129
Merged
Merged
Changes from all commits
Commits
Show all changes
69 commits
Select commit
Hold shift + click to select a range
31fdc8b
fix lint
azai91 b644e02
requests added to opattr
azai91 612f64f
comment out addto
azai91 2c4b41e
can invalidate kAddTo request mkldarrays
azai91 2dc646a
revert adding kAddTo to invalidate
azai91 5278ef9
use copy of output instead of creating new array
azai91 3adcd8d
convert output to default if fallback
azai91 2489b86
do not make copy when init
azai91 c7e64f3
copyex fallback copies to old array with kAddTo
azai91 67001ce
change input mem desc to output mem desc if not equal
azai91 5a75f53
reorder memory in commitoutput
azai91 f5b63fc
allocate temp memory
azai91 4d52987
fix var names
azai91 6b62e97
create helper reorder function to handle diff format/shapes
azai91 9da3655
fix typos
azai91 c0c38ca
fix typos
azai91 2338046
remove unused code
azai91 f974c3c
fix param
azai91 918a864
fix header files
azai91 50fc6ca
force input memory to output
azai91 a9915be
reorder2default keeps pointer to mkldnn memory
azai91 630c091
pass reference
azai91 aa6c406
remove extra lines
azai91 75c5160
do not get raw mem from ptr
azai91 f65ea9c
remove isView check
azai91 3483f28
fallback writes back to output
azai91 0428e0f
remove redundant line
azai91 1cdd60c
remove commented out code
azai91 c9e8f85
use fallback in copy (refactor)
azai91 996d0ef
remove unused header
azai91 4532209
fix lint
azai91 410c491
reorder2default only if mkldnn flag
azai91 2efdc3b
only reorder if mkldnn
azai91 dc3cd8d
does not assume 1 output
azai91 ad66611
sum compares input and output shape
azai91 860fa21
compare address and pd in sum
azai91 a727eea
refactor mkldnnsum
azai91 c76aee3
fix const param
azai91 64422aa
fix header
azai91 ac2b3a1
Merge branch 'master' into test-kAddTo
azai91 bb10946
improve control flow when setting output blob
azai91 ad31578
fix merge
azai91 0e03c96
remove kaddto comment
azai91 6ef7b87
add reqests to operators
azai91 90c9acb
fix spacing
azai91 7d0f275
do sum in place
azai91 3edf492
fix conditionals
azai91 5c20e46
remove redundant reqs
azai91 cd70dac
use wait to read all
azai91 0972ffa
fix lint
azai91 637c76a
create multiple outputs
azai91 5718651
create multiple copies for kaddto
azai91 d91df93
retrigger
azai91 993c7aa
retriggrer
azai91 e7d18be
merge
azai91 e2a464d
retrigger
azai91 dc742c8
retrigger
azai91 92c50f0
another retrigger
azai91 eb97f3d
Merge branch 'master' into test-kAddTo
azai91 113903a
retrigger
azai91 ecbde64
retrigger
azai91 be84769
another another retrigger
azai91 5181420
Merge branch 'master' into test-kAddTo
azai91 0731a58
merge
azai91 ad3c70e
fix merge
azai91 2874d0a
retrigger
azai91 0e249f7
merge
azai91 581495f
add kAddto to relu op
azai91 9e7c22e
retrigger
azai91 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,6 +27,7 @@ | |
|
||
#include <cmath> | ||
#include <climits> | ||
#include <set> | ||
#include "gtest/gtest.h" | ||
#include "mxnet/imperative.h" | ||
#include "../../src/operator/nn/mkldnn/mkldnn_base-inl.h" | ||
|
@@ -363,6 +364,7 @@ struct NDArrayAttrs { | |
struct OpAttrs { | ||
nnvm::NodeAttrs attrs; | ||
std::vector<DispatchMode> dispatches; | ||
std::set<OpReqType> requests; | ||
int num_inputs; | ||
int num_outputs; | ||
}; | ||
|
@@ -375,6 +377,9 @@ OpAttrs GetCopyOp() { | |
attrs.dispatches.resize(2); | ||
attrs.dispatches[0] = DispatchMode::kFCompute; | ||
attrs.dispatches[1] = DispatchMode::kFComputeEx; | ||
attrs.requests.insert(OpReqType::kWriteTo); | ||
attrs.requests.insert(OpReqType::kWriteInplace); | ||
attrs.requests.insert(OpReqType::kAddTo); | ||
return attrs; | ||
} | ||
|
||
|
@@ -386,6 +391,9 @@ OpAttrs GetCopyBackwardsOp() { | |
attrs.dispatches.resize(2); | ||
attrs.dispatches[0] = DispatchMode::kFCompute; | ||
attrs.dispatches[1] = DispatchMode::kFComputeEx; | ||
attrs.requests.insert(OpReqType::kWriteTo); | ||
attrs.requests.insert(OpReqType::kWriteInplace); | ||
attrs.requests.insert(OpReqType::kAddTo); | ||
return attrs; | ||
} | ||
|
||
|
@@ -399,6 +407,9 @@ OpAttrs GetReluOp() { | |
attrs.dispatches.resize(2); | ||
attrs.dispatches[0] = DispatchMode::kFCompute; | ||
attrs.dispatches[1] = DispatchMode::kFComputeEx; | ||
attrs.requests.insert(OpReqType::kWriteTo); | ||
attrs.requests.insert(OpReqType::kWriteInplace); | ||
attrs.requests.insert(OpReqType::kAddTo); | ||
return attrs; | ||
} | ||
|
||
|
@@ -412,6 +423,9 @@ OpAttrs GetReluBackwardsOp() { | |
attrs.dispatches.resize(2); | ||
attrs.dispatches[0] = DispatchMode::kFCompute; | ||
attrs.dispatches[1] = DispatchMode::kFComputeEx; | ||
attrs.requests.insert(OpReqType::kWriteTo); | ||
attrs.requests.insert(OpReqType::kWriteInplace); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why there isn't a kAdd test here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
attrs.requests.insert(OpReqType::kAddTo); | ||
return attrs; | ||
} | ||
|
||
|
@@ -423,6 +437,9 @@ OpAttrs GetSumOp() { | |
attrs.dispatches.resize(2); | ||
attrs.dispatches[0] = DispatchMode::kFCompute; | ||
attrs.dispatches[1] = DispatchMode::kFComputeEx; | ||
attrs.requests.insert(OpReqType::kWriteTo); | ||
attrs.requests.insert(OpReqType::kWriteInplace); | ||
attrs.requests.insert(OpReqType::kAddTo); | ||
return attrs; | ||
} | ||
|
||
|
@@ -434,6 +451,9 @@ OpAttrs GetSumBackwardsOp() { | |
attrs.dispatches.resize(2); | ||
attrs.dispatches[0] = DispatchMode::kFCompute; | ||
attrs.dispatches[1] = DispatchMode::kFComputeEx; | ||
attrs.requests.insert(OpReqType::kWriteTo); | ||
attrs.requests.insert(OpReqType::kWriteInplace); | ||
attrs.requests.insert(OpReqType::kAddTo); | ||
return attrs; | ||
} | ||
|
||
|
@@ -821,6 +841,21 @@ void VerifyConcatResult(const std::vector<NDArray *> &in_arrs, | |
} | ||
} | ||
|
||
void VerifyAddRequest(const std::vector<NDArray*> &in_arrs, | ||
const std::vector<NDArray*> &original_outputs, | ||
const std::vector<NDArray*> &new_outputs, | ||
VerifyFunc verify_fn) { | ||
CHECK(original_outputs.size() == new_outputs.size()); | ||
std::vector<NDArray*> tmp_outputs; | ||
NDArray tmp; | ||
for (size_t i = 0; i < new_outputs.size(); i++) { | ||
tmp = new_outputs[i]->Reorder2Default() - original_outputs[i]->Reorder2Default(); | ||
tmp_outputs.push_back(&tmp); | ||
} | ||
Engine::Get()->WaitForAll(); | ||
verify_fn(in_arrs, tmp_outputs); | ||
} | ||
|
||
void VerifyConcatBackwardsResult(const std::vector<NDArray *> &in_arrs, | ||
const std::vector<NDArray *> &out_arrs) { | ||
// in_arrs is larger array, out_arr is ammler | ||
|
@@ -846,15 +881,6 @@ void VerifyConcatBackwardsResult(const std::vector<NDArray *> &in_arrs, | |
} | ||
} | ||
|
||
void VerifyAddRequest(const std::vector<NDArray*> &in_arrs, | ||
const std::vector<NDArray*> &original_outputs, | ||
const std::vector<NDArray*> &new_outputs, | ||
VerifyFunc verify_fn) { | ||
NDArray tmp = new_outputs[0]->Reorder2Default() - original_outputs[0]->Reorder2Default(); | ||
tmp.WaitToRead(); | ||
verify_fn(in_arrs, {&tmp}); | ||
} | ||
|
||
TEST(MKLDNN_NDArray, CopyFrom) { | ||
TestArrayShapes tas = GetTestArrayShapes(); | ||
std::vector<mkldnn::memory::primitive_desc> pds = tas.pds; | ||
|
@@ -879,54 +905,88 @@ void TestOp(const OpAttrs &attrs, VerifyFunc verify_fn) { | |
std::vector<NDArray*> inputs(attrs.num_inputs); | ||
std::vector<NDArray*> outputs(attrs.num_outputs); | ||
std::vector<OpReqType> req(attrs.num_outputs); | ||
std::vector<NDArrayAttrs> in_arrs; | ||
std::vector<std::vector<NDArrayAttrs>> out_arrs(attrs.num_outputs); | ||
std::vector<DispatchMode> dispatches = attrs.dispatches; | ||
|
||
TestArrayShapes tas = GetTestArrayShapes(); | ||
std::vector<mkldnn::memory::primitive_desc> pds = tas.pds; | ||
|
||
std::vector<NDArrayAttrs> in_arrs = GetTestInputArrays(); | ||
for (auto &in_arr : in_arrs) { | ||
if (attrs.requests.find(OpReqType::kWriteTo) != attrs.requests.end()) { | ||
std::vector<NDArrayAttrs> in_arrs = GetTestInputArrays(); | ||
for (auto &in_arr : in_arrs) { | ||
for (auto &dispatch : dispatches) { | ||
std::vector<std::vector<NDArrayAttrs>> out_arrs(attrs.num_outputs); | ||
for (int i = 0; i < attrs.num_outputs; i++) | ||
out_arrs[i] = GetTestOutputArrays(in_arr.arr.shape(), pds); | ||
for (int i = 0; i < attrs.num_inputs; i++) | ||
inputs[i] = &in_arr.arr; | ||
for (size_t output_i = 0; output_i < out_arrs[0].size(); output_i++) { | ||
for (int i = 0; i < attrs.num_outputs; i++) { | ||
req[i] = kWriteTo; | ||
outputs[i] = &out_arrs[i][output_i].arr; | ||
} | ||
PrintVerifyMsg(in_arr, out_arrs[0][output_i]); | ||
Imperative::Get()->InvokeOp(Context(), attrs.attrs, inputs, | ||
outputs, req, dispatch, mxnet::OpStatePtr()); | ||
Engine::Get()->WaitForAll(); | ||
verify_fn(inputs, outputs); | ||
} | ||
} | ||
} | ||
} | ||
|
||
if (attrs.requests.find(OpReqType::kWriteInplace) != attrs.requests.end()) { | ||
for (auto &dispatch : dispatches) { | ||
std::vector<std::vector<NDArrayAttrs>> out_arrs(attrs.num_outputs); | ||
for (int i = 0; i < attrs.num_outputs; i++) | ||
out_arrs[i] = GetTestOutputArrays(in_arr.arr.shape(), pds); | ||
for (int i = 0; i < attrs.num_inputs; i++) | ||
inputs[i] = &in_arr.arr; | ||
for (size_t output_i = 0; output_i < out_arrs[0].size(); output_i++) { | ||
in_arrs = GetTestInputArrays(); | ||
for (auto &arr : in_arrs) { | ||
// If the array is a view, we shouldn't write data to it. | ||
if (arr.arr.IsView()) | ||
continue; | ||
NDArrayAttrs orig(arr.arr.Copy(arr.arr.ctx()), "InPlace Copy"); | ||
for (int i = 0; i < attrs.num_inputs; i++) | ||
inputs[i] = &arr.arr; | ||
for (int i = 0; i < attrs.num_outputs; i++) { | ||
req[i] = kWriteTo; | ||
outputs[i] = &out_arrs[i][output_i].arr; | ||
req[i] = kWriteInplace; | ||
outputs[i] = &arr.arr; | ||
} | ||
PrintVerifyMsg(in_arr, out_arrs[0][output_i]); | ||
Imperative::Get()->InvokeOp(Context(), attrs.attrs, inputs, | ||
outputs, req, dispatch, mxnet::OpStatePtr()); | ||
PrintVerifyMsg(orig, arr); | ||
Imperative::Get()->InvokeOp(Context(), attrs.attrs, inputs, outputs, req, | ||
dispatch, mxnet::OpStatePtr()); | ||
Engine::Get()->WaitForAll(); | ||
verify_fn(inputs, outputs); | ||
std::vector<NDArray *> orig_inputs(attrs.num_inputs); | ||
for (int i = 0; i < attrs.num_inputs; i++) | ||
orig_inputs[i] = &orig.arr; | ||
verify_fn(orig_inputs, outputs); | ||
} | ||
} | ||
} | ||
|
||
for (auto &dispatch : dispatches) { | ||
if (attrs.requests.find(OpReqType::kAddTo) != attrs.requests.end()) { | ||
std::vector<NDArray*> original_outputs(attrs.num_outputs); | ||
in_arrs = GetTestInputArrays(); | ||
for (auto &arr : in_arrs) { | ||
// If the array is a view, we shouldn't write data to it. | ||
if (arr.arr.IsView()) | ||
continue; | ||
NDArrayAttrs orig(arr.arr.Copy(arr.arr.ctx()), "InPlace Copy"); | ||
for (int i = 0; i < attrs.num_inputs; i++) | ||
inputs[i] = &arr.arr; | ||
for (int i = 0; i < attrs.num_outputs; i++) { | ||
req[i] = kWriteInplace; | ||
outputs[i] = &arr.arr; | ||
for (auto &in_arr : in_arrs) { | ||
for (auto &dispatch : dispatches) { | ||
for (int i = 0; i < attrs.num_outputs; i++) | ||
out_arrs[i] = GetTestOutputArrays(in_arr.arr.shape(), pds); | ||
for (size_t i = 0; i < attrs.num_inputs; i++) | ||
inputs[i] = &in_arr.arr; | ||
for (size_t output_i = 0; output_i < out_arrs[0].size(); output_i++) { | ||
NDArray tmp; | ||
for (size_t i = 0; i < attrs.num_outputs; i++) { | ||
auto out_arr = out_arrs[i][output_i]; | ||
tmp = out_arr.arr.Copy(out_arr.arr.ctx()); | ||
original_outputs[i] = &tmp; | ||
outputs[i] = &out_arrs[i][output_i].arr; | ||
req[i] = kAddTo; | ||
} | ||
PrintVerifyMsg(in_arr, out_arrs[0][output_i]); | ||
Imperative::Get()->InvokeOp(Context(), attrs.attrs, inputs, | ||
outputs, req, dispatch, mxnet::OpStatePtr()); | ||
Engine::Get()->WaitForAll(); | ||
VerifyAddRequest(inputs, original_outputs, outputs, verify_fn); | ||
} | ||
} | ||
PrintVerifyMsg(orig, arr); | ||
Imperative::Get()->InvokeOp(Context(), attrs.attrs, inputs, outputs, req, | ||
dispatch, mxnet::OpStatePtr()); | ||
Engine::Get()->WaitForAll(); | ||
std::vector<NDArray *> orig_inputs(attrs.num_inputs); | ||
for (int i = 0; i < attrs.num_inputs; i++) | ||
orig_inputs[i] = &orig.arr; | ||
verify_fn(orig_inputs, outputs); | ||
} | ||
} | ||
} | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indent.