乐于分享
好东西不私藏

TensorRT-LLM 0.5.0 源码之十

TensorRT-LLM 0.5.0 源码之十

linear.py

def _gemm_plugin(input: Tensor,
                 mat2: Tensor,
                 transa: bool = False,
                 transb: bool = False,
                 use_fp8: bool = False
) -> Tensor:
    plg_creator = trt.get_plugin_registry().get_plugin_creator(
        'Gemm'
, '1', TRT_LLM_PLUGIN_NAMESPACE)
    assert
 plg_creator is not None

    transa = 1 if transa else 0
    transa = trt.PluginField("transa", np.array(transa, dtype=np.int32),
                             trt.PluginFieldType.INT32)
    transb = 1 if transb else 0
    transb = trt.PluginField("transb", np.array(transb, dtype=np.int32),
                             trt.PluginFieldType.INT32)
    use_fp8 = 1 if use_fp8 else 0
    use_fp8 = trt.PluginField("use_fp8", np.array(use_fp8, dtype=np.int32),
                              trt.PluginFieldType.INT32)

    p_dtype = default_net().plugin_config.gemm_plugin
    pf_type = trt.PluginField(
        "type_id"
, np.array([int(str_dtype_to_trt(p_dtype))], np.int32),
        trt.PluginFieldType.INT32)
    pfc = trt.PluginFieldCollection([transa, transb, pf_type, use_fp8])
    gemm_plug = plg_creator.create_plugin("gemm", pfc)
    plug_inputs = [input.trt_tensor, mat2.trt_tensor]
    layer = default_trtnet().add_plugin_v2(plug_inputs, gemm_plug)
    return
 _create_tensor(layer.get_output(0), layer)
在这里插入图片描述

模型模块并行方案
Linear层作为切分主要的网络层,其核心是MatMul矩阵计算,因此矩阵切分计算也是模型并行最重要的一部分。

基础矩阵乘模块

matmul1
在这里插入图片描述

在大模型计算中,矩阵乘(MatMul)不管是在权重还是计算量上都占了相当大的比例。观察矩阵乘,其拥有列可切分性(Column-wise Parallelism)和行可切分性(Row-wise Parallelism)。

Column-wise Parallelism

在这里插入图片描述

Row-wise Parallelism

在这里插入图片描述

ColumnLinaer

class Linear(Module):

    def
 __init__(self,
                 in_features,
                 out_features,
                 bias=True,
                 dtype=None,
                 tp_group=None,
                 tp_size=1,
                 gather_output=True,
                 share_weight=None
):
        super
().__init__()
        self
.in_features = in_features
        self
.out_features = out_features // tp_size
        self
.dtype = dtype

        if
 not share_weight:
            self
.weight = Parameter(shape=(self.out_features, self.in_features),
                                    dtype=dtype)
        else
:
            self
.weight = share_weight

        self
.tp_size = tp_size
        self
.tp_group = tp_group
        self
.gather_output = gather_output

        if
 bias:
            self
.bias = Parameter(shape=(self.out_features, ), dtype=dtype)
        else
:
            self
.register_parameter('bias', None)

    def
 multiply_gather(self, x, weight, gemm_plugin, use_fp8=False):
        if
 gemm_plugin:
            x = _gemm_plugin(x, weight, transb=True, use_fp8=use_fp8)
        else
:
            x = matmul(x, weight, transb=True)

        if
 self.bias is not None:
            if
 x.dtype != self.bias.value.dtype:
                x = cast(x, self.bias.value.dtype)
            x = x + self.bias.value

        if
 self.gather_output and self.tp_size > 1 and self.tp_group is not None:
            # 1. [dim0, local_dim] -> [dim0 * tp_size, local_dim]

            x = allgather(x, self.tp_group)

            # 2. [dim0 * tp_size, local_dim] -> [dim0, local_dim * tp_size]

            # 2.1 split

            split_size = shape(x, dim=0) / self.tp_size
            ndim = x.ndim()
            starts = [constant(int32_array([0])) for _ in range(ndim)]
            sizes = [shape(x, dim=d) for d in range(ndim)]
            sizes[0] = split_size
            sections = []
            for
 i in range(self.tp_size):
                starts[0] = split_size * i
                sections.append(slice(x, concat(starts), concat(sizes)))
            # 2.2 concat

            x = concat(sections, dim=1)

        return
 x

    def
 forward(self, x):
        return
 self.multiply_gather(x, self.weight.value,
                                    default_net().plugin_config.gemm_plugin)


