0


手写 Triton Softmax Kernel:程序实例、块大小、mask 与指针算术

GPU 编程看起来总像黑魔法,满眼是 warpsshared memorytensor cores,还有 kernel 里古怪的索引运算。但是这篇文章从一个具体例子入手帮你理解 Triton:从头实现一个 softmax kernel。

以官方 Triton 教程为基础,深入代码背后的原理并配上手绘图解。如果你觉得 GPU 编程教程总是太晦涩,这篇文章正好可以用来入门。

我们的目标不止是写一个 kernel而是理解现代 AI 工作负载在 GPU 上到底怎么跑。

最后会把 kernel 放到 RTX 5090 上跟 PyTorch 的原生 softmax 跑个 benchmark。结果不是简单的"Triton 赢了"——这里有个性能悬崖,教会你 GPU 编程里很重要的一件事。

Softmax:简单的数学,隐藏的内存问题

逐行 softmax 从数学上很简单:每行是一个独立 logit 向量,softmax 把它转成概率。

比如一个

2×3

矩阵,不是对六个值算一个大 softmax,而是算两个独立的 softmax——行 0 一个、行 1 一个。

难点不在数学而是在 GPU 上的执行方式:数据搬几次、中间值存在哪、GPU 是花时间算还是在等内存。

简单的 PyTorch 实现把 softmax 拆成几个独立的张量操作:max、减法、指数、求和、除法。每一步都可能从全局内存读数据再把中间值写回去。

而融合的 Triton kernel 改变了这个模式:一次加载一行,所有 softmax 步骤在数据留在片上时完成,最后一次性写回结果。

这里的片外指 GPU 全局内存/DRAM:大但慢。片上指 GPU 计算单元内部的内存(寄存器或共享内存/SRAM):快得多但小得多。

从概念上说一个 Triton 程序处理一行,但实际运行时是大量 Triton 程序并行跑。

一个简单的 Triton模型

在看 softmax kernel 之前,先搭个简单的模型。

一个

3072

长度的向量

X

,要给每个元素减 1。

CPU 思路是顺序循环:

 foriinrange(3072):  
     X[i] =X[i] -1

在 GPU 上就不是这样了,GPU 要把向量切成块,并行处理。

Triton 里,一个 kernel 描述一个程序实例的行为。启动 kernel 时,启动一个网格,里面很多程序实例并行跑。

 BLOCK_SIZE=1024

每个程序实例处理

1024

个元素。

 3072 / 1024 = 3 → 需要 3 个程序实例。  
 program 0 → elements 0-1023  
 program 1 → elements 1024-2047  
 program 2 → elements 2048-3071

每个程序实例拿到自己的

program_id

,用它定位数据切片,执行相同操作。

Softmax kernel 里也一样,只是每个程序实例处理矩阵的一行,不是向量的一块。

逐行拆解 Triton Softmax Kernel

一个 Triton 程序实例一次处理一行。启动的程序数少于行数时,每个程序以固定步长在矩阵中跳跃,处理多行。

 @triton.jit  
def softmax_kernel(  
    output_ptr, input_ptr,  
    input_row_stride, output_row_stride,  
    n_rows, n_cols,  
    BLOCK_SIZE: tl.constexpr,  
    num_stages: tl.constexpr,  
):  
    row_start = tl.program_id(0)   # 当前程序实例 ID  
    row_step = tl.num_programs(0)  # 轴 0 上的实例总数  
      
     for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):

tl.program_id(0)拿到当前实例的 id。

如果启了 4 个程序,program 0 从 row 0 开始,program 1 从 row 1 开始以此类推,每个程序按

row_step

跳跃处理后续行。

row_stride

告诉程序在内存里走多远才到下一行的开头。一个常见错误是认为下一行总在

n_cols

个元素之后开始——对紧凑连续张量是对的但不是所有布局都这样。

 # 指向当前行在内存中的起始位置  
 row_start_ptr = input_ptr + row_idx * input_row_stride  
 col_offsets = tl.arange(0, BLOCK_SIZE)  
 input_ptrs = row_start_ptr + col_offsets


