乐于分享
好东西不私藏

FlashInfer 源码级解读:大模型推理的"底层引擎"是怎么炼成的

FlashInfer 源码级解读:大模型推理的"底层引擎"是怎么炼成的

如果你关注过大模型推理加速,大概率听过 FlashAttention。

但比 FlashAttention 更底层的,是一个叫 FlashInfer 的开源项目。GitHub 5,400+ 星,今天依然在持续更新。

大多数讨论推理加速的文章都在说模型层面的优化:量化、蒸馏、MoE。很少有人聊到最底层的——GPU kernel 级别。

今天这篇,我想带你看看 FlashInfer 到底在做什么,以及为什么它对大模型推理生态如此重要。

Kernel 是什么

在深入之前,先解决一个概念问题。

GPU 上跑的每一段计算程序叫 kernel。大模型的注意力计算(Attention)就是由一系列 GPU kernel 组合完成的。

你平时用的 PyTorch、TensorFlow,它们背后调用的就是这些 kernel。但默认的 kernel 往往不是最优的——因为它们是通用实现,要兼容各种硬件和场景。

FlashInfer 做的事是:为特定的大模型推理场景,手写最优 GPU kernel。

FlashAttention 的遗产

要理解 FlashInfer,得先说 FlashAttention。

2022 年斯坦福的三篇论文提出了 FlashAttention,核心思路是:Attention 计算中的中间矩阵太大,显存放不下,导致频繁读写 HBM(GPU 高带宽内存),性能被内存带宽卡住了。

FlashAttention 的解法叫 IO-aware:把 Attention 分块计算,每块的数据刚好塞进 SRAM(GPU 片上缓存),算完再写回。

这个思路直接把 Attention 的内存访问开销砍了一半以上,推理速度大幅提升。

但 FlashAttention 有个问题:它只管前向传播的 Attention,而且主要面向训练场景。 推理阶段的很多特殊需求它没覆盖到。

FlashInfer 补上了什么

FlashInfer 团队看到了这个缺口。

大模型推理有几个跟训练完全不同的特点:

第一,自回归生成。 推理时是一个 token 一个 token 生成的,每一步的序列长度在增长。这意味着 KV Cache 的读取模式跟训练时完全不一样。

第二,Batch 大小不固定。 推理服务的并发量随时变化,从 1 到几百都有可能。Kernel 需要适应不同 batch size。

第三,PageAttention。 vLLM 引入的 PageAttention 把 KV Cache 分页管理,kernel 需要支持非连续的内存访问。

FlashInfer 针对这些场景,重写了整套推理用 Attention kernel。

它的核心模块包括:

单请求推理 kernel。专门为单条请求的自回归生成优化,延迟压到最低。这是线上服务最常用的场景。

Batch 推理 kernel。支持动态 batch size,用 continuous batching 策略管理并发请求。

PageAttention kernel。跟 vLLM 的分页管理配合,支持非连续 KV Cache 的高效读取。

Speculative decoding kernel。为投机采样加速提供底层 kernel 支持。这个后面会细说。

一个值得关注的细节:算子融合

FlashInfer 有一个很多人忽略的优化方向:算子融合

大模型的每一层,Attention 之后跟着一堆操作:RoPE、LayerNorm、激活函数。如果每个操作单独启动一个 kernel,kernel launch 的开销会累积起来。

FlashInfer 的做法是把几个小算子融合成一个大 kernel。比如 RoPE + Attention 融合,LayerNorm + 激活函数融合。

效果呢?根据他们的 benchmark,在高并发场景下,算子融合能再省 10-15% 的延迟。

这个数据听起来不大,但在推理延迟已经很低的场景下,10% 的优化就是”有感”和”无感”的差别。

跟 SGLang、vLLM 的关系

你可能会问:FlashInfer 跟 SGLang、vLLM 是什么关系?

简单来说:FlashInfer 是底层算子库,SGLang 和 vLLM 是上层推理框架。

vLLM 已经内置了 FlashInfer 作为可选的 kernel 后端。SGLang 的部分优化也参考了 FlashInfer 的思路。

它们不是竞争关系,而是上下游。FlashInfer 提供了更高效的底层算子,上层框架调用它来提升推理速度。

这种分工我觉得是健康的。底层的人专注把 kernel 写到极致,上层的人专注做好调度和用户体验。

开发者怎么用

如果你的项目在用 vLLM,FlashInfer 已经可以作为 backend 选项。在启动参数里指定就行。

如果你在做自己的推理服务,可以直接用 FlashInfer 的 Python API 调用它的 kernel。文档写得比较详细,入门门槛不高。

不过也要说实话:如果你只是想”跑通一个大模型”,不需要关心 kernel 级别的优化。FlashInfer 更适合那些对延迟和吞吐量有极致要求的场景。

我的看法

大模型推理优化有一条清晰的层次:

  • 模型层:量化、蒸馏、MoE
  • 框架层:vLLM、SGLang 的调度和 batching
  • 算子层:FlashAttention、FlashInfer 的 kernel 优化

每一层都有自己的价值。但我觉得最被低估的是算子层——因为大多数人看不到它的存在,但它默默支撑着上面的一切。

FlashInfer 团队在做的事情,有点像给赛车换发动机。你坐在车里感受不到发动机的变化,但速度确实快了。

在大模型推理这个赛道,kernel 级别还有没有优化空间?我觉得还有。毕竟 GPU 架构在演进,模型结构在变化,kernel 的优化永远没有终点。

下次有人跟你说”大模型推理已经优化到头了”,你可以告诉他:去看看 FlashInfer 的 commit log。