Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Paddle-TRT] IPluginExt -> IPluginV2 #33680

Merged
merged 58 commits into from
Jul 12, 2021
Merged
Changes from 42 commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
f40cee7
add trt LT version helper
zlsh80826 Jun 18, 2021
75d20da
upgrade PluginTensorRT to IPluginV2Ext
zlsh80826 Jun 20, 2021
5d63c78
trt plugin factory is not usable in IPluginV2
zlsh80826 Jun 20, 2021
f7e94a0
upgrade add plugin api to use IPluginV2
zlsh80826 Jun 20, 2021
ac6bd52
remove IPlugin register and adapt getSerializeSize(), serialize()
zlsh80826 Jun 20, 2021
b311c3f
adapt IPluginV2Layer
zlsh80826 Jun 20, 2021
02b4120
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zlsh80826 Jun 22, 2021
70e8a1e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zlsh80826 Jun 24, 2021
21c9c35
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zlsh80826 Jun 24, 2021
b290aa9
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zlsh80826 Jun 24, 2021
99e1461
downgrade to IPluginV2
zlsh80826 Jun 24, 2021
e7e8d92
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zlsh80826 Jun 24, 2021
686b368
implement elementwise clone
zlsh80826 Jun 24, 2021
fffb0b5
add gelu plugin creator and fix gelu serialization bug
zlsh80826 Jun 25, 2021
b55ac3a
add swish plugin creator and fix swish serialization bug
zlsh80826 Jun 25, 2021
a852793
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zlsh80826 Jun 25, 2021
c5b08bb
format
zlsh80826 Jun 25, 2021
326e093
fix typo
zlsh80826 Jun 25, 2021
d434252
add elementwise plugin creator and fix serialization
zlsh80826 Jun 25, 2021
3d36ece
add base creator class
zlsh80826 Jun 25, 2021
2020e58
add gelu plugin creator
zlsh80826 Jun 25, 2021
84e3675
add hard swish creator and fix serialization
zlsh80826 Jun 25, 2021
20c91fa
add instance norm creator and fix serialization
zlsh80826 Jun 25, 2021
8fcd8fd
add layer norm creator and fix serialization
zlsh80826 Jun 25, 2021
571ca99
add pool creator and fix serialization
zlsh80826 Jun 25, 2021
f60ec9c
add prelu creator and fix serialization
zlsh80826 Jun 25, 2021
4c1ab09
add slice creator and fix serialization
zlsh80826 Jun 25, 2021
c7053fa
add swish creator and fix serialization
zlsh80826 Jun 25, 2021
b0df6b1
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zlsh80826 Jun 25, 2021
3b2e493
add instance norm op unittest
zlsh80826 Jun 25, 2021
3c0a212
remove redundent api
zlsh80826 Jun 25, 2021
3d59b63
fix wrong graph size to enable trt
zlsh80826 Jun 25, 2021
045e906
instance norm function move to cc
zlsh80826 Jun 26, 2021
3ad1eb6
add trt elementwise ut to trigger coverage
zlsh80826 Jun 26, 2021
92849b9
remove opt cahce to hit serialization coverage
zlsh80826 Jun 26, 2021
3723bbb
remove opt cahce to hit serialization coverage
zlsh80826 Jun 26, 2021
5a6b329
remove unused code
zlsh80826 Jun 26, 2021
b57de25
remove unused inputs_
zlsh80826 Jun 26, 2021
4a03a9c
add dbg info
zlsh80826 Jun 26, 2021
1d5c960
remove dbg info
zlsh80826 Jun 26, 2021
d29a365
add instance norm serialization
zlsh80826 Jun 26, 2021
6ac45e2
roll back
zlsh80826 Jun 26, 2021
ce98f21
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zlsh80826 Jun 28, 2021
4397920
remove comment code
zlsh80826 Jun 28, 2021
faa9bf6
remove trt plugin registery
zlsh80826 Jun 28, 2021
8efb2c5
fix prelu dynamic serialization
zlsh80826 Jun 28, 2021
55fb335
add prelu ut and reduce the input size to reduce memory usage
zlsh80826 Jun 28, 2021
a340884
fix pool dynamic plugin serialization and add ut
zlsh80826 Jun 28, 2021
6097f42
refine pool ut with subtest
zlsh80826 Jun 28, 2021
eafd8c7
add env for avoiding oom
zlsh80826 Jun 28, 2021
d9a6505
reduce test input size & increase pool op ut to 45s
zlsh80826 Jun 28, 2021
80cef1e
add the contributor
zlsh80826 Jun 29, 2021
1c0f1f7
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zlsh80826 Jun 29, 2021
8a10d21
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zlsh80826 Jul 8, 2021
33c79db
Merge branch 'trt-IPluginV2Ext' of github.com:zlsh80826/Paddle into t…
zlsh80826 Jul 8, 2021
ca6b8b3
remove copyright (will add in contributor)
zlsh80826 Jul 8, 2021
df49070
remove copyright (will add in contributor)
zlsh80826 Jul 9, 2021
95f4d3d
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zlsh80826 Jul 9, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions paddle/fluid/inference/tensorrt/convert/elementwise_op.cc
Original file line number Diff line number Diff line change
@@ -251,10 +251,11 @@ class ElementwiseTensorOpConverter : public OpConverter {
} else {
plugin::ElementWisePlugin* plugin =
new plugin::ElementWisePlugin(op_type_, dims_x, dims_y, axis);
plugin->AddInput(X);
plugin->AddInput(Y);
nvinfer1::IPluginLayer* plugin_layer = engine_->AddPlugin(
plugin->GetInputs().data(), 2,
std::vector<nvinfer1::ITensor*> inputs{X, Y};
// plugin->AddInput(X);
// plugin->AddInput(Y);
auto* plugin_layer = engine_->AddPlugin(
inputs.data(), inputs.size(),
reinterpret_cast<plugin::PluginTensorRT*>(plugin));

layer = plugin_layer;
Original file line number Diff line number Diff line change
@@ -74,7 +74,7 @@ class InstanceNormOpConverter : public OpConverter {
plugin::InstanceNormPlugin* plugin =
new plugin::InstanceNormPlugin(eps, scale_v, bias_v);
plugin->getPluginType();
nvinfer1::IPluginLayer* layer = engine_->AddPlugin(&input, 1, plugin);
auto* layer = engine_->AddPlugin(&input, 1, plugin);

auto output_name = op_desc.Output("Y")[0];
RreplenishLayerAndOutput(layer, "instance_norm", {output_name}, test_mode);
Original file line number Diff line number Diff line change
@@ -61,7 +61,8 @@ class ShuffleChannelOpConverter : public OpConverter {
reshape_layer->setReshapeDimensions(reshape_dim2);

auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(reshape_layer, "concat", {output_name}, test_mode);
RreplenishLayerAndOutput(reshape_layer, "shuffle_channel", {output_name},
test_mode);
}
};

4 changes: 2 additions & 2 deletions paddle/fluid/inference/tensorrt/engine.cc
Original file line number Diff line number Diff line change
@@ -330,11 +330,11 @@ float *TensorRTEngine::GetWeightCPUData(const std::string &name,

int TensorRTEngine::GetRuntimeBatch() { return runtime_batch_; }

nvinfer1::IPluginLayer *TensorRTEngine::AddPlugin(
nvinfer1::IPluginV2Layer *TensorRTEngine::AddPlugin(
nvinfer1::ITensor *const *inputs, int num_inputs,
plugin::PluginTensorRT *plugin) {
owned_plugin_.emplace_back(plugin);
return network()->addPluginExt(inputs, num_inputs, *plugin);
return network()->addPluginV2(inputs, num_inputs, *plugin);
}

nvinfer1::IPluginV2Layer *TensorRTEngine::AddPluginV2Ext(
10 changes: 2 additions & 8 deletions paddle/fluid/inference/tensorrt/engine.h
Original file line number Diff line number Diff line change
@@ -280,14 +280,8 @@ class TensorRTEngine {
infer_engine_.reset(runtime->deserializeCudaEngine(
engine_serialized_data.c_str(), engine_serialized_data.size()));
} else {
#if IS_TRT_VERSION_LT(8000)
infer_engine_.reset(runtime->deserializeCudaEngine(
engine_serialized_data.c_str(), engine_serialized_data.size(),
&inference::Singleton<plugin::PluginFactoryTensorRT>::Global()));
#else
infer_engine_.reset(runtime->deserializeCudaEngine(
engine_serialized_data.c_str(), engine_serialized_data.size()));
#endif
}

PADDLE_ENFORCE_NOT_NULL(
@@ -311,8 +305,8 @@ class TensorRTEngine {

int GetDeviceId() { return device_id_; }

nvinfer1::IPluginLayer* AddPlugin(nvinfer1::ITensor* const* inputs,
int num_inputs, plugin::PluginTensorRT*);
nvinfer1::IPluginV2Layer* AddPlugin(nvinfer1::ITensor* const* inputs,
int num_inputs, plugin::PluginTensorRT*);

nvinfer1::IPluginV2Layer* AddPluginV2Ext(nvinfer1::ITensor* const* inputs,
int num_inputs,
Original file line number Diff line number Diff line change
@@ -21,12 +21,6 @@ namespace inference {
namespace tensorrt {
namespace plugin {

ElementWisePlugin *CreateElementWisePluginDeserialize(const void *buffer,
size_t length) {
return new ElementWisePlugin(buffer, length);
}
REGISTER_TRT_PLUGIN("elementwise_plugin", CreateElementWisePluginDeserialize);

namespace details {
template <typename T>
struct Add {
47 changes: 35 additions & 12 deletions paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.h
Original file line number Diff line number Diff line change
@@ -40,14 +40,16 @@ class ElementWisePlugin : public PluginTensorRT {
const char* elementwise_type;
DeserializeValue(&serial_data, &serial_length, &elementwise_type);
type_ = std::string(elementwise_type);
DeserializeValue(&serial_data, &serial_length, &axis_);
DeserializeValue(&serial_data, &serial_length, &dims_x_);
DeserializeValue(&serial_data, &serial_length, &dims_y_);
DeserializeValue(&serial_data, &serial_length, &axis_);
DeserializeValue(&serial_data, &serial_length, &prev_size_);
DeserializeValue(&serial_data, &serial_length, &midd_size_);
DeserializeValue(&serial_data, &serial_length, &post_size_);
}

ElementWisePlugin* clone() const override {
// return new ElementWisePlugin(dims_x_, dims_y_, axis_);
return nullptr;
return new ElementWisePlugin(type_, dims_x_, dims_y_, axis_);
}

const char* getPluginType() const override { return "elementwise_plugin"; }
@@ -65,22 +67,25 @@ class ElementWisePlugin : public PluginTensorRT {
#endif
void* workspace, cudaStream_t stream);

protected:
size_t getSerializationSize() override {
return SerializedSize(getPluginType()) + SerializedSize(axis_) +
size_t getSerializationSize() const override {
return getBaseSerializationSize() + SerializedSize(type_.c_str()) +
SerializedSize(dims_x_) + SerializedSize(dims_y_) +
getBaseSerializationSize();
SerializedSize(axis_) + SerializedSize(prev_size_) +
SerializedSize(midd_size_) + SerializedSize(post_size_);
}

void serialize(void* buffer) override {
SerializeValue(&buffer, getPluginType());
void serialize(void* buffer) const override {
serializeBase(buffer);
SerializeValue(&buffer, type_.c_str());
SerializeValue(&buffer, axis_);
SerializeValue(&buffer, dims_x_);
SerializeValue(&buffer, dims_y_);
SerializeValue(&buffer, axis_);
SerializeValue(&buffer, prev_size_);
SerializeValue(&buffer, midd_size_);
SerializeValue(&buffer, post_size_);
}

protected:
std::string type_;
nvinfer1::Dims dims_x_;
nvinfer1::Dims dims_y_;
@@ -90,6 +95,20 @@ class ElementWisePlugin : public PluginTensorRT {
int post_size_;
};

class ElementWisePluginCreator : public TensorRTPluginCreator {
public:
const char* getPluginName() const override { return "elementwise_plugin"; }

const char* getPluginVersion() const override { return "1"; }

nvinfer1::IPluginV2* deserializePlugin(const char* name,
const void* serial_data,
size_t serial_length) override {
return new ElementWisePlugin(serial_data, serial_length);
}
};
REGISTER_TRT_PLUGIN_V2(ElementWisePluginCreator);

#if IS_TRT_VERSION_GE(6000)
class ElementwisePluginDynamic : public DynamicPluginTensorRT {
public:
@@ -105,7 +124,9 @@ class ElementwisePluginDynamic : public DynamicPluginTensorRT {
return new ElementwisePluginDynamic(type_, axis_);
}

const char* getPluginType() const override { return "elementwise_plugin"; }
const char* getPluginType() const override {
return "elementwise_plugin_dynamic";
}
int getNbOutputs() const override { return 1; }
int initialize() override;

@@ -150,7 +171,9 @@ class ElementwisePluginDynamic : public DynamicPluginTensorRT {
class ElementwisePluginDynamicCreator : public nvinfer1::IPluginCreator {
public:
ElementwisePluginDynamicCreator() {}
const char* getPluginName() const override { return "elementwise_plugin"; }
const char* getPluginName() const override {
return "elementwise_plugin_dynamic";
}

const char* getPluginVersion() const override { return "1"; }

6 changes: 0 additions & 6 deletions paddle/fluid/inference/tensorrt/plugin/gelu_op_plugin.cu
Original file line number Diff line number Diff line change
@@ -31,12 +31,6 @@ static const float kAT = 0.5;
static const float kBT = 0.7978845608028654; // sqrt(2.0/M_PI)
static const float kCT = 0.035677408136300125; // 0.044715 * sqrt(2.0/M_PI)

GeluPlugin* CreateGeluPluginDeserialize(const void* buffer, size_t length) {
return new GeluPlugin(buffer, length);
}

REGISTER_TRT_PLUGIN("gelu_plugin", CreateGeluPluginDeserialize);

bool GeluPlugin::supportsFormat(nvinfer1::DataType type,
nvinfer1::PluginFormat format) const {
if (with_fp16_) {
53 changes: 19 additions & 34 deletions paddle/fluid/inference/tensorrt/plugin/gelu_op_plugin.h
Original file line number Diff line number Diff line change
@@ -51,18 +51,28 @@ class GeluPlugin : public PluginTensorRT {
#endif
void* workspace, cudaStream_t stream) override;

protected:
size_t getSerializationSize() override {
return getBaseSerializationSize() + SerializedSize(getPluginType());
size_t getSerializationSize() const override {
return getBaseSerializationSize();
}

// TRT will call this func to serialize the configuration of TRT
// It should not be called by users.
void serialize(void* buffer) override {
SerializeValue(&buffer, getPluginType());
serializeBase(buffer);
void serialize(void* buffer) const override { serializeBase(buffer); }
};

class GeluPluginCreator : public TensorRTPluginCreator {
public:
const char* getPluginName() const override { return "gelu_plugin"; }

const char* getPluginVersion() const override { return "1"; }

nvinfer1::IPluginV2* deserializePlugin(const char* name,
const void* serial_data,
size_t serial_length) override {
return new GeluPlugin(serial_data, serial_length);
}
};
REGISTER_TRT_PLUGIN_V2(GeluPluginCreator);

#if IS_TRT_VERSION_GE(6000)
class GeluPluginDynamic : public DynamicPluginTensorRT {
@@ -77,7 +87,7 @@ class GeluPluginDynamic : public DynamicPluginTensorRT {
return new GeluPluginDynamic(with_fp16_);
}

const char* getPluginType() const override { return "gelu_plugin"; }
const char* getPluginType() const override { return "gelu_plugin_dynamic"; }
int getNbOutputs() const override { return 1; }
int initialize() override { return 0; }

@@ -119,44 +129,19 @@ class GeluPluginDynamic : public DynamicPluginTensorRT {
void destroy() override { delete this; }
};

class GeluPluginDynamicCreator : public nvinfer1::IPluginCreator {
class GeluPluginDynamicCreator : public TensorRTPluginCreator {
public:
GeluPluginDynamicCreator() {}
const char* getPluginName() const override { return "gelu_plugin"; }
const char* getPluginName() const override { return "gelu_plugin_dynamic"; }

const char* getPluginVersion() const override { return "1"; }

const nvinfer1::PluginFieldCollection* getFieldNames() override {
return &field_collection_;
}

nvinfer1::IPluginV2* createPlugin(
const char* name, const nvinfer1::PluginFieldCollection* fc) override {
return nullptr;
}

nvinfer1::IPluginV2* deserializePlugin(const char* name,
const void* serial_data,
size_t serial_length) override {
auto plugin = new GeluPluginDynamic(serial_data, serial_length);
return plugin;
}

void setPluginNamespace(const char* lib_namespace) override {
plugin_namespace_ = lib_namespace;
}

const char* getPluginNamespace() const override {
return plugin_namespace_.c_str();
}

private:
std::string plugin_namespace_;
std::string plugin_name_;
nvinfer1::PluginFieldCollection field_collection_{0, nullptr};
std::vector<nvinfer1::PluginField> plugin_attributes_;
};

REGISTER_TRT_PLUGIN_V2(GeluPluginDynamicCreator);
#endif

Original file line number Diff line number Diff line change
@@ -22,13 +22,6 @@ namespace inference {
namespace tensorrt {
namespace plugin {

HardSwishPlugin* CreateHardSwishPluginDeserialize(const void* buffer,
size_t length) {
return new HardSwishPlugin(buffer, length);
}

REGISTER_TRT_PLUGIN("hard_swish_plugin", CreateHardSwishPluginDeserialize);

nvinfer1::Dims HardSwishPlugin::getOutputDimensions(
int index, const nvinfer1::Dims* in_dims, int nb_inputs) {
assert(nb_inputs == 1);
32 changes: 22 additions & 10 deletions paddle/fluid/inference/tensorrt/plugin/hard_swish_op_plugin.h
Original file line number Diff line number Diff line change
@@ -56,27 +56,39 @@ class HardSwishPlugin : public PluginTensorRT {
#endif
void* workspace, cudaStream_t stream) override;

protected:
float threshold_;
float scale_;
float offset_;

size_t getSerializationSize() override {
size_t getSerializationSize() const override {
return getBaseSerializationSize() + SerializedSize(threshold_) +
SerializedSize(scale_) + SerializedSize(offset_) +
SerializedSize(getPluginType());
SerializedSize(scale_) + SerializedSize(offset_);
}

// TRT will call this func to serialize the configuration of TRT
// It should not be called by users.
void serialize(void* buffer) override {
SerializeValue(&buffer, getPluginType());
void serialize(void* buffer) const override {
serializeBase(buffer);
SerializeValue(&buffer, threshold_);
SerializeValue(&buffer, scale_);
SerializeValue(&buffer, offset_);
}

protected:
float threshold_;
float scale_;
float offset_;
};

class HardSwishPluginCreator : public TensorRTPluginCreator {
public:
const char* getPluginName() const override { return "hard_swish_plugin"; }

const char* getPluginVersion() const override { return "1"; }

nvinfer1::IPluginV2* deserializePlugin(const char* name,
const void* serial_data,
size_t serial_length) override {
return new HardSwishPlugin(serial_data, serial_length);
}
};
REGISTER_TRT_PLUGIN_V2(HardSwishPluginCreator);

} // namespace plugin
} // namespace tensorrt
Original file line number Diff line number Diff line change
@@ -40,13 +40,6 @@ cudnnStatus_t convert_trt2cudnn_dtype(nvinfer1::DataType trt_dtype,
return CUDNN_STATUS_SUCCESS;
}

InstanceNormPlugin *CreateInstanceNormPluginDeserialize(const void *buffer,
size_t length) {
return new InstanceNormPlugin(buffer, length);
}
REGISTER_TRT_PLUGIN("instance_norm_plugin",
CreateInstanceNormPluginDeserialize);

int InstanceNormPlugin::initialize() { return 0; }

nvinfer1::Dims InstanceNormPlugin::getOutputDimensions(
@@ -58,6 +51,13 @@ nvinfer1::Dims InstanceNormPlugin::getOutputDimensions(
return output_dims;
}

bool InstanceNormPlugin::supportsFormat(nvinfer1::DataType type,
nvinfer1::PluginFormat format) const {
return ((type == nvinfer1::DataType::kFLOAT ||
type == nvinfer1::DataType::kHALF) &&
(format == nvinfer1::PluginFormat::kNCHW));
zlsh80826 marked this conversation as resolved.
Show resolved Hide resolved
}

int InstanceNormPlugin::enqueue(int batch_size, const void *const *inputs,
#if IS_TRT_VERSION_LT(8000)
void **outputs, void *workspace,
Loading