TensorRT 插件使用
TensorRT 插件(Plugin) 的核心作用:把 TensorRT 原生不支持的算子,或者你想自己优化的算子,接进 TensorRT 引擎里。
比如这些场景最常见:
-
ONNX 里有 TensorRT 不认识的 op -
你有自己写的 CUDA kernel -
想把 NMS / Decode / GridSample / DeformableConv / ROIAlign / BEV 特殊算子做成高性能层 -
想减少 Host 侧后处理,把后处理塞进 engine
一、先建立正确认知:TensorRT 插件到底是什么
TensorRT 正常工作流程是:
ONNX / Network Definition↓TensorRT 解析网络↓每个 layer 变成 TensorRT 内部 layer↓Builder 生成 engine↓Runtime 执行
但如果某个层 TensorRT 不认识,比如:
MyDecodeLayer它就没法构图。
这时候就需要:Plugin = TensorRT 的“用户自定义层实现”
它要负责两件事:
1)构建期(Build Time)
告诉 TensorRT:
-
这个层叫什么 -
输入输出 shape 是什么 -
支持哪些 dtype / format -
需要多少 workspace -
参数怎么保存
2)运行期(Runtime)
真正执行你的计算:
-
拿输入指针 -
拿输出指针 -
拿 stream -
调你的 CUDA kernel
二、最常用的插件接口选型(你先别写错)
TensorRT 插件接口历史很多版本,最常见你会看到这些名字:
IPluginV2IPluginV2ExtIPluginV2DynamicExtIPluginV3...
(新版本更复杂)
如果你现在是工程开发,最实用的入口通常是:
nvinfer1::IPluginV2DynamicExt
因为它支持:
-
动态 shape -
FP32 / FP16 / INT8 扩展 -
batch 不固定 -
ONNX 导入更常见
IPluginV2DynamicExt
三、插件整体结构(脑子里先有框架)
一个完整 TensorRT 插件通常有 4 个核心部分:
1)Plugin 类
真正表示这个自定义层
比如:
class MyPlugin : public nvinfer1::IPluginV2DynamicExt2)PluginCreator 类
负责:
-
注册插件 -
解析插件参数 -
反序列化 plugin
class MyPluginCreator : public nvinfer1::IPluginCreator
3)插件注册
让 TensorRT 能找到它
REGISTER_TENSORRT_PLUGIN(MyPluginCreator);
4)CUDA / CPU 实现
真正的算子逻辑
比如:
my_kernel<<<...>>>(...)你可以把它理解成:
PluginCreator = 工厂 / 注册中心Plugin = 算子对象enqueue() = 真正执行入口serialize() = 保存参数deserialize = 读回参数四、实例1️⃣ Plugin 头文件:
AddScalarPlugin.h#pragma once#include"NvInfer.h"#include<vector>#include<string>#include<cassert>class AddScalarPlugin : public nvinfer1::IPluginV2DynamicExt {public:AddScalarPlugin(float value) : scalar_(value) {}// 从序列化数据恢复AddScalarPlugin(const void* data, size_t length) {assert(length == sizeof(float));scalar_ = *reinterpret_cast<const float*>(data);}// ----------------- IPluginV2DynamicExt 必须实现 -----------------nvinfer1::IPluginV2DynamicExt* clone()constnoexceptoverride{return new AddScalarPlugin(scalar_);}intgetNbOutputs()constnoexceptoverride{ return 1; }nvinfer1::DimsExprs getOutputDimensions(int outputIndex,const nvinfer1::DimsExprs* inputs,int nbInputs,nvinfer1::IExprBuilder& exprBuilder) noexcept override {return inputs[0]; // 输出和输入同形状}boolsupportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs)noexceptoverride{return inOut[pos].type == nvinfer1::DataType::kFLOAT &&inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;}voidconfigurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs,const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept override {}size_tgetWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs,const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept override {return 0;}intenqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc,const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override;nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs)constnoexceptoverride{return inputTypes[0];}constchar* getPluginType()constnoexceptoverride{ return "AddScalarPlugin"; }constchar* getPluginVersion()constnoexceptoverride{ return "1"; }intinitialize()noexceptoverride{ return 0; }voidterminate()noexceptoverride{}size_tgetSerializationSize()constnoexceptoverride{ return sizeof(float); }voidserialize(void* buffer)constnoexceptoverride{ *reinterpret_cast<float*>(buffer) = scalar_; }voiddestroy()noexceptoverride{ delete this; }voidsetPluginNamespace(constchar* pluginNamespace)noexceptoverride{ namespace_ = pluginNamespace; }constchar* getPluginNamespace()constnoexceptoverride{ return namespace_.c_str(); }private:float scalar_;std::string namespace_;};2️⃣ Plugin 实现文件:
AddScalarPlugin.cu
#include"AddScalarPlugin.h"#include"NvInfer.h"class AddScalarPluginCreator : public nvinfer1::IPluginCreator {public:AddScalarPluginCreator() {mFC.nbFields = 0;mFC.fields = nullptr;}constchar* getPluginName()constnoexceptoverride{ return "AddScalarPlugin"; }constchar* getPluginVersion()constnoexceptoverride{ return "1"; }const nvinfer1::PluginFieldCollection* getFieldNames()noexceptoverride{ return &mFC; }nvinfer1::IPluginV2* createPlugin(constchar* name, const nvinfer1::PluginFieldCollection* fc)noexceptoverride{return new AddScalarPlugin(1.0f); // 默认加 1}nvinfer1::IPluginV2* deserializePlugin(constchar* name, constvoid* serialData, size_t serialLength)noexceptoverride{return new AddScalarPlugin(serialData, serialLength);}voidsetPluginNamespace(constchar* libNamespace)noexceptoverride{ mNamespace = libNamespace; }constchar* getPluginNamespace()constnoexceptoverride{ return mNamespace.c_str(); }private:nvinfer1::PluginFieldCollection mFC{};std::string mNamespace;};✅ 这个插件可以直接在 TensorRT 网络中使用:
auto plugin = AddScalarPlugin(2.5f); // 每个元素加 2.5nvinfer1::ITensor* input = network->getInput(0);auto layer = network->addPluginV2(&input, 1, plugin);
夜雨聆风