区分两个概念:

n_cols

是逻辑列数,

input_row_stride

是两行之间的物理内存距离。

 mask = col_offsets < n_cols  
 row = tl.load(input_ptrs, mask=mask, other=-float("inf"))

mask 告诉 Triton 只加载实际列,假列用

-inf

填充,因为

exp(-inf) = 0

不影响 softmax 分母。

 row_minus_max = row - tl.max(row, axis=0)    
 numerator = tl.exp(row_minus_max)    
 denominator = tl.sum(numerator, axis=0)    
 softmax_output = numerator / denominator

先减最大值保数值稳定,不改变 softmax 结果但防止指数溢出。这些操作都在同一个融合的 Triton 程序里——

row_minus_max

numerator

denominator

不会作为中间张量写回全局内存。

启动 Kernel:Python 包装器

Triton kernel 描述一个程序实例内部干什么,但实际问题需要 Python 代码来回答:块多大?多少 warp?启动几个程序?

 def softmax(x):  
     n_rows, n_cols = x.shape  
     BLOCK_SIZE = triton.next_power_of_2(n_cols)

选择 2 的幂的 BLOCK_SIZE——适合 Triton 的块编程模型和归约操作。一行 3000 列?BLOCK_SIZE 用 4096,多余的用 mask 屏蔽。

 num_warps = 8

Warp 是一组一起执行的 GPU 线程,

num_warps = 8

意味着每个 Triton 程序实例用 8 个 warp。

 num_stages = 4 if SIZE_SMEM > 200000 else 2

num_stages和程序、warp 是不同的,它帮助同一程序内的循环迭代重叠——比如一轮加载、一轮计算、一轮写入同时进行。不过更多阶段用更多片上资源并不一定更好。

 y = torch.empty_like(x)

为输出分配和输入同 shape、dtype、device 的张量。

 kernel = softmax_kernel.warmup(  
    y, x, x.stride(0), y.stride(0),  
    n_rows, n_cols,  
    BLOCK_SIZE=BLOCK_SIZE, num_stages=num_stages, num_warps=num_warps,  
    grid=(1,),  
)  
kernel._init_handles()  
n_regs = kernel.n_regs  
 size_smem = kernel.metadata.shared

先编译一次 kernel,看看一个程序实例消耗多少寄存器和共享内存。

GPU 流多处理器资源有限。每个 SM 有固定的寄存器和共享内存预算。一个程序用太多,同一 SM 能同时跑的程序就少。

 occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)  
 occupancy = min(occupancy, SIZE_SMEM // size_smem)  
 num_programs = NUM_SM * occupancy  
 num_programs = min(num_programs, n_rows)

占用率受限于最先耗尽的资源。这是持久化风格 kernel:不是每行启一个程序,而是启足够程序占满 GPU,每个程序循环处理多行。

基准测试

RTX 5090 上逐行 softmax benchmark,

_M = 4096_

*,

_N_

变化。*

中小行大小下 PyTorch 更快,意料之中。但

N ≈ 8700

附近两边都撞上性能悬崖。之后 Triton kernel 反超。

这不意味着 Triton 万能更快,因为GPU 性能高度依赖张量形状、块大小、资源使用。y 轴是有效带宽,从输入输出张量大小算出,不是每次内部内存事务。

Triton 实现中,

N

超过

8192

BLOCK_SIZE

跳到

16384

,每个程序实例内部操作更大的块,资源压力上升,性能出现突变。

总结

Triton 可以让你在接近 Python 的层面写 GPU kernel 的方式。这个例子也告诉我们不是 Triton 总比 PyTorch 快,因为PyTorch 已经高度优化了。

本文代码

https://github.com/lounishamroun/optimization_sandbox/blob/main/triton_kernels/basics/softmax_kernel.py

by Lounis Hamroun

“手写 Triton Softmax Kernel:程序实例、块大小、mask 与指针算术”的评论:

还没有评论