FlashInfer 源码级解读:大模型推理的"底层引擎"是怎么炼成的
如果你关注过大模型推理加速,大概率听过 FlashAttention。
但比 FlashAttention 更底层的,是一个叫 FlashInfer 的开源项目。GitHub 5,400+ 星,今天依然在持续更新。
大多数讨论推理加速的文章都在说模型层面的优化:量化、蒸馏、MoE。很少有人聊到最底层的——GPU kernel 级别。
今天这篇,我想带你看看 FlashInfer 到底在做什么,以及为什么它对大模型推理生态如此重要。
Kernel 是什么
在深入之前,先解决一个概念问题。
GPU 上跑的每一段计算程序叫 kernel。大模型的注意力计算(Attention)就是由一系列 GPU kernel 组合完成的。
你平时用的 PyTorch、TensorFlow,它们背后调用的就是这些 kernel。但默认的 kernel 往往不是最优的——因为它们是通用实现,要兼容各种硬件和场景。
FlashInfer 做的事是:为特定的大模型推理场景,手写最优 GPU kernel。
FlashAttention 的遗产
要理解 FlashInfer,得先说 FlashAttention。
2022 年斯坦福的三篇论文提出了 FlashAttention,核心思路是:Attention 计算中的中间矩阵太大,显存放不下,导致频繁读写 HBM(GPU 高带宽内存),性能被内存带宽卡住了。
FlashAttention 的解法叫 IO-aware:把 Attention 分块计算,每块的数据刚好塞进 SRAM(GPU 片上缓存),算完再写回。
这个思路直接把 Attention 的内存访问开销砍了一半以上,推理速度大幅提升。
但 FlashAttention 有个问题:它只管前向传播的 Attention,而且主要面向训练场景。 推理阶段的很多特殊需求它没覆盖到。
FlashInfer 补上了什么
FlashInfer 团队看到了这个缺口。
大模型推理有几个跟训练完全不同的特点:
第一,自回归生成。 推理时是一个 token 一个 token 生成的,每一步的序列长度在增长。这意味着 KV Cache 的读取模式跟训练时完全不一样。
第二,Batch 大小不固定。 推理服务的并发量随时变化,从 1 到几百都有可能。Kernel 需要适应不同 batch size。
第三,PageAttention。 vLLM 引入的 PageAttention 把 KV Cache 分页管理,kernel 需要支持非连续的内存访问。
FlashInfer 针对这些场景,重写了整套推理用 Attention kernel。
它的核心模块包括:
单请求推理 kernel。专门为单条请求的自回归生成优化,延迟压到最低。这是线上服务最常用的场景。
Batch 推理 kernel。支持动态 batch size,用 continuous batching 策略管理并发请求。
PageAttention kernel。跟 vLLM 的分页管理配合,支持非连续 KV Cache 的高效读取。
Speculative decoding kernel。为投机采样加速提供底层 kernel 支持。这个后面会细说。
一个值得关注的细节:算子融合
FlashInfer 有一个很多人忽略的优化方向:算子融合。
大模型的每一层,Attention 之后跟着一堆操作:RoPE、LayerNorm、激活函数。如果每个操作单独启动一个 kernel,kernel launch 的开销会累积起来。
FlashInfer 的做法是把几个小算子融合成一个大 kernel。比如 RoPE + Attention 融合,LayerNorm + 激活函数融合。
效果呢?根据他们的 benchmark,在高并发场景下,算子融合能再省 10-15% 的延迟。
这个数据听起来不大,但在推理延迟已经很低的场景下,10% 的优化就是”有感”和”无感”的差别。
跟 SGLang、vLLM 的关系
你可能会问:FlashInfer 跟 SGLang、vLLM 是什么关系?
简单来说:FlashInfer 是底层算子库,SGLang 和 vLLM 是上层推理框架。
vLLM 已经内置了 FlashInfer 作为可选的 kernel 后端。SGLang 的部分优化也参考了 FlashInfer 的思路。
它们不是竞争关系,而是上下游。FlashInfer 提供了更高效的底层算子,上层框架调用它来提升推理速度。
这种分工我觉得是健康的。底层的人专注把 kernel 写到极致,上层的人专注做好调度和用户体验。
开发者怎么用
如果你的项目在用 vLLM,FlashInfer 已经可以作为 backend 选项。在启动参数里指定就行。
如果你在做自己的推理服务,可以直接用 FlashInfer 的 Python API 调用它的 kernel。文档写得比较详细,入门门槛不高。
不过也要说实话:如果你只是想”跑通一个大模型”,不需要关心 kernel 级别的优化。FlashInfer 更适合那些对延迟和吞吐量有极致要求的场景。
我的看法
大模型推理优化有一条清晰的层次:
-
模型层:量化、蒸馏、MoE -
框架层:vLLM、SGLang 的调度和 batching -
算子层:FlashAttention、FlashInfer 的 kernel 优化
每一层都有自己的价值。但我觉得最被低估的是算子层——因为大多数人看不到它的存在,但它默默支撑着上面的一切。
FlashInfer 团队在做的事情,有点像给赛车换发动机。你坐在车里感受不到发动机的变化,但速度确实快了。
在大模型推理这个赛道,kernel 级别还有没有优化空间?我觉得还有。毕竟 GPU 架构在演进,模型结构在变化,kernel 的优化永远没有终点。
下次有人跟你说”大模型推理已经优化到头了”,你可以告诉他:去看看 FlashInfer 的 commit log。
夜雨聆风