This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-497] fix bugs in MKLDNN operators to handle the kAddTo request #11129
Merged
+140
−61
Merged
Changes from 54 commits
Commits
Show all changes
69 commits
Select commit
Hold shift + click to select a range
31fdc8b
fix lint
azai91 b644e02
requests added to opattr
azai91 612f64f
comment out addto
azai91 2c4b41e
can invalidate kAddTo request mkldarrays
azai91 2dc646a
revert adding kAddTo to invalidate
azai91 5278ef9
use copy of output instead of creating new array
azai91 3adcd8d
convert output to default if fallback
azai91 2489b86
do not make copy when init
azai91 c7e64f3
copyex fallback copies to old array with kAddTo
azai91 67001ce
change input mem desc to output mem desc if not equal
azai91 5a75f53
reorder memory in commitoutput
azai91 f5b63fc
allocate temp memory
azai91 4d52987
fix var names
azai91 6b62e97
create helper reorder function to handle diff format/shapes
azai91 9da3655
fix typos
azai91 c0c38ca
fix typos
azai91 2338046
remove unused code
azai91 f974c3c
fix param
azai91 918a864
fix header files
azai91 50fc6ca
force input memory to output
azai91 a9915be
reorder2default keeps pointer to mkldnn memory
azai91 630c091
pass reference
azai91 aa6c406
remove extra lines
azai91 75c5160
do not get raw mem from ptr
azai91 f65ea9c
remove isView check
azai91 3483f28
fallback writes back to output
azai91 0428e0f
remove redundant line
azai91 1cdd60c
remove commented out code
azai91 c9e8f85
use fallback in copy (refactor)
azai91 996d0ef
remove unused header
azai91 4532209
fix lint
azai91 410c491
reorder2default only if mkldnn flag
azai91 2efdc3b
only reorder if mkldnn
azai91 dc3cd8d
does not assume 1 output
azai91 ad66611
sum compares input and output shape
azai91 860fa21
compare address and pd in sum
azai91 a727eea
refactor mkldnnsum
azai91 c76aee3
fix const param
azai91 64422aa
fix header
azai91 ac2b3a1
Merge branch 'master' into test-kAddTo
azai91 bb10946
improve control flow when setting output blob
azai91 ad31578
fix merge
azai91 0e03c96
remove kaddto comment
azai91 6ef7b87
add reqests to operators
azai91 90c9acb
fix spacing
azai91 7d0f275
do sum in place
azai91 3edf492
fix conditionals
azai91 5c20e46
remove redundant reqs
azai91 cd70dac
use wait to read all
azai91 0972ffa
fix lint
azai91 637c76a
create multiple outputs
azai91 5718651
create multiple copies for kaddto
azai91 d91df93
retrigger
azai91 993c7aa
retriggrer
azai91 e7d18be
merge
azai91 e2a464d
retrigger
azai91 dc742c8
retrigger
azai91 92c50f0
another retrigger
azai91 eb97f3d
Merge branch 'master' into test-kAddTo
azai91 113903a
retrigger
azai91 ecbde64
retrigger
azai91 be84769
another another retrigger
azai91 5181420
Merge branch 'master' into test-kAddTo
azai91 0731a58
merge
azai91 ad3c70e
fix merge
azai91 2874d0a
retrigger
azai91 0e249f7
merge
azai91 581495f
add kAddto to relu op
azai91 9e7c22e
retrigger
azai91 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -98,10 +98,21 @@ inline bool SetupDefaultBlobsOut(const std::vector<NDArray>& src, | |
is_default = nd.IsDefaultData(); | ||
#endif | ||
if (!is_default) { | ||
NDArray temp = bufs != nullptr ? bufs->at(i) : NDArray(nd.shape(), nd.ctx(), | ||
true, nd.dtype()); | ||
#if MXNET_USE_MKLDNN == 1 | ||
NDArray temp; | ||
if (bufs != nullptr) { | ||
temp = bufs->at(i); | ||
} else if (kAddTo == req->at(i) && nd.IsMKLDNNData()) { | ||
temp = nd.Reorder2Default(); | ||
} else if (kAddTo == req->at(i)) { | ||
temp = nd; | ||
} else { | ||
temp = NDArray(nd.shape(), nd.ctx(), true, nd.dtype()); | ||
} | ||
CHECK(temp.IsDefaultData()); | ||
#else | ||
NDArray temp = bufs != nullptr ? bufs->at(i) : NDArray(nd.shape(), nd.ctx(), | ||
true, nd.dtype()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. indent. |
||
#endif | ||
temp_src->emplace_back(nd); | ||
temp_dst->emplace_back(temp); | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,6 +22,8 @@ | |
#include <atomic> | ||
#include "./mkldnn_base-inl.h" | ||
#include "./mkldnn_ops-inl.h" | ||
#include "../../../common/exec_utils.h" | ||
|
||
|
||
namespace mxnet { | ||
|
||
|
@@ -77,6 +79,75 @@ mkldnn::memory *TmpMemMgr::Alloc(const mkldnn::memory::primitive_desc &pd) { | |
} | ||
} | ||
|
||
void MKLDNNCopy(const mkldnn::memory &mem, const mkldnn::memory* this_mem) { | ||
MKLDNNStream *stream = MKLDNNStream::Get(); | ||
|
||
mkldnn::memory::primitive_desc from_pd = mem.get_primitive_desc(); | ||
mkldnn::memory::desc from_desc = from_pd.desc(); | ||
mkldnn::memory::primitive_desc this_pd = this_mem->get_primitive_desc(); | ||
mkldnn::memory::desc this_desc = this_pd.desc(); | ||
mkldnn_memory_format_t from_def_format = GetDefaultFormat(from_desc); | ||
mkldnn_memory_format_t this_def_format = GetDefaultFormat(this_desc); | ||
// It's possible that the memory and the NDArray don't have the same shape. | ||
if (!same_shape(this_desc, from_desc) | ||
// If the source memory uses the default layout, we can reshape directly. | ||
&& from_def_format == from_desc.data.format) { | ||
// In this case, we can simply create a new MKLDNN memory for the required | ||
// shape. | ||
mkldnn::memory::dims dims(this_desc.data.dims, | ||
this_desc.data.dims + this_desc.data.ndims); | ||
auto this_dtype = static_cast<mkldnn::memory::data_type>(this_desc.data.data_type); | ||
auto this_format = static_cast<mkldnn::memory::format>(GetDefaultFormat(this_desc)); | ||
mkldnn::memory::desc data_md(dims, this_dtype, this_format); | ||
mkldnn::memory::primitive_desc pd(data_md, from_pd.get_engine()); | ||
mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, mem.get_data_handle())); | ||
stream->RegisterMem(tmp_mem); | ||
stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *this_mem)); | ||
} else if (!same_shape(this_desc, from_desc)) { | ||
// In this case, the source memory stores data in a customized layout. We | ||
// need to reorganize the data in memory before we can reshape. | ||
mkldnn::memory::primitive_desc def_pd = GetPrimitiveDesc(from_pd, from_def_format); | ||
mkldnn::memory *def_mem = TmpMemMgr::Get()->Alloc(def_pd); | ||
stream->RegisterPrim(mkldnn::reorder(mem, *def_mem)); | ||
// Now we can reshape it | ||
mkldnn::memory::dims dims(this_desc.data.dims, | ||
this_desc.data.dims + this_desc.data.ndims); | ||
auto this_dtype = static_cast<mkldnn::memory::data_type>(this_desc.data.data_type); | ||
auto this_format = static_cast<mkldnn::memory::format>(GetDefaultFormat(this_desc)); | ||
mkldnn::memory::desc data_md(dims, this_dtype, this_format); | ||
mkldnn::memory::primitive_desc pd(data_md, from_pd.get_engine()); | ||
mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, def_mem->get_data_handle())); | ||
stream->RegisterMem(tmp_mem); | ||
stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *this_mem)); | ||
} else if (from_pd == this_pd) { | ||
// If the layout is the same, we can just copy data. | ||
stream->RegisterPrim(mkldnn::reorder(mem, *this_mem)); | ||
} else { | ||
// If both are not using the default layouts. There isn't much we can do, | ||
// other than reorder data layout directly. | ||
if (this_def_format != this_desc.data.format | ||
&& from_def_format != from_desc.data.format) { | ||
stream->RegisterPrim(mkldnn::reorder(mem, *this_mem)); | ||
} else if (this_def_format == this_desc.data.format) { | ||
// If the dest mem uses the default memory layout, we can simply use | ||
// the default format of the source memory to improve perf of reorder. | ||
mkldnn::memory::primitive_desc pd = GetPrimitiveDesc(from_pd, | ||
from_def_format); | ||
mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, this_mem->get_data_handle())); | ||
stream->RegisterMem(tmp_mem); | ||
stream->RegisterPrim(mkldnn::reorder(mem, *tmp_mem)); | ||
} else { | ||
// If the src mem uses the default memory layout, we can use | ||
// the default format of the source memory to improve perf. | ||
mkldnn::memory::primitive_desc pd = GetPrimitiveDesc(this_pd, | ||
this_def_format); | ||
mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, mem.get_data_handle())); | ||
stream->RegisterMem(tmp_mem); | ||
stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *this_mem)); | ||
} | ||
} | ||
} | ||
|
||
bool CanWriteTo(const NDArray &out_arr, | ||
const NDArray &in_arr, | ||
const mkldnn::memory::primitive_desc &desc) { | ||
|
@@ -141,13 +212,16 @@ void CommitOutput(const NDArray &arr, const mkldnn_output_t &res) { | |
if (res.first == CopyBack) { | ||
const_cast<NDArray &>(arr).CopyFrom(*res.second); | ||
} else if (res.first == AddBack) { | ||
auto res_memory = res.second; | ||
auto target_pd = arr.GetMKLDNNData()->get_primitive_desc(); | ||
auto mem = arr.GetMKLDNNData(res.second->get_primitive_desc()); | ||
CHECK(mem != nullptr); | ||
// We have to allocate new memory for the sum result. | ||
auto sum_res = TmpMemMgr::Get()->Alloc( | ||
res.second->get_primitive_desc()); | ||
op::MKLDNNSum(*res.second, *mem, *sum_res); | ||
const_cast<NDArray &>(arr).CopyFrom(*sum_res); | ||
if (mem == nullptr) { | ||
auto tmp_memory = TmpMemMgr::Get()->Alloc(target_pd); | ||
MKLDNNCopy(*res_memory, tmp_memory); | ||
res_memory = tmp_memory; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As my understanding, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in line 224 we use res_memory and add it with mem. |
||
mem = arr.GetMKLDNNData(); | ||
} | ||
op::MKLDNNSum(*mem, *res_memory, *mem); | ||
} | ||
} | ||
|
||
|
@@ -317,18 +391,28 @@ void FallBackCompute(FCompute fn, const nnvm::NodeAttrs &attrs, | |
MKLDNNStream::Get()->Submit(); | ||
|
||
std::vector<TBlob> out_blobs(outputs.size()); | ||
std::vector<NDArray> temp_src, temp_dst; | ||
for (size_t i = 0; i < out_blobs.size(); i++) { | ||
NDArray output = outputs[i]; | ||
// ensure output does not use mkldnn mem. | ||
// for inplace, we already converted & copied input above. | ||
if ((req[i] == kWriteTo) || (req[i] == kWriteInplace)) | ||
if ((req[i] == kWriteTo) || (req[i] == kWriteInplace)) { | ||
const_cast<NDArray &>(output).InvalidateMKLDNNData(); | ||
else if (req[i] == kAddTo) | ||
output = outputs[i].Reorder2Default(); | ||
} else if (req[i] == kAddTo && output.IsMKLDNNData()) { | ||
NDArray temp = outputs[i].Reorder2Default(); | ||
temp_src.emplace_back(temp); | ||
temp_dst.emplace_back(outputs[i]); | ||
output = temp; | ||
} | ||
CHECK(output.IsDefaultData()); | ||
out_blobs[i] = output.data(); | ||
} | ||
|
||
fn(attrs, ctx, in_blobs, req, out_blobs); | ||
for (size_t i = 0; i < out_blobs.size(); i++) { | ||
if (req[i] == kAddTo && outputs[i].IsMKLDNNData()) | ||
mxnet::common::CastNonDefaultStorage(temp_src, temp_dst, ctx, false); | ||
} | ||
} | ||
|
||
template<typename DType> | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sparse arrays doesn't have kAddTo? @eric-haibin-lin
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. The executor won't generate kAddTo for sparse outputs. Sparse operators don't support that