==================

this is probably a nothing post, but as i’m trying to solve a problem: wanting to make a robot lawyer go brrr - i often have sidequests.

this sidequest just happened to be making my grpo implementation faster. to do that, i needed to improve upon a base grpo implementation and basically iterate until it’s faster (until i lose momentum, get bored, or both). pretty simple really.

the Second Brain Algorithm^tm is as follows:

  • read code/paper
  • ask gpt5-pro for any clarifying questions if something doesn’t make sense in my brain
  • think about ways it could be faster
  • ask gpt5-pro to implement each one so i can test them
  • run through bugs until each iteration works
  • go back and do some reading on things/tricks that make kernels faster
  • ask gpt5-pro to implement the tricks
  • repeat until i have a branch of the “faster tree” that shows promise (eg, the branch of the 20+ ideas that hasn’t been pruned from applying effort yet).
  • it’s really just mental breath first search

i went down a few rabbitholes trying my best to beat cublas matmuls and said forget it, i’m sure if i wanted to spend a long time on this i could have, but i don’t want to lose momentum, so lei decided to use cublas then write a softmax kernel and go from there.

basically, i got heavily inspired by liger-kernel’s chunking/streaming and flash attention’s accumulator tricks and i wanted to see if i can combine them in some way to make things faster. things ended up working out pretty well in my experimentation and i wanted to write down what i did before i move onto the next sidequest. here is the gist:

“regular” grpo:

  • generate a few answers, score them (that score is your advantage).
  • compare the new policy to the old one with a ratio (think “did we get more confident or less?").
  • clip that ratio so we don’t overreact to noisy scores.
  • optionally add a KL penalty to stay close to a reference model.
  • average the per‑token loss and backprop.

two big problems:

  • big matmul
  • lots of dead tokens

so how do we fix/speed that up? FlashGRPO located here: https://github.com/KayneWest/flashgrpo - the basic idea is this:

1) pack only real tokens (no full V matrix materialization): pack the non‑masked tokens into a tight matrix so we don’t waste compute on padding. vanilla grpo typically forms full logits $[N,V]$ (or full log-softmax) per step and then gather’s target log‑probs, which costs $O(N \cdot V)$ memory traffic. FlashGRPO instead:

  • packs tokens to keep only valid rows $N \le B \cdot T$ (mask + ignore index).
  • streams the vocabulary in tiles of width $C$ (where $C$ is auto‑tuned) and runs a cublas gemm $X[N,K] \times W_{v_0:v_1}^T[C,K] \to [N,C]$.
  • consumes each tile immediately with a triton reducer that performs online log-sum-exp (tracks per-row $(m,s)$) and gathers the target logit $z$ if it falls in the current tile. no global $[N,V]$ tensor is ever written.

this keeps temporary memory to $O(N \cdot C)$ instead of $O(N \cdot V)$ while preserving full‑vocab accuracy.

2) online LSE + softcap for numerics. the triton reducer maintains $(m, s)$ so that $\mathrm{LSE} = m + \log s$ is exact over all tiles. optional softcap clamps logits to $\pm \mathrm{softcap}$ during both forward and backward passes. temperature is applied in‑place during reduction to avoid extra gemms.

3) overlapped execution (eg dual streams). cublas gemm for tile $t+1$ runs concurrently with triton reduction for tile $t$ using ping‑pong buffers (concurrent read/write operations in two memory regions that alternate between being read from and written to; improving throughput/overlapping operations) and cuda events. this hides reduction latency and keeps the gemm saturated.

4) custom backward that never forms $p(\cdot)$ explicitly. instead of backpropagating through a full log-softmax, FlashGRPO reconstructs per‑tile gradients:

$$ \frac{\partial \ell}{\partial \mathrm{logits}{i,v}} = \underbrace{\frac{\partial \ell}{\partial \log p_i}}{\mathrm{grad_row}} \cdot \left( \mathbb{1}[v = y_i] \cdot g_t - p_{i,v} \cdot g_{i,v} \right) $$

then contracts locally where these contractions per tile are streamed across $V$, so no $[N,V]$ probability matrix is ever realized. gradients are accumulated in fp32 and cast back to parameter dtype at the end.

5) ratio/clip gradient is selective and stable. FlashGRPO computes token‑wise $\log p - \log p_{\mathrm{old}}$ in fp32 and clamps it to $[-10,10]$ before exponentiating. the gradient flows only through the chosen branch.

6) bf16/fp16 friendly with fp32 critical paths. inputs can be bf16/fp16; the reduction/LSE and accumulators are fp32. chunk size is auto‑selected given $N$, $V$, dtype‑bytes, and a MB budget.

TL;DR:

triton kernels reduce_chunk_online_lse (forward LSE + target gather), finalize_grpo_1d (ratio/clip + advantage), and the backward pair bwd_dx_from_tile / bwd_dwdb_from_tile handle all row‑packing, softcap gating, and fp32 accumulation; scheduling uses ping‑pong buffers with CUDA events; chunk sizing is chosen via choose_chunk_size given a memory budget.

Installation

git clone https://github.com/KayneWest/flashgrpo.git
cd flashgrpo
pip install -e .

Basic Usage

import torch
from flashgrpo import GRPOCuBLASxTritonPacked

# Initialize GRPO operator
grpo_op = GRPOCuBLASxTritonPacked(
    temperature=1.0,        # Temperature scaling
    epsilon_low=0.2,        # Lower clipping bound  
    epsilon_high=0.2,       # Upper clipping bound
    delta=None,             # Optional delta clipping
    softcap=30.0,           # Logit clamping
    chunk_size=4096,        # Fixed chunk size (auto if None)
    max_temp_mb=768,        # Memory budget for auto chunking
    beta=0.04,              # KL regularization coefficient
)

# Your model forward pass
with torch.cuda.amp.autocast(False):
    outputs = model(input_ids=batch["input_ids"],
                   attention_mask=batch["attention_mask"],
                   output_hidden_states=True)
    
    # Extract components
    H = outputs.hidden_states[-1][:, :-1, :].to(torch.bfloat16)  # [B,T,K]
    W = model.get_output_embeddings().weight                       # [V,K] 
    targets = batch["input_ids"][:, 1:]                           # [B,T]
    old_logps = batch["old_logps"][:, :-1]                        # [B,T]
    completion_mask = batch["completion_mask"][:, :-1]            # [B,T]
    advantages = batch["advantages"]                              # [B]
    
    ref_logps = get_reference_logps(batch)  # [B,T]

    # Compute GRPO loss
    loss = grpo_op(
        H,
        W,
        targets,
        old_logps, 
        completion_mask, 
        advantages,
        ref_per_token_logps=ref_logps
    )

loss.backward()
optimizer.step()

tada – sidequest completed and back to creating a robot thomas jefferson