ColumnLinear = Linear

RowLinear

class RowLinear(Module):

    def
 __init__(self,
                 in_features,
                 out_features,
                 bias=True,
                 dtype=None,
                 tp_group=None,
                 tp_size=1,
                 instance_id: int = 0
):
        super
().__init__()
        self
.in_features = in_features // tp_size
        self
.out_features = out_features
        self
.dtype = dtype

        self
.weight = Parameter(shape=(self.out_features, self.in_features),
                                dtype=dtype)

        if
 bias:
            self
.bias = Parameter(shape=(self.out_features, ), dtype=dtype)
        else
:
            self
.register_parameter('bias', None)

        self
.tp_group = tp_group
        self
.tp_size = tp_size
        self
.instance_id = instance_id

    def
 multiply_reduce(self,
                        x,
                        weight,
                        gemm_plugin,
                        use_fp8=False,
                        workspace=None
):
        if
 gemm_plugin:
            x = _gemm_plugin(x, weight, transb=True, use_fp8=use_fp8)
        else
:
            x = matmul(x, weight, transb=True)

        if
 self.tp_size > 1 and self.tp_group is not None:
            x = allreduce(x, self.tp_group, workspace, self.instance_id)

        if
 self.bias is not None:
            if
 x.dtype != self.bias.value.dtype:
                x = cast(x, self.bias.value.dtype)

            x = x + self.bias.value

        return
 x

    def
 forward(self, x, workspace=None):
        return
 self.multiply_reduce(x,
                                    self
.weight.value,
                                    default_net().plugin_config.gemm_plugin,
                                    workspace=workspace)

参考文献

  • • https://github.com/NVIDIA/TensorRT-LLM/blob/v0.5.0/tensorrt_llm/layers/linear.py
  • • https://www.mindspore.cn/tutorials/zh-CN/r2.7.1/model_infer/ms_infer/ms_infer_parallel_infer.html
