Skip to content

feat: add experimental native RL stack and arithmetic validation benchmark#6

Open
PastaPastaPasta wants to merge 27 commits intoARahim3:mainfrom
PastaPastaPasta:codex/rl-reference-grpo
Open

feat: add experimental native RL stack and arithmetic validation benchmark#6
PastaPastaPasta wants to merge 27 commits intoARahim3:mainfrom
PastaPastaPasta:codex/rl-reference-grpo

Conversation

@PastaPastaPasta
Copy link

Summary

This branch is the full native RL bring-up compared to main.

It adds the internal RL runtime, RL model-role plumbing, checkpoint/resume support, public RL API surface, multiple RL trainers/configs, a TRL-compat patch layer, and an experimental arithmetic GRPO validation benchmark for Qwen 3.

This is still very experimental. It is absolutely prototype-quality, heavily vibe coded (GPT 5.4), and not tested to the level this amount of surface area would normally require.

The goal was to keep API compatibility with unsloth; I have not manually verified that to be the case.

What changed

  • adds native RL runtime infrastructure in mlx_tune/_rl_runtime.py
  • adds RL model-role builders, reward/value/reference helpers, and checkpoint role handling
  • adds public RL API helpers in mlx_tune/rl_api.py
  • extends exports in mlx_tune/__init__.py
  • adds or significantly expands trainers/configs for:
    • Reward modeling
    • DPO
    • ORPO
    • GRPO
    • PPO
    • OnlineDPO
    • KTO
    • SimPO
  • adds TRL compatibility patching in mlx_tune/trl_compat.py
  • updates examples, including a new Qwen 3 arithmetic GRPO validation benchmark
  • adds targeted RL/runtime/model-role/integration tests

Why

The goal of this branch is to get a real native RL training stack into the repo and make it testable with a deterministic benchmark.

The arithmetic benchmark is there mainly as a validation harness:

  • easy to generate
  • easy to score
  • no judge model
  • deterministic reward
  • useful for checking whether GRPO is actually moving policy behavior

In a local proof run with the new arithmetic benchmark, held-out exact match and solution-tag adherence improved materially after GRPO. That is encouraging, but it should be treated as an experiment, not proof that the broader stack is production-ready.

Important caveats

This PR is large and risky.

  • It changes a lot of training/runtime surface area at once.
  • The implementation is experimental and was put together quickly.
  • A meaningful amount of it is vibe coded.
  • There are tests, but relative to the size of this change the testing is still minimal.
  • I would not treat the current APIs or behaviors as stable.
  • I would expect follow-up fixes, edge cases, and cleanup.

Suggested reviewer mindset

Please review this as:

  • a large experimental RL branch
  • an attempt to get end-to-end functionality working
  • not a polished or production-ready training subsystem

I would focus on:

  • correctness of trainer semantics
  • checkpoint/resume behavior
  • reward/data plumbing
  • rollout/runtime edge cases
  • API shape and maintenance risk
  • obvious regressions against existing training flows

Validation performed

Locally, I ran targeted RL tests plus a real arithmetic GRPO proof run:

  • baseline eval
  • GRPO training
  • post-RL eval and comparison

That gives some confidence the path is live, but not nearly enough confidence for the full scope of this branch.

Yes. Add a short “How To Try It” section aimed at maintainers.

Use this in the PR body:

How To Try It

If you want to sanity check that the RL path actually works, the easiest entrypoint is the arithmetic GRPO benchmark.

1. Generate a deterministic dataset

python examples/10_qwen3_arithmetic_grpo_validation.py generate \
  --output-dir /tmp/qwen3_arith_demo \
  --train-size 256 \
  --val-size 32 \
  --test-size 32 \
  --force-generate

2. Measure the zero-shot baseline

python examples/10_qwen3_arithmetic_grpo_validation.py baseline \
  --output-dir /tmp/qwen3_arith_demo \
  --model-name mlx-community/Qwen3-1.7B-4bit \
  --max-completion-length 128 \
  --max-seq-length 384

This writes:

  • /tmp/qwen3_arith_demo/baseline_outputs.jsonl
  • /tmp/qwen3_arith_demo/baseline_metrics.json

3. Run a small GRPO training pass

python examples/10_qwen3_arithmetic_grpo_validation.py train \
  --output-dir /tmp/qwen3_arith_demo \
  --model-name mlx-community/Qwen3-1.7B-4bit \
  --max-completion-length 128 \
  --max-seq-length 384 \
  --max-steps 30 \
  --per-device-train-batch-size 2 \
  --rollout-batch-size 2 \
  --num-generations 2 \
  --logging-steps 5 \
  --eval-steps 10 \
  --save-steps 10

This writes:

  • /tmp/qwen3_arith_demo/rl_training_summary.json
  • /tmp/qwen3_arith_demo/post_rl_outputs.jsonl
  • /tmp/qwen3_arith_demo/post_rl_metrics.json

4. Compare before vs after RL

python examples/10_qwen3_arithmetic_grpo_validation.py compare \
  --output-dir /tmp/qwen3_arith_demo

This writes:

  • /tmp/qwen3_arith_demo/comparison.json
  • /tmp/qwen3_arith_demo/comparison.md

What to expect

The benchmark is intentionally simple:

  • the model can emit whatever it wants in <think>
  • only the integer inside <solution>...</solution> is scored
  • reward is deterministic

In my local proof run on this branch:

  • baseline exact match was very low
  • post-GRPO exact match and solution-tag adherence improved materially

So if the stack is functioning, you should usually see:

  • higher solution_tag_rate
  • higher avg_reward
  • often higher exact_match

Important caveats

This is only a smoke/proof path.
It does not prove the whole RL stack is correct or stable.
It is just the fastest way for a maintainer to see an actual RL loop move model behavior in this branch.

@ARahim3
Copy link
Owner

ARahim3 commented Mar 7, 2026

Hi @PastaPastaPasta ,
Thanks a lot for the contribution - really appreciate the effort here. This is a substantial change, so I’m going to take some time to review it carefully before deciding on the scope and next steps.

@PastaPastaPasta
Copy link
Author

Totally agree. The reason I built this is wanting to play with RL on my Mac but I never could. I'm going to be testing it in a not so toy project as well.

Figured better to open a pr than leave it sitting in my fork.

No pressure on review but if you find things happy to resolve them.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants