乐于分享
好东西不私藏

TensorRT 插件使用

TensorRT 插件使用

TensorRT 插件(Plugin) 的核心作用:把 TensorRT 原生不支持的算子,或者你想自己优化的算子,接进 TensorRT 引擎里。

比如这些场景最常见:

  • ONNX 里有 TensorRT 不认识的 op
  • 你有自己写的 CUDA kernel
  • 想把 NMS / Decode / GridSample / DeformableConv / ROIAlign / BEV 特殊算子 做成高性能层
  • 想减少 Host 侧后处理,把后处理塞进 engine
这次不只讲“怎么写”,而是把工程里真正会遇到的 插件生命周期、接口、序列化、enqueue、动态 shape、注册、部署排查 一次讲透。看完之后基本就能自己写一个 TensorRT 插件。

一、先建立正确认知:TensorRT 插件到底是什么

TensorRT 正常工作流程是:

ONNX / Network Definition        ↓TensorRT 解析网络        ↓每个 layer 变成 TensorRT 内部 layer        ↓Builder 生成 engine        ↓Runtime 执行

但如果某个层 TensorRT 不认识,比如:

MyDecodeLayer

它就没法构图。

这时候就需要:Plugin = TensorRT 的“用户自定义层实现”

它要负责两件事:

1)构建期(Build Time)

告诉 TensorRT:

  • 这个层叫什么
  • 输入输出 shape 是什么
  • 支持哪些 dtype / format
  • 需要多少 workspace
  • 参数怎么保存

2)运行期(Runtime)

真正执行你的计算:

  • 拿输入指针
  • 拿输出指针
  • 拿 stream
  • 调你的 CUDA kernel
所以,Plugin 不是“普通 C++ 类”,它本质是 TensorRT 的“自定义算子协议实现”。

二、最常用的插件接口选型(你先别写错)

TensorRT 插件接口历史很多版本,最常见你会看到这些名字:

  • IPluginV2
  • IPluginV2Ext
  • IPluginV2DynamicExt
  • IPluginV3...
    (新版本更复杂)

如果你现在是工程开发,最实用的入口通常是:

nvinfer1::IPluginV2DynamicExt

因为它支持:

  • 动态 shape
  • FP32 / FP16 / INT8 扩展
  • batch 不固定
  • ONNX 导入更常见
现在自己写插件,优先学这个:
IPluginV2DynamicExt

三、插件整体结构(脑子里先有框架)

一个完整 TensorRT 插件通常有 4 个核心部分

1)Plugin 类

真正表示这个自定义层

比如:

class MyPlugin : public nvinfer1::IPluginV2DynamicExt

2)PluginCreator 类

负责:

  • 注册插件
  • 解析插件参数
  • 反序列化 plugin
class MyPluginCreator : public nvinfer1::IPluginCreator

3)插件注册

让 TensorRT 能找到它

REGISTER_TENSORRT_PLUGIN(MyPluginCreator);

4)CUDA / CPU 实现

真正的算子逻辑

比如:

my_kernel<<<...>>>(...)

你可以把它理解成:

PluginCreator   = 工厂 / 注册中心Plugin          = 算子对象enqueue()       = 真正执行入口serialize()     = 保存参数deserialize     = 读回参数
四、实例

1️⃣ Plugin 头文件:AddScalarPlugin.h

#pragma once#include"NvInfer.h"#include<vector>#include<string>#include<cassert>class AddScalarPlugin : public nvinfer1::IPluginV2DynamicExt {public:    AddScalarPlugin(float value) : scalar_(value) {}    // 从序列化数据恢复    AddScalarPlugin(const void* data, size_t length) {        assert(length == sizeof(float));        scalar_ = *reinterpret_cast<const float*>(data);    }    // ----------------- IPluginV2DynamicExt 必须实现 -----------------    nvinfer1::IPluginV2DynamicExt* clone()constnoexceptoverride{        return new AddScalarPlugin(scalar_);    }    intgetNbOutputs()constnoexceptoverridereturn 1; }    nvinfer1::DimsExprs getOutputDimensions(int outputIndex,                                             const nvinfer1::DimsExprs* inputs,                                             int nbInputs,                                             nvinfer1::IExprBuilder& exprBuilder) noexcept override {        return inputs[0]; // 输出和输入同形状    }    boolsupportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs)noexceptoverride{        return inOut[pos].type == nvinfer1::DataType::kFLOAT &&               inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;    }    voidconfigurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs,                          const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept override {}    size_tgetWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs,                             const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept override {         return 0    }    intenqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc,                const voidconst* inputs, voidconst* outputs, void* workspace, cudaStream_t stream) noexcept override;    nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs)constnoexceptoverride{        return inputTypes[0];    }    constchargetPluginType()constnoexceptoverridereturn "AddScalarPlugin"; }    constchargetPluginVersion()constnoexceptoverridereturn "1"; }    intinitialize()noexceptoverridereturn 0; }    voidterminate()noexceptoverride{}    size_tgetSerializationSize()constnoexceptoverridereturn sizeof(float); }    voidserialize(void* buffer)constnoexceptoverride{ *reinterpret_cast<float*>(buffer) = scalar_; }    voiddestroy()noexceptoverridedelete this; }    voidsetPluginNamespace(constchar* pluginNamespace)noexceptoverride{ namespace_ = pluginNamespace; }    constchargetPluginNamespace()constnoexceptoverridereturn namespace_.c_str(); }private:    float scalar_;    std::string namespace_;};

2️⃣ Plugin 实现文件:AddScalarPlugin.cu

#include"AddScalarPlugin.h"#include"NvInfer.h"class AddScalarPluginCreator : public nvinfer1::IPluginCreator {public:    AddScalarPluginCreator() {        mFC.nbFields = 0;        mFC.fields = nullptr;    }    constchargetPluginName()constnoexceptoverridereturn "AddScalarPlugin"; }    constchargetPluginVersion()constnoexceptoverridereturn "1"; }    const nvinfer1::PluginFieldCollection* getFieldNames()noexceptoverridereturn &mFC; }    nvinfer1::IPluginV2* createPlugin(constchar* name, const nvinfer1::PluginFieldCollection* fc)noexceptoverride{        return new AddScalarPlugin(1.0f); // 默认加 1    }    nvinfer1::IPluginV2* deserializePlugin(constchar* name, constvoid* serialData, size_t serialLength)noexceptoverride{        return new AddScalarPlugin(serialData, serialLength);    }    voidsetPluginNamespace(constchar* libNamespace)noexceptoverride{ mNamespace = libNamespace; }    constchargetPluginNamespace()constnoexceptoverridereturn mNamespace.c_str(); }private:    nvinfer1::PluginFieldCollection mFC{};    std::string mNamespace;};

✅ 这个插件可以直接在 TensorRT 网络中使用:

auto plugin = AddScalarPlugin(2.5f);  // 每个元素加 2.5nvinfer1::ITensor* input = network->getInput(0);auto layer = network->addPluginV2(&input, 1, plugin);