点个「赞」+「在看」❤️
让我们知道这份文字有温暖到你,也是我们持续创作的最大动力!
推荐
Lock-Free 队列实现原理
Share Memory 的 Bank Conflict
告别高成本!TensorRT-LLM实战:如何将LLM推理速度提升数倍
使用LoRA对LLM进行微调的实用技巧
强化学习小白必看:PTX Loss 到底是个啥?
GPT-5 Prompt Migration and Improvement Using the New Optimizer
Task 异步流 coroutine 实现
C++ corotine 介绍
搭建 VSCode 离线开发环境
nlohmann/json 库简介
Intro to C++ Coroutines: Concept
Hugging Face BPE Tokenizer 的资源文件
移动语义 std::move 和完美转发 std::forward
ACEBench: Who Wins the Match Point in Tool Usage?
什么是 GN
RULER: Relative Universal LLM-Elicited Rewards
SFT和RFT的区别
CosyVoice 3: 面向真实场景的大规模零样本语音生成模型
CosyVoice 3: Towards In-the-wild Speech Generation
语音合成(TTS)中文自然度:问题、成因、解决方案
上下文工程如何实现
上下文工程(Context Engineering)
新手必看!LangGraph 101:手把手教你搭一个深度研究 Agent
LangGraph 简介
SFT 泛化新解读:强化学习 + 奖励修正,一文读懂
程序员狂喜!Self-Instruct 框架全解析:无限生成高质量指令集,从此告别标注噩梦!
Evol-Instruct 竟能精准生成领域专属数据?实操技巧速看!
指令微调数据-少即是多
LLM generate 参数怎么用?
语音合成(TTS)跳跃与重复问题的解析:成因、机制及解决方案
大模型训练新思路:GEPA 靠 “反思” 赢过 RL,看完秒懂
F5-TTS:用 Flow Matching 玩转语音,流畅度和真实感都 “拉满” 了
E2 TTS:令人尴尬地简单、完全非自回归、零样本的语音合成技术
Voicebox:大规模文本引导的多语言通用语音生成技术
为什么都在聊 Kimi K2?Open Agentic Intelligence 藏着哪些新惊喜
Step-Audio-AQAA 端到端音频模型
DPO、PPO、GRPO的原理,区别与联系
OPENCSG 中文语料库:一系列高质量的中文数据集,用于语言模型训练
什么是 Classifier-Free Guidance?
Conditional Flow Matching : 连续标准流 Continuous Normalizing Flow
CFM 与 OT-CFM:条件流匹配与最优传输的碰撞
DPO损失实现
Conditional Flow Matching : 常微分方程ODE、欧拉方法和Neural ODE
当 Normalizing flow 遇上语音生成:AI 说话变 “真人” 的秘密在这里!
深度剖析:Kimi – Audio 中 BigVGAN 的神奇作用
为什么说分布变换是 Normalizing flow 的「灵魂操作」?
MATCHA-TTS 来了!条件流匹配让文本转语音效率飙升
从知识增长的角度提升RAG上下文的质量
MiniMax-Speech,零样本语音合成新突破,32 种语言轻松拿捏!
手把手教你创建 evol-instruct 数据集!附完整流程~
社交类聊天的 Query 分析与应答策略
SFT 中指令选择和响应选择哪个更重要?
角色扮演大模型技术分享2-超拟人模型的困境
最新!SpeechLLM 综述:架构、能力、挑战与未来全揭秘
如何低成本生成高质量指令微调数据?
从数量到质量:通过自引导数据选择来提升语言模型性能以实现指令调优
Kimi-Audio:开源音频基础模型全面解析
Kimi-Audio 的 TTS 效果如何?
Qwen 的训练数据是怎么做的?
GeForce RTX 3090, 4090, A10, A40, A100, A800, L20, L40 显卡性能对比
如何低成本生成高质量指令微调数据?
掌握RAG:投入生产前要评估的8个场景
掌握RAG:如何评估RAG的LLM
掌握RAG:如何在部署后观察您的RAG
掌握RAG:如何选择嵌入模型
基础模型中的新范式:为什么o1是不同的,以及它将如何改变LLM应用
Semantic token和连续特征在SLLM下的对比
从数量到质量:通过自引导数据选择来提升语言模型性能以实现指令调优
RLHF及其变体:进展和实际工程见解
Freeze-Omni: 低延迟语音对话模型
Fully Sharded Data Parallelism (FSDP)
什么是置信度?置信度模型怎么做?
晦涩难懂的 Flow matching!图形化理解
中文指令微调数据,质量就是一切!
基于 LLM 的文本泛化
CosyVoice 2:基于大型语言模型的可扩展流式语音合成技术
Mini-Omni2: with Vision, Speech and Duplex Capabilities
FSQ的原理与VQ-VAE的区别和联系
大模型并行训练的一些知识——极简版
亲测有效!如何用 Address Sanitizer 精准定位内存漏洞?附保姆级操作指南
要用 AI 裁员 50% 的千亿独角兽,公开认错,重启招聘!
single codebook和dual codebook在LLM中向量量化上有什么区别?
一些文档去重算法
最佳的指令数据应当是什么样的?
Prefill-Decode分离
亲测有效!如何用 Address Sanitizer 精准定位内存漏洞?附保姆级操作指南
Simhash-文档去重算法简介
RLHF 入门,高手勿进!
最佳的指令数据应当是什么样的?
CosyVoice:一种基于监督式语义标记的可扩展多语言 Zero-Shot 语音合成器
Model Context Protocol (MCP)
MCP(模型上下文协议)是什么以及它是如何运作的
压力测试LLMs——大海捞针实现