Skip to content

Commit

Permalink
[OpAttr]Refine Teller logic if encounter OpDesc with Variable type At…
Browse files Browse the repository at this point in the history
…tribute (#45795)

* [OpAttr]Refine Teller logic if encounter OpDesc with Variable type Attribute

* fix iterator

* fix typo

* fix lambda expr

* fix ptr
  • Loading branch information
Aurelius84 authored Sep 8, 2022
1 parent bd4ce23 commit a642365
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 1 deletion.
6 changes: 5 additions & 1 deletion paddle/fluid/framework/op_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -799,7 +799,11 @@ Attribute OpDesc::GetAttr(const std::string &name, bool with_attr_var) const {
PADDLE_ENFORCE_EQ(
HasAttrVar(it->second),
false,
platform::errors::NotFound("Attribute %s is not found.", name));
platform::errors::NotFound(
"Attribute %s with constant value is not found, but found it with "
"Variable(s) type, which maybe not supported in some scenarios "
"currently, such as TensorRT et.al",
name));
}
return it->second;
}
Expand Down
32 changes: 32 additions & 0 deletions paddle/fluid/inference/tensorrt/op_teller.cc
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,9 @@ bool OpTeller::Tell(const framework::ir::Node* node,
desc.HasAttr("skip_quant"))
return false;

// do not support Attribute with Variable(s) Type
if (HasUnsupportAttrVar(desc)) return false;

for (auto& teller : tellers_) {
std::unordered_set<std::string> act_op_list = {
"relu", "relu6", "sigmoid",
Expand Down Expand Up @@ -2261,6 +2264,35 @@ bool OpTeller::Tell(const framework::ir::Node* node,

return false;
}

bool OpTeller::HasUnsupportAttrVar(const framework::OpDesc& desc) const {
const std::string op_type = desc.Type();
auto has_attr_var = [&](const std::string& attr_name) -> bool {
// If Attribute is Variable(s), HasAttr() will return False
return !desc.HasAttr(attr_name, /*with_attr_var=*/false);
};
std::unordered_map<std::string, std::vector<std::string>> attrs_info = {
{"dropout", {"dropout_prob"}},
{"pool2d", {"ksize"}},
{"arg_max", {"axis"}},
{"reduce_mean", {"dim"}},
{"reduce_sum", {"dim"}},
{"squeeze2", {"axes"}},
};

bool flag = false;
auto iter = attrs_info.find(op_type);
if (iter != attrs_info.end()) {
for (auto& attr_name : iter->second) {
if (has_attr_var(attr_name)) {
flag = true;
break;
}
}
}
return flag;
}

OpTeller::OpTeller() { tellers_.emplace_back(new SimpleOpTypeSetTeller); }
} // namespace tensorrt
} // namespace inference
Expand Down
8 changes: 8 additions & 0 deletions paddle/fluid/inference/tensorrt/op_teller.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,14 @@ class OpTeller {
private:
OpTeller();

/*
* Some OpDescs Attribute support both constant value and dynamic
* runtime value (which is a Variable(s) type). But TensorRT maybe
* only support constant value Attribute, so we shall distinguish
* this case in time and return False in OpTeller.Tell().
*/
bool HasUnsupportAttrVar(const framework::OpDesc& desc) const;

private:
std::vector<std::unique_ptr<Teller>> tellers_;
};
Expand Down

0 comments on commit a642365

Please sign in to comment.