Peeking Under the Hood of prime-rl

10 minute read

Published:

I’d been following the INTELLECT-2 paper and other PrimeIntellect work, but what really piqued my curiosity was PrimeIntellect-ai/prime-rl. The promise was bold: fully asynchronous, file-based RL that scales across decentralized devices. I wanted to understand exactly how it worked—scheduler quirks, memory tricks, the rollout loop, so I asked o3 to be my copilot. What followed was a week-long conversation in which we spelunked through every Python file until a coherent picture emerged. (While at it, I started a fork and sprinkled a few small QoL commits of my own → kevinbdsouza/prime-rl.)

The README is great for a quick run, but I wanted to see the gears turning: where do rollouts live? who shards what? when does the learner talk to vLLM? Instead of diving into 30-odd files by hand I fired up o3 and asked it to annotate every module. Over a few prompt-and-refine cycles it produced a crisp map of the whole src/zeroband/ package (pasted full in Appendix). I then cross-checked the details and figured a write-up might save others an afternoon. Below is the distilled notebook-to-blog narrative.

Why PRIME-RL is worth dissecting

  • Asynchronous, file-based RL loop – no parameter server, just Parquet shards in /step_k/ and a single-file safetensors checkpoint.
  • Group Relative Policy Optimisation (GRPO) – the leaner PPO cousin with token-level control-variates.
  • FSDP + activation ckpt – fits ≥7 B params on a single A100.
  • Config-first ethoseverything (LR schedule, clipping regime, micro-batching) is declarative via Pydantic.

The annotated directory map

ZoneKey filesWhat they actually do
Top-leveltrain.py – launch script that
• parses a Config model,
• swaps Qwen-2 kernels for Liger when asked,
• wraps every block in FSDP,
• spins the async collect → learn loop,
• checkpoints twice (full shards + safetensors).
The main orchestrator.
 infer.py – vLLM worker that streams prompts, scores rewards, dumps Parquet.Generation side of the loop.
training/config.py, envs.py, world_info.pyPure Pydantic & env wrappers.
 data.pyWatches /step_k/, shards rows per rank, returns advantages & log π_old in one go.
 loss.pyThree GRPO flavours (clip, ratio, KL-cov).
 lr_scheduler.pyCosine, linear, plus √-decay (“WSD-sqrt”).
 checkpoint.pyFull FSDP recovery + light safetensors for rollouts.
inference/pipeline.py, rewards.py, toploc.pyEverything needed to turn raw completions into reward-rich Parquet.
Shared utils/logger, models, http_monitor, metricsColour logs, model-zoo helpers, REST metrics.

(The original o3 dump is in the appendix for completeness.)

Untangling “rollout” vs “rollout step”

One potentially confusing term in prime-rl is “rollout”:

ConceptIn prime-rl terms
RolloutThe file bundle under …/step_k/ – ~batch_size × step_per_rollout prompt→completion rows, each already tagged with token-level rewards.
Optimizer stepOne optim.step() over a mini-batch from that bundle.
Rollout step (k)The sequence “load step_k → do step_per_rollout optimizer steps → save ckpt_rollout_k.safetensors”.

After step_per_rollout updates the learner tosses the data, hands the fresh weights to the inference workers, and waits for /step_{k+1}/ to finish writing. The ratio keeps GPUs busy while guaranteeing data freshness. If your config says:

optim.batch_size        = 256
optim.step_per_rollout  = 8

then each rollout must contain 256 × 8 = 2048 samples. The trainer reuses that table for 8 updates before it asks the inference side for new data—balancing data freshness with GPU utilisation.

The 60-second mental model

