Peeking Under the Hood of prime-rl
Published:
I’d been following the INTELLECT-2 paper and other PrimeIntellect work, but what really piqued my curiosity was PrimeIntellect-ai/prime-rl
. The promise was bold: fully asynchronous, file-based RL that scales across decentralized devices. I wanted to understand exactly how it worked—scheduler quirks, memory tricks, the rollout loop, so I asked o3 to be my copilot. What followed was a week-long conversation in which we spelunked through every Python file until a coherent picture emerged. (While at it, I started a fork and sprinkled a few small QoL commits of my own → kevinbdsouza/prime-rl
.)
The README is great for a quick run, but I wanted to see the gears turning: where do rollouts live? who shards what? when does the learner talk to vLLM? Instead of diving into 30-odd files by hand I fired up o3 and asked it to annotate every module. Over a few prompt-and-refine cycles it produced a crisp map of the whole src/zeroband/
package (pasted full in Appendix). I then cross-checked the details and figured a write-up might save others an afternoon. Below is the distilled notebook-to-blog narrative.
Why PRIME-RL is worth dissecting
- Asynchronous, file-based RL loop – no parameter server, just Parquet shards in
/step_k/
and a single-file safetensors checkpoint. - Group Relative Policy Optimisation (GRPO) – the leaner PPO cousin with token-level control-variates.
- FSDP + activation ckpt – fits ≥7 B params on a single A100.
- Config-first ethos – everything (LR schedule, clipping regime, micro-batching) is declarative via Pydantic.
The annotated directory map
Zone | Key files | What they actually do |
---|---|---|
Top-level | train.py – launch script that• parses a Config model, • swaps Qwen-2 kernels for Liger when asked, • wraps every block in FSDP, • spins the async collect → learn loop, • checkpoints twice (full shards + safetensors). | The main orchestrator. |
infer.py – vLLM worker that streams prompts, scores rewards, dumps Parquet. | Generation side of the loop. | |
training/ | config.py , envs.py , world_info.py | Pure Pydantic & env wrappers. |
data.py | Watches /step_k/ , shards rows per rank, returns advantages & log π_old in one go. | |
loss.py | Three GRPO flavours (clip, ratio, KL-cov). | |
lr_scheduler.py | Cosine, linear, plus √-decay (“WSD-sqrt”). | |
checkpoint.py | Full FSDP recovery + light safetensors for rollouts. | |
inference/ | pipeline.py , rewards.py , toploc.py | Everything needed to turn raw completions into reward-rich Parquet. |
Shared utils/ | logger , models , http_monitor , metrics | Colour logs, model-zoo helpers, REST metrics. |
(The original o3 dump is in the appendix for completeness.)
Untangling “rollout” vs “rollout step”
One potentially confusing term in prime-rl is “rollout”:
Concept | In prime-rl terms |
---|---|
Rollout | The file bundle under …/step_k/ – ~batch_size × step_per_rollout prompt→completion rows, each already tagged with token-level rewards. |
Optimizer step | One optim.step() over a mini-batch from that bundle. |
Rollout step (k) | The sequence “load step_k → do step_per_rollout optimizer steps → save ckpt_rollout_k.safetensors ”. |
After step_per_rollout
updates the learner tosses the data, hands the fresh weights to the inference workers, and waits for /step_{k+1}/
to finish writing. The ratio keeps GPUs busy while guaranteeing data freshness. If your config says:
optim.batch_size = 256
optim.step_per_rollout = 8
then each rollout must contain 256 × 8 = 2048
samples. The trainer reuses that table for 8 updates before it asks the inference side for new data—balancing data freshness with GPU utilisation.
The 60-second mental model
┌── inference workers (vLLM) ─────────────┐
│ sample N completions per prompt │
│ compute rewards ➜ write to step_k/*.parquet
└──────────────────────────────────────────┘
│ (file + stable flag)
▼
┌──── train.py ─────────────────────────────────────────────┐
│ stream Parquet ➜ recompute log-probs ➜ GRPO loss │
│ micro-batch, grad-accum, clip, AdamW step │
│ every `step_per_rollout` steps: │
│ • save FSDP shards (recovery) │
│ • save safetensors weights (rollout checkpoint) │
└───────────────────────────────────────────────────────────┘
│ (HTTP path broadcast by shardcast)
▼
inference workers reload weights … repeat
Training therefore looks like this (numbers are typical, not fixed):
rollout 0
├─ infer generates 2 048 (prompt, completion) pairs → .../step_0/
└─ trainer loads step_0 and does
step 0.0 minibatch 256 → weight update
step 0.1 minibatch 256 → weight update
… repeat until step_per_rollout-1
save ckpt_rollout_0
rollout 1
├─ infer sees new weights ➜ generates 2 048 fresh pairs → .../step_1/
└─ trainer repeats the routines on step_1
…
Once again, to summarise:
- Inference workers (any mix of TP / DP / PP) read the current model checkpoint, sample N completions, compute rewards and dump them to Parquet.
- When the last worker writes the
stable
sentinel file, training unblocks, streams those rows, recomputes any missing log-probs or KL reference terms, and performs one rollout ( =step_per_rollout
gradient steps). - After the rollout it reshards / off-loads the fresh weights to a single safetensors file and broadcasts that path with shardcast; inference nodes pick it up and the loop continues.
- Periodically,
checkpoint.py
saves full FSDP shards so you can resume interrupted runs exactly.
Because both sides are fully decoupled (just file+HTTP hand-shakes) you can scale them independently: e.g. 1 trainer GPU with step_per_rollout=1
fed by 16 vLLM GPUs, or vice-versa.
Patterns worth stealing
Pattern | Why it matters |
---|---|
Pydantic for everything | One TOML file rebuilds the entire run; makes WandB sweeps trivial. |
Filesystem hand-shake | Parquet + .stable sentinel = dead-simple, stateless coordination – no Redis, no Ray. |
Two-tier checkpoints | Heavy FSDP shards for crash-recovery, light safetensors for inference throughput. |
GRPO variants side-by-side | You can benchmark clip vs ratio vs KL-cov just by flipping one enum. |
Takeaways
- Declarative configs beat bash glue. With every arg living in Pydantic, spinning a 2-GPU debug run or a 64-GPU sweep is
cp config/foo.toml config/bar.toml
. - File-based async RL is underrated. Parquet + safetensors + a HTTP heartbeat were easier to reason about (and debug) than Ray actors or a bespoke RPC framework.
- o3 is great at dissecting repos. Especially when the repo already follows clean, config-driven design.
- PRIME-RL’s separation of collect and learn is refreshingly minimal: if you understand the folder naming convention, you understand the pipeline.
Appendix · Full o3 file-by-file dump
Top-level package files
file | role |
---|---|
__init__.py | Empty placeholder – just marks zeroband as a Python package so you can do import zeroband… |
train.py | Entry-point that launches the RL trainer. It • parses CLI via pydantic-config into a Config model; • builds/loads the model & tokenizer ( get_model_and_tokenizer ) and optionally patches Qwen-2 with Liger kernels; • wraps every transformer block with FSDP ( fully_shard ) for sharded training; • creates the optimiser/scheduler ( AdamW + custom scheduler); • spins an async rollout loop ( step_per_rollout ) that– streams mini-batches from training.data.get_dataloader ,– (optionally) recomputes token-wise log-probs for GRPO, – calls grpo_loss , entropy_loss , kl_penalty to get the final objective,– does grad-accum, clipping, optimiser step and LR step, – logs to W\&B & an HTTP monitor, and – checkpointes every N rollouts both for recovery ( save_checkpoint_fsdp_state ) and for rollout workers (save_ckpt_for_rollout ). ([github.com][1]) |
infer.py | Companion script that generates experiences with vLLM and stores them as Parquet rows that the trainer later consumes. Highlights: • spins up a vLLM engine with configurable DP/TP/PP parallelism; • loads the dataset, optionally filters by prompt length & difficulty; • hands prompts to vLLM with the SamplingParams supplied in inference.config ; • computes reward signals via inference.rewards.compute_rewards and optional ToP-LOC length-control reward; • writes a Parquet file for every sample plus a stable flag file so the trainer knows the step is finished; • optionally hot-reloads model weights from the most recent rollout checkpoint (for fully asynchronous RL). ([github.com][2]) |
training sub-package
file | job |
---|---|
config.py | Declarative schema (Pydantic) for every trainer hyper-parameter: optimiser block, scheduler type, batch sizes, micro-batching, checkpoint cadence, GRPO variants, etc. The validators ensure things like “ckpt.interval must be a multiple of step_per_rollout ”. ([github.com][3]) |
envs.py | Typed wrappers around environment variables (RANK , WORLD_SIZE , TRAINING_ENABLE_ACCEPTED_CHECK , …). Provides get_env_value helpers so the rest of the code can use plain attribute reads. ([github.com][4]) |
world_info.py | Convenience singleton that captures distributed topology (rank, local-rank, #GPUs/node, etc.) and is reused everywhere so you never touch os.environ in user code. ([github.com][5]) |
data.py | The data loader: • watches a directory hierarchy …/step_k/*.parquet ; • blocks until the stable flag appears, then uses pyarrow.dataset to stream rows; • shards rows across ranks & workers ( _should_skip_index ); • can fall back to a synthetic FakeTokenizedDataset for debugging. The output dictionary already contains advantages, token-wise log-probs, rewards, task ids, etc., so the trainer can compute losses without touching the original dataset again. ([github.com][6]) |
loss.py | Implements three flavours of GRPO (clip, ratio, KL-covariance). All versions share helpers:selective_log_softmax , highest_entropy_mask , _apply_mask . Returns (loss, clip_ratio?) so the trainer can log PPO-style stats. ([github.com][7]) |
lr_scheduler.py | Thin wrapper that exposes cosine, linear, and a custom “WSD-sqrt” schedule (warm-up → stable → √-decay). Selects the right callable from SCHED_MAP based on the config. ([github.com][8]) |
checkpoint.py | Two levels of persistence: 1. Full FSDP shards ( save_checkpoint_fsdp_state ) for exact recovery; 2. Rollout safetensors ( save_ckpt_for_rollout ) – a single‐file, CPU-off-loaded copy of the model that inference workers can download via shardcast . ([github.com][9]) |
utils.py | Mixed bag of trainer helpers: • apply_ac_ckpt turn-on PyTorch activation checkpointing every n layers; • GPU/TPU FLOP calculators → PerfCounter exposes MFU & tokens/s; • MetricsAverager that syncs per-GPU statistics; • a small random-port helper for spawning vLLM servers while avoiding conflicts. ([github.com][10]) |
(Other small modules in the folder – data_prefetch.py
, mp.py
, etc. – are thin wrappers around multiprocessing or GCP pre-fetching and don’t hold core logic.)
inference sub-package
file |
---|
config.py – Pydantic schema for everything the inference node needs: model name, parallelism sizes, sampling params, reward toggles, etc. |
envs.py – Same pattern as the training side but with NODE_ADDRESS , PP.RANK , … |
pipeline.py – Registers PP nodes with prime-iroh so shards can stream tensors peer-to-peer. |
parquet.py – Turns lists of generated samples into a columnar Parquet table that matches the schema expected by training.data . |
toploc.py – Implements the ToP-LOC length-control reward (looks at hidden‐state activations). |
rewards.py – Houses reward functions: correctness, length penalty, difficulty buckets, etc. |
utils.py – Token-based helpers (fake_chat_template , prompt length filtering, etc.). |
(These files follow the same design philosophy as the trainer: pure functions + Pydantic configs so you can swap any component out in your own fork.)
utils shared helpers
A few highlights that are imported by both trainer & inference:
utils.logger
– colourised, rank-aware logging.utils.models
– centralises “model zoo” logic (Qwen-2 vs LLaMA-2, Flash-Attn vs Torch-Attn, parameter counting, etc.).utils.metrics
– lightweight JSONL logger for PrimeIntellect’s internal dashboard.utils.http_monitor
– pushes selected metrics to a REST endpoint so you can watch runs without WandB.