FlashGRPO
==================
this is probably a nothing post, but as i’m trying to solve a problem: wanting to make a robot lawyer go brrr - i often have sidequests.
this sidequest just happened to be making my grpo implementation faster. to do that, i needed to improve upon a base grpo implementation and basically iterate until it’s faster (until i lose momentum, get bored, or both). pretty simple really.
the Second Brain Algorithm^tm is as follows:
- read code/paper
- ask gpt5-pro for any clarifying questions if something doesn’t make sense in my brain
- think about ways it could be faster
- ask gpt5-pro to implement each one so i can test them
- run through bugs until each iteration works
- go back and do some reading on things/tricks that make kernels faster
- ask gpt5-pro to implement the tricks
- repeat until i have a branch of the “faster tree” that shows promise (eg, the branch of the 20+ ideas that hasn’t been pruned from applying effort yet).
- it’s really just mental breath first search
i went down a few rabbitholes trying my best to beat cublas matmuls and said forget it, i’m sure if i wanted to spend a long time on this i could have, but i don’t want to lose momentum, so lei decided to use cublas then write a softmax kernel and go from there.
basically, i got heavily inspired by liger-kernel’s chunking/streaming and flash attention’s accumulator tricks and i wanted to see if i can combine them in some way to make things faster. things ended up working out pretty well in my experimentation and i wanted to write down what i did before i move onto the next sidequest. here is the gist:
“regular” grpo:
- generate a few answers, score them (that score is your advantage).
- compare the new policy to the old one with a ratio (think “did we get more confident or less?").
- clip that ratio so we don’t overreact to noisy scores.
- optionally add a KL penalty to stay close to a reference model.
- average the per‑token loss and backprop.
two big problems:
- big matmul
- lots of dead tokens
so how do we fix/speed that up? FlashGRPO located here: https://github.com/KayneWest/flashgrpo - the basic idea is this:
1) pack only real tokens (no full V matrix materialization): pack the non‑masked tokens into a tight matrix so we don’t waste compute on padding. vanilla grpo typically forms full logits $[N,V]$ (or full log-softmax) per step and then gather’s target log‑probs, which costs $O(N \cdot V)$ memory traffic. FlashGRPO instead:
- packs tokens to keep only valid rows $N \le B \cdot T$ (mask + ignore index).
- streams the vocabulary in tiles of width $C$ (where $C$ is auto‑tuned) and runs a cublas gemm $X[N,K] \times W_{v_0:v_1}^T[C,K] \to [N,C]$.
- consumes each tile immediately with a triton reducer that performs online log-sum-exp (tracks per-row $(m,s)$) and gathers the target logit $z$ if it falls in the current tile. no global $[N,V]$ tensor is ever written.
this keeps temporary memory to $O(N \cdot C)$ instead of $O(N \cdot V)$ while preserving full‑vocab accuracy.
2) online LSE + softcap for numerics. the triton reducer maintains $(m, s)$ so that $\mathrm{LSE} = m + \log s$ is exact over all tiles. optional softcap clamps logits to $\pm \mathrm{softcap}$ during both forward and backward passes. temperature is applied in‑place during reduction to avoid extra gemms.
3) overlapped execution (eg dual streams). cublas gemm for tile $t+1$ runs concurrently with triton reduction for tile $t$ using ping‑pong buffers (concurrent read/write operations in two memory regions that alternate between being read from and written to; improving throughput/overlapping operations) and cuda events. this hides reduction latency and keeps the gemm saturated.
4) custom backward that never forms $p(\cdot)$ explicitly. instead of backpropagating through a full log-softmax, FlashGRPO reconstructs per‑tile gradients:
$$ \frac{\partial \ell}{\partial \mathrm{logits}{i,v}} = \underbrace{\frac{\partial \ell}{\partial \log p_i}}{\mathrm{grad_row}} \cdot \left( \mathbb{1}[v = y_i] \cdot g_t - p_{i,v} \cdot g_{i,v} \right) $$
then contracts locally where these contractions per tile are streamed across $V$, so no $[N,V]$ probability matrix is ever realized. gradients are accumulated in fp32 and cast back to parameter dtype at the end.
5) ratio/clip gradient is selective and stable. FlashGRPO computes token‑wise $\log p - \log p_{\mathrm{old}}$ in fp32 and clamps it to $[-10,10]$ before exponentiating. the gradient flows only through the chosen branch.
6) bf16/fp16 friendly with fp32 critical paths. inputs can be bf16/fp16; the reduction/LSE and accumulators are fp32. chunk size is auto‑selected given $N$, $V$, dtype‑bytes, and a MB budget.
TL;DR:
triton kernels reduce_chunk_online_lse
(forward LSE + target gather), finalize_grpo_1d
(ratio/clip + advantage), and the backward pair bwd_dx_from_tile
/ bwd_dwdb_from_tile
handle all row‑packing, softcap gating, and fp32 accumulation; scheduling uses ping‑pong buffers with CUDA events; chunk sizing is chosen via choose_chunk_size
given a memory budget.
Installation
git clone https://github.com/KayneWest/flashgrpo.git
cd flashgrpo
pip install -e .
Basic Usage
import torch
from flashgrpo import GRPOCuBLASxTritonPacked
# Initialize GRPO operator
grpo_op = GRPOCuBLASxTritonPacked(
temperature=1.0, # Temperature scaling
epsilon_low=0.2, # Lower clipping bound
epsilon_high=0.2, # Upper clipping bound
delta=None, # Optional delta clipping
softcap=30.0, # Logit clamping
chunk_size=4096, # Fixed chunk size (auto if None)
max_temp_mb=768, # Memory budget for auto chunking
beta=0.04, # KL regularization coefficient
)
# Your model forward pass
with torch.cuda.amp.autocast(False):
outputs = model(input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
output_hidden_states=True)
# Extract components
H = outputs.hidden_states[-1][:, :-1, :].to(torch.bfloat16) # [B,T,K]
W = model.get_output_embeddings().weight # [V,K]
targets = batch["input_ids"][:, 1:] # [B,T]
old_logps = batch["old_logps"][:, :-1] # [B,T]
completion_mask = batch["completion_mask"][:, :-1] # [B,T]
advantages = batch["advantages"] # [B]
ref_logps = get_reference_logps(batch) # [B,T]
# Compute GRPO loss
loss = grpo_op(
H,
W,
targets,
old_logps,
completion_mask,
advantages,
ref_per_token_logps=ref_logps
)
loss.backward()
optimizer.step()
tada – sidequest completed and back to creating a robot thomas jefferson