Skip to content

fix: account for multi-turn chunks in agentic LR scheduler budget (#407)#439

Open
dashitongzhi wants to merge 1 commit intoalibaba:mainfrom
dashitongzhi:fix/agentic-lr-scheduler-exhaustion
Open

fix: account for multi-turn chunks in agentic LR scheduler budget (#407)#439
dashitongzhi wants to merge 1 commit intoalibaba:mainfrom
dashitongzhi:fix/agentic-lr-scheduler-exhaustion

Conversation

@dashitongzhi
Copy link
Copy Markdown

Problem

When using AgentNativeStepEnvManager for step-level agentic training, the LR scheduler exhausts its step budget far before all pipeline steps complete, causing the learning rate to drop to zero mid-training.

Root cause: PPOConfig.set_max_steps() computes total optimizer steps based on rollout_batch_size (number of trajectories), but AgentNativeStepEnvManager.formulate_rollouts() creates one training sample per turn — so the actual number of optimizer steps per pipeline step is much higher than budgeted.

Example from the field: With 4 trajectories × ~10 turns each = ~40 training samples per pipeline step. With backward_batch_size=4, that's ~10 optimizer steps per pipeline step — not the 1 that the scheduler was budgeted for. In a 200-step run, LR hit zero at step 123 (38.5% of training with zero LR).

Fixes #407.

Solution

This PR overrides set_max_steps in AgenticConfig to account for multi-turn chunking:

  1. New config field: estimated_chunks_per_traj (default 0 = auto-detect)
  2. Auto-detection: When not explicitly set, estimates from custom_envs.max_steps using max_steps // 2 as a conservative midpoint
  3. Explicit override: Users can set estimated_chunks_per_traj for precise control
  4. Backward compatible: When max_steps <= 1 in env configs (single-turn), falls back to multiplier of 1 — identical to parent behavior

How it works

# In AgenticConfig.set_max_steps():
chunks_multiplier = self._get_chunks_per_traj_estimate()

self.actor_train.training_args.max_steps = max(1, (
    max_steps * rollout_batch_size * num_return_sequences
    * ppo_epochs * chunks_multiplier // actor_backward_batch_size
))

User-facing config

# Auto-detect (default) — reads max_steps from custom_envs
estimated_chunks_per_traj: 0

# Explicit override for precise control
estimated_chunks_per_traj: 10

Files Changed

  • roll/pipeline/agentic/agentic_config.py — Added estimated_chunks_per_traj field, _get_chunks_per_traj_estimate() method, and set_max_steps() override

Testing

The fix is backward-compatible: for single-turn environments (max_steps=1), the auto-detect returns 1 and behavior is identical to the parent class.

When using AgentNativeStepEnvManager for step-level agentic training,
each trajectory produces multiple training samples (one per agent turn).
The base PPOConfig.set_max_steps() only accounts for rollout_batch_size
(number of trajectories), causing the LR scheduler to exhaust its step
budget far before all pipeline steps complete.

For example, with 4 trajectories × ~10 turns each = ~40 training samples
per pipeline step. With backward_batch_size=4, that's ~10 optimizer steps
per pipeline step instead of the budgeted 1.

This fix:
- Adds estimated_chunks_per_traj config field (default 0 = auto-detect)
- Overrides set_max_steps in AgenticConfig to multiply optimizer step
  budget by the chunks-per-trajectory estimate
- Auto-detects the estimate from custom_envs max_steps (conservative:
  max_steps // 2) when not explicitly configured
- Users can override via estimated_chunks_per_traj for precise control

Fixes alibaba#407
Copilot AI review requested due to automatic review settings May 8, 2026 15:29
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes premature LR-scheduler exhaustion during step-level agentic training by making AgenticConfig.set_max_steps() account for the fact that some env managers generate one training sample per turn (chunk) rather than per trajectory.

Changes:

  • Added estimated_chunks_per_traj to let users explicitly scale the optimizer-step budget for multi-turn (chunked) rollouts.
  • Implemented _get_chunks_per_traj_estimate() to auto-estimate the chunk multiplier from custom_envs.
  • Overrode set_max_steps() in AgenticConfig to multiply actor/critic training_args.max_steps by the chunk multiplier.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

"When 0 (default), auto-detected from custom_envs max_steps (conservative: max_steps // 2)."
},
)

Comment on lines +383 to +391
# Auto-detect from custom_envs: look for max_steps in env configs
max_env_steps = 0
if self.custom_envs:
for tag, cfg in self.custom_envs.items():
if hasattr(cfg, 'max_steps') and cfg.max_steps is not None:
max_env_steps = max(max_env_steps, int(cfg.max_steps))
elif isinstance(cfg, dict) and 'max_steps' in cfg:
max_env_steps = max(max_env_steps, int(cfg['max_steps']))

Comment on lines +383 to +401
# Auto-detect from custom_envs: look for max_steps in env configs
max_env_steps = 0
if self.custom_envs:
for tag, cfg in self.custom_envs.items():
if hasattr(cfg, 'max_steps') and cfg.max_steps is not None:
max_env_steps = max(max_env_steps, int(cfg.max_steps))
elif isinstance(cfg, dict) and 'max_steps' in cfg:
max_env_steps = max(max_env_steps, int(cfg['max_steps']))

if max_env_steps > 1:
# Conservative estimate: half of max_steps as average turns per trajectory
estimate = max(1, max_env_steps // 2)
logger.info(
f"Auto-detected estimated_chunks_per_traj={estimate} "
f"(from custom_envs max_steps={max_env_steps}, using max_steps // 2)"
)
return estimate

return 1
@CLAassistant
Copy link
Copy Markdown

CLAassistant commented May 8, 2026

CLA assistant check
All committers have signed the CLA.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

LR scheduler exhausts early in agentic training with AgentNativeStepEnvManager

3 participants