┌── inference workers (vLLM) ─────────────┐
│ sample N completions per prompt         │
│ compute rewards ➜ write to step_k/*.parquet
└──────────────────────────────────────────┘
                   │  (file + stable flag)
                   ▼
┌──── train.py ─────────────────────────────────────────────┐
│ stream Parquet ➜ recompute log-probs ➜ GRPO loss          │
│ micro-batch, grad-accum, clip, AdamW step                 │
│ every `step_per_rollout` steps:                           │
│   • save FSDP shards (recovery)                           │
│   • save safetensors weights (rollout checkpoint)         │
└───────────────────────────────────────────────────────────┘
                   │  (HTTP path broadcast by shardcast)
                   ▼
           inference workers reload weights … repeat

Training therefore looks like this (numbers are typical, not fixed):

rollout 0
  ├─ infer generates 2 048 (prompt, completion) pairs →  .../step_0/
  └─ trainer loads step_0 and does
       step 0.0  minibatch 256  → weight update
       step 0.1  minibatch 256  → weight update
       … repeat until step_per_rollout-1
       save ckpt_rollout_0
rollout 1
  ├─ infer sees new weights ➜ generates 2 048 fresh pairs → .../step_1/
  └─ trainer repeats the routines on step_1
…

Once again, to summarise:

  1. Inference workers (any mix of TP / DP / PP) read the current model checkpoint, sample N completions, compute rewards and dump them to Parquet.
  2. When the last worker writes the stable sentinel file, training unblocks, streams those rows, recomputes any missing log-probs or KL reference terms, and performs one rollout ( = step_per_rollout gradient steps).
  3. After the rollout it reshards / off-loads the fresh weights to a single safetensors file and broadcasts that path with shardcast; inference nodes pick it up and the loop continues.
  4. Periodically, checkpoint.py saves full FSDP shards so you can resume interrupted runs exactly.

Because both sides are fully decoupled (just file+HTTP hand-shakes) you can scale them independently: e.g. 1 trainer GPU with step_per_rollout=1 fed by 16 vLLM GPUs, or vice-versa.

Patterns worth stealing

PatternWhy it matters
Pydantic for everythingOne TOML file rebuilds the entire run; makes WandB sweeps trivial.
Filesystem hand-shakeParquet + .stable sentinel = dead-simple, stateless coordination – no Redis, no Ray.
Two-tier checkpointsHeavy FSDP shards for crash-recovery, light safetensors for inference throughput.
GRPO variants side-by-sideYou can benchmark clip vs ratio vs KL-cov just by flipping one enum.

Takeaways

  1. Declarative configs beat bash glue. With every arg living in Pydantic, spinning a 2-GPU debug run or a 64-GPU sweep is cp config/foo.toml config/bar.toml.
  2. File-based async RL is underrated. Parquet + safetensors + a HTTP heartbeat were easier to reason about (and debug) than Ray actors or a bespoke RPC framework.
  3. o3 is great at dissecting repos. Especially when the repo already follows clean, config-driven design.
  4. PRIME-RL’s separation of collect and learn is refreshingly minimal: if you understand the folder naming convention, you understand the pipeline.

Appendix · Full o3 file-by-file dump

Top-level package files

filerole
__init__.pyEmpty placeholder – just marks zeroband as a Python package so you can do import zeroband…
train.pyEntry-point that launches the RL trainer. It
• parses CLI via pydantic-config into a Config model;
• builds/loads the model & tokenizer (get_model_and_tokenizer) and optionally patches Qwen-2 with Liger kernels;
• wraps every transformer block with FSDP (fully_shard) for sharded training;
• creates the optimiser/scheduler (AdamW + custom scheduler);
• spins an async rollout loop (step_per_rollout) that
  – streams mini-batches from training.data.get_dataloader,
  – (optionally) recomputes token-wise log-probs for GRPO,
  – calls grpo_loss, entropy_loss, kl_penalty to get the final objective,
  – does grad-accum, clipping, optimiser step and LR step,
  – logs to W\&B & an HTTP monitor, and
  – checkpointes every N rollouts both for recovery (save_checkpoint_fsdp_state) and for rollout workers (save_ckpt_for_rollout). ([github.com][1])
infer.pyCompanion script that generates experiences with vLLM and stores them as Parquet rows that the trainer later consumes. Highlights:
• spins up a vLLM engine with configurable DP/TP/PP parallelism;
• loads the dataset, optionally filters by prompt length & difficulty;
• hands prompts to vLLM with the SamplingParams supplied in inference.config;
• computes reward signals via inference.rewards.compute_rewards and optional ToP-LOC length-control reward;
• writes a Parquet file for every sample plus a stable flag file so the trainer knows the step is finished;
• optionally hot-reloads model weights from the most recent rollout checkpoint (for fully asynchronous RL). ([github.com][2])

training sub-package

filejob
config.pyDeclarative schema (Pydantic) for every trainer hyper-parameter: optimiser block, scheduler type, batch sizes, micro-batching, checkpoint cadence, GRPO variants, etc. The validators ensure things like “ckpt.interval must be a multiple of step_per_rollout”. ([github.com][3])
envs.pyTyped wrappers around environment variables (RANK, WORLD_SIZE, TRAINING_ENABLE_ACCEPTED_CHECK, …). Provides get_env_value helpers so the rest of the code can use plain attribute reads. ([github.com][4])
world_info.pyConvenience singleton that captures distributed topology (rank, local-rank, #GPUs/node, etc.) and is reused everywhere so you never touch os.environ in user code. ([github.com][5])
data.pyThe data loader:
• watches a directory hierarchy …/step_k/*.parquet;
• blocks until the stable flag appears, then uses pyarrow.dataset to stream rows;
• shards rows across ranks & workers (_should_skip_index);
• can fall back to a synthetic FakeTokenizedDataset for debugging. The output dictionary already contains advantages, token-wise log-probs, rewards, task ids, etc., so the trainer can compute losses without touching the original dataset again. ([github.com][6])
loss.pyImplements three flavours of GRPO (clip, ratio, KL-covariance). All versions share helpers:
selective_log_softmax, highest_entropy_mask, _apply_mask. Returns (loss, clip_ratio?) so the trainer can log PPO-style stats. ([github.com][7])
lr_scheduler.pyThin wrapper that exposes cosine, linear, and a custom “WSD-sqrt” schedule (warm-up → stable → √-decay). Selects the right callable from SCHED_MAP based on the config. ([github.com][8])
checkpoint.pyTwo levels of persistence:
1. Full FSDP shards (save_checkpoint_fsdp_state) for exact recovery;
2. Rollout safetensors (save_ckpt_for_rollout) – a single‐file, CPU-off-loaded copy of the model that inference workers can download via shardcast. ([github.com][9])
utils.pyMixed bag of trainer helpers:
apply_ac_ckpt turn-on PyTorch activation checkpointing every n layers;
• GPU/TPU FLOP calculators → PerfCounter exposes MFU & tokens/s;
MetricsAverager that syncs per-GPU statistics;
• a small random-port helper for spawning vLLM servers while avoiding conflicts. ([github.com][10])

(Other small modules in the folder – data_prefetch.py, mp.py, etc. – are thin wrappers around multiprocessing or GCP pre-fetching and don’t hold core logic.)

inference sub-package

file
config.py – Pydantic schema for everything the inference node needs: model name, parallelism sizes, sampling params, reward toggles, etc.
envs.py – Same pattern as the training side but with NODE_ADDRESS, PP.RANK, …
pipeline.py – Registers PP nodes with prime-iroh so shards can stream tensors peer-to-peer.
parquet.py – Turns lists of generated samples into a columnar Parquet table that matches the schema expected by training.data.
toploc.py – Implements the ToP-LOC length-control reward (looks at hidden‐state activations).
rewards.py – Houses reward functions: correctness, length penalty, difficulty buckets, etc.
utils.py – Token-based helpers (fake_chat_template, prompt length filtering, etc.).

(These files follow the same design philosophy as the trainer: pure functions + Pydantic configs so you can swap any component out in your own fork.)

utils shared helpers

A few highlights that are imported by both trainer & inference:

  • utils.logger – colourised, rank-aware logging.
  • utils.models – centralises “model zoo” logic (Qwen-2 vs LLaMA-2, Flash-Attn vs Torch-Attn, parameter counting, etc.).
  • utils.metrics – lightweight JSONL logger for PrimeIntellect’s internal dashboard.
  • utils.http_monitor – pushes selected metrics to a REST endpoint so you can watch runs without WandB.