通过对比的方式来说明triton的编程模式
接下来的例子是在cuda的gpu的场景下,计算两个长度都是10240的向量a和向量b,计算它们的点积。下面我管这个事情叫做“计算任务”。
为了方便讲解,需要对比 tvm实现(@T.prim_func) 、triton实现(@triton.jit) 和 cuda c++ 实现(核函数)这三种实现方式的不同。
通过对比才能更清晰triton的编程模式。我的代码附在最下方。所有的讲解都是照着下面的代码来讲解。
这三种方式中tvm的编程体验做好,在“@T.prim_func”函数中,使用python的语法串行的完成整个计算任务即可,完全采用最朴素的计算方式,
不需要考虑如何优化,优化的事情完全交给tvm即可。
而triton的@triton.jit函数只是是完成一个分块的计算,也就是说需要在更外层启动@triton.jit函数之前将整体的计算任务按照数据分成若干块,
然后每个@triton.jit函数负责其中某一个分块的数据的计算。通常来说是按照@triton.jit函数的输出张量,或者说是承载计算结果的张量来分块的。
这个输出张量如果是一维的,每个分块也是一维的,每个分块就是其中的一个局部的小段列表。
如果这个输出张量是二维的,那每个分块也是二维的,这个小分块的行数和列数可以是不同的,这个没关系,这个分块逻辑完全交给实现@triton.jit函数的程序员。
这里有个小问题,我们的计算任务中,计算的结果是点积,是个具体的数字,这个情况怎么分块呢。这种情况就按照输入向量来分块,分成若干小段,
@triton.jit函数就计算每个小段的点积,这个点积是局部的。在调用@triton.jit函数的外层需要再对这些@triton.jit函数的返回值进行一次求和。
最后再说cuda c++ 实现(核函数)这个场景,这是最原汁原味的编程方式,这里的一个核函数代表一个cuda的thread。所以需要在函数内部根据
当前的blockIdx.x和threadIdx.x等信息计算出当前thread的线程id,还要在外层调用核函数的时候,把每个线程负责总的计算任务中的哪个局部的计算定义好,
当前这个核函数只是负责其中一个具体的thread的计算任务,它计算的数据的粒度比@triton.jit函数就更小。
一个@triton.jit函数内的逻辑对应cuda中的一个线程block,即一个具体的blockIdx.x。通过函数内部的pid = tl.program_id(axis=0)得到具体的blockIdx.x的block的id值。
@triton.jit函数中,加载数据都用 a = tl.load(a_ptr + offsets, mask=mask) 类似这种方式,tl.load的返回值是个张量,
这个张量可以是一维的、二维的 或者 其他维度的,这完全取决于offsets是什么维度的,tl.load的返回值的张量形状和offsets的形状是一样的。
a = tl.load(a_ptr + offsets, mask=mask)后续的操作都是张量的操作,比如针对a+b是张量的操作,
在底层是多个cuda的thread在并行做计算,每个thread负责张量中的一个元素或者几个元素的加法计算。至于一个thread负责几个元素,是受到这些参数决定的:
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> 这个triton的ir在pass过程中的属性,
其中sizePerThread = [1] 表示每个线程只处理 1 个元素。若为 4,则每个线程处理连续的 4 个元素。
threadsPerWarp = [32]:一个线程束(warp)有32个线程。
warpsPerCTA = [4]:一个线程块(CTA)有4个 warp。其中CTA就是Cooperative Thread Array,即 CUDA 中的线程块(thread block),即就是blockIdx.x语境中的线程块block的意思。
order = [0]:维度遍历顺序,按行主序或列主序。
根据这些参数做计算,如果tl.load返回的张量中的元素数量超过一个block中的thread的数量,那就会一个thread负责多个元素的计算。
接下来再说一说reduce操作,这个在计算逻辑中必不可少,常见的就是sum、max、argmax等。在triton中,
reduce操作要调用tl.开头个各种原语,比如tl.sum、tl.max等,强烈不建议自己采用python的代码来实现,那样的话无法在后续triton的各种pass中,被优化成并行的高效的计算逻辑。
而对tvm,就没有这个限制,可以采取任意python的代码来朴素的实现各种reduce逻辑,就可以被tvm后续的各种调度规则识别和优化。
而针对cuda c++ 实现核函数中,要实现reduce的逻辑,就会比较有挑战,见下面例子中的代码。
再说一下矩阵乘法,即矩阵A乘以矩阵B得到矩阵C这个操作,如果在triton的编程模式中,需要合适的选取分块逻辑,通常来讲,就会针对结果矩阵C来分块。
这就需要按照线性代数中分块矩阵乘法(Block Matrix Multiplication)的数学原理来操作。
即假设A、B、C三个矩阵都分完块了,接下来把它们的每个块想象成一个元素,
C矩阵中的第m行,第n列这个块,就等于= A中的第m行所有的块 与 B中第n列中所有的块,按照对应的位置相乘,然后再求和。
如果把它们的每个块想象成普通矩阵中一个元素,那这个过程就和矩阵相乘的原理一样。
例子如下:
1. TVM 实现(@T.prim_func):(只需编写数学上正确的朴素串行逻辑)
import tvmfrom tvm import tirimport numpy as npVEC_LEN = 10240@tir.prim_funcdef vec_dot_tvm(a: tir.handle,b: tir.handle,dot_out: tir.handle, # 输出标量(TVM中用长度为1的张量表示)length: tir.int32) -> None:# 声明张量内存布局A = tir.match_buffer(a, (length,), "float32")B = tir.match_buffer(b, (length,), "float32")DotOut = tir.match_buffer(dot_out, (1,), "float32")# 最朴素的串行reduce逻辑:初始化累加器+遍历所有元素with tir.block([]): # 无迭代域的block表示标量计算tir.bind(tir.thread_axis("root"), 0) # 根线程acc = tir.allocate((1,), "float32", "local") # 局部累加器acc[0] = 0.0# 遍历所有元素,对位相乘后累加(完全串行的逻辑)with tir.block([length], "reduce_loop") as [i]:acc[0] = acc[0] + A[i] * B[i]# 将累加结果写入输出DotOut[0] = acc[0]# 编译为CUDA可执行模块(opt_level=3开启全量优化)target = tvm.target.Target("cuda")with tvm.transform.PassContext(opt_level=3):mod = tvm.build(vec_dot_tvm, target=target)# 测试验证dev = tvm.cuda(0)a_np = np.random.randn(VEC_LEN).astype(np.float32)b_np = np.random.randn(VEC_LEN).astype(np.float32)dot_np = np.array([np.dot(a_np, b_np)], dtype=np.float32) # 标准答案# 准备TVM张量a_tvm = tvm.nd.array(a_np, dev)b_tvm = tvm.nd.array(b_np, dev)dot_tvm = tvm.nd.array(np.zeros(1, dtype=np.float32), dev)# 执行TVM kernelmod(a_tvm, b_tvm, dot_tvm, VEC_LEN)# 验证结果assert np.allclose(dot_tvm.numpy(), dot_np, rtol=1e-4)print(f"TVM点积结果验证通过:{dot_tvm.numpy()[0]:.4f}")
2. Triton 实现(@triton.jit):
import tritonimport triton.language as tlimport torchVEC_LEN = 10240@triton.jitdef vec_dot_triton(a_ptr, # 输入a的设备指针b_ptr, # 输入b的设备指针partial_ptr, # 存储分块局部和的输出指针length: tl.constexpr, # 向量总长度(编译期常量)BLOCK_SIZE: tl.constexpr # 每个分块处理的元素数):# 1. 计算当前分块的ID(对应CUDA的blockIdx.x)pid = tl.program_id(axis=0)# 2. 计算当前分块的元素偏移start = pid * BLOCK_SIZEoffsets = start + tl.arange(0, BLOCK_SIZE)# 3. 边界掩码(防止越界访问)mask = offsets < length# 4. 加载分块内的a、b元素(tl.load为每个元素绑定一个lane/thread)a = tl.load(a_ptr + offsets, mask=mask)b = tl.load(b_ptr + offsets, mask=mask)# 5. 分块内reduce:对位相乘后累加(必须用tl.sum,否则无法高效并行)partial_sum = tl.sum(a * b, axis=0)# 6. 存储当前分块的局部和(每个分块输出一个标量)tl.store(partial_ptr + pid, partial_sum)# 封装启动逻辑(外层汇总分块局部和)def vec_dot_triton_wrapper(a: torch.Tensor, b: torch.Tensor) -> float:assert a.is_cuda and b.is_cuda and a.dtype == torch.float32assert a.shape[0] == b.shape[0] == VEC_LEN# 定义分块大小(需为32的倍数,适配CUDA warp)BLOCK_SIZE = 256# 计算分块数(向上取整)num_blocks = (VEC_LEN + BLOCK_SIZE - 1) // BLOCK_SIZE# 分配存储分块局部和的张量(num_blocks个元素)partial_sums = torch.empty(num_blocks, device=a.device, dtype=torch.float32)# 启动Triton kernel(grid维度为(num_blocks,),对应CUDA的block数)vec_dot_triton[(num_blocks,)](a, b, partial_sums, VEC_LEN, BLOCK_SIZE)# 外层汇总所有分块的局部和,得到最终标量点积total_dot = partial_sums.sum().item()return total_dot# 测试验证a_torch = torch.randn(VEC_LEN, device="cuda", dtype=torch.float32)b_torch = torch.randn(VEC_LEN, device="cuda", dtype=torch.float32)triton_dot = vec_dot_triton_wrapper(a_torch, b_torch)torch_dot = torch.dot(a_torch, b_torch).item()assert abs(triton_dot - torch_dot) < 1e-4print(f"Triton点积结果验证通过:{triton_dot:.4f}")
3. CUDA C++ 实现核函数:
// vec_dot_cuda.cu#include<cuda_runtime.h>#include<stdio.h>#include<stdlib.h>#include<math.h>const int VEC_LEN = 10240;// CUDA核函数:计算分块局部和(每个block输出一个局部和)__global__ voidvec_dot_cuda(constfloat* a, constfloat* b, float* partial_sums, int length){// 共享内存:存储block内的中间累加结果(大小=blockDim.x)__shared__ float s_partial[256];// 1. 计算当前thread的全局索引int global_idx = blockIdx.x * blockDim.x + threadIdx.x;// 2. 每个thread累加多个元素(避免thread数不足)float thread_sum = 0.0f;while (global_idx < length) {thread_sum += a[global_idx] * b[global_idx];global_idx += blockDim.x * gridDim.x; // 步长=总thread数}// 3. 将thread累加结果写入共享内存s_partial[threadIdx.x] = thread_sum;__syncthreads(); // 同步block内所有thread,确保共享内存数据完整// 4. Block内reduce(二分法,高效累加共享内存中的值)for (int s = blockDim.x / 2; s > 0; s >>= 1) {if (threadIdx.x < s) {s_partial[threadIdx.x] += s_partial[threadIdx.x + s];}__syncthreads(); // 每轮累加后同步}// 5. Block内第一个thread将局部和写入输出if (threadIdx.x == 0) {partial_sums[blockIdx.x] = s_partial[0];}}// 主机端封装函数:分配内存+启动核函数+汇总局部和floatvec_dot_wrapper(constfloat* a_host, constfloat* b_host){// 1. 分配设备内存float *a_dev, *b_dev, *partial_sums_dev;int block_size = 256; // 每个block的thread数int grid_size = (VEC_LEN + block_size - 1) / block_size; // 总block数cudaMalloc(&a_dev, VEC_LEN * sizeof(float));cudaMalloc(&b_dev, VEC_LEN * sizeof(float));cudaMalloc(&partial_sums_dev, grid_size * sizeof(float));if (!a_dev || !b_dev || !partial_sums_dev) return NAN;// 2. 主机→设备数据拷贝cudaMemcpy(a_dev, a_host, VEC_LEN * sizeof(float), cudaMemcpyHostToDevice);cudaMemcpy(b_dev, b_host, VEC_LEN * sizeof(float), cudaMemcpyHostToDevice);// 3. 启动CUDA核函数(<<<grid_size, block_size>>> 配置线程)vec_dot_cuda<<<grid_size, block_size>>>(a_dev, b_dev, partial_sums_dev, VEC_LEN);if (cudaGetLastError() != cudaSuccess) return NAN;// 4. 设备→主机拷贝局部和float* partial_sums_host = (float*)malloc(grid_size * sizeof(float));cudaMemcpy(partial_sums_host, partial_sums_dev, grid_size * sizeof(float), cudaMemcpyDeviceToHost);// 5. 主机端汇总所有block的局部和,得到最终点积float total_dot = 0.0f;for (int i = 0; i < grid_size; ++i) {total_dot += partial_sums_host[i];}// 6. 释放内存free(partial_sums_host);cudaFree(a_dev);cudaFree(b_dev);cudaFree(partial_sums_dev);return total_dot;}// 测试主函数intmain(){// 生成随机输入float a_host[VEC_LEN], b_host[VEC_LEN];float ref_dot = 0.0f; // 标准答案(主机端计算)for (int i = 0; i < VEC_LEN; ++i) {a_host[i] = (float)rand() / RAND_MAX;b_host[i] = (float)rand() / RAND_MAX;ref_dot += a_host[i] * b_host[i];}// 调用CUDA实现float cuda_dot = vec_dot_wrapper(a_host, b_host);// 验证结果if (fabs(cuda_dot - ref_dot) < 1e-4) {printf("CUDA点积结果验证通过:%.4f\n", cuda_dot);} else {printf("CUDA点积结果验证失败!参考值:%.4f,CUDA值:%.4f\n", ref_dot, cuda_dot);}return 0;}
夜雨聆风
