Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions roll/pipeline/agentic/agentic_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,16 @@ class AgenticConfig(PPOConfig):
open_feedback_turn: bool = field(default=False, metadata={"help": "open feedback turn"})
use_token_reward: bool = field(default=False, metadata={"help": "use token reward"})

estimated_chunks_per_traj: int = field(
default=0,
metadata={
"help": "Estimated average number of training chunks (turns) per trajectory for step-level "
"agentic training (e.g., AgentNativeStepEnvManager). When > 0, the LR scheduler "
"budget is multiplied by this factor to prevent premature exhaustion. "
"When 0 (default), auto-detected from custom_envs max_steps (conservative: max_steps // 2)."
},
)

batch_adjust_mode: Literal["copy", "delete", "auto", "random_sample"] = field(
default="copy", metadata={"help": "batch adjust mode: copy or delete"}
)
Expand Down Expand Up @@ -357,6 +367,83 @@ def __post_init__(self):
# Apply OPD configuration at the end (handles student_train/student_infer/teacher mapping)
self._apply_opd_config()

def _get_chunks_per_traj_estimate(self) -> int:
"""Estimate the average number of training chunks per trajectory.

For step-level env managers (e.g., AgentNativeStepEnvManager), each trajectory
produces multiple training samples (one per turn/step). This method estimates
the multiplier from custom_envs config or the explicit override field.

Returns:
int: estimated chunks per trajectory (minimum 1)
"""
if self.estimated_chunks_per_traj > 0:
return self.estimated_chunks_per_traj

# Auto-detect from custom_envs: look for max_steps in env configs
max_env_steps = 0
if self.custom_envs:
for tag, cfg in self.custom_envs.items():
if hasattr(cfg, 'max_steps') and cfg.max_steps is not None:
max_env_steps = max(max_env_steps, int(cfg.max_steps))
elif isinstance(cfg, dict) and 'max_steps' in cfg:
max_env_steps = max(max_env_steps, int(cfg['max_steps']))

Comment on lines +383 to +391
if max_env_steps > 1:
# Conservative estimate: half of max_steps as average turns per trajectory
estimate = max(1, max_env_steps // 2)
logger.info(
f"Auto-detected estimated_chunks_per_traj={estimate} "
f"(from custom_envs max_steps={max_env_steps}, using max_steps // 2)"
)
return estimate

return 1
Comment on lines +383 to +401

def set_max_steps(self, max_steps: int):
"""Override to account for multi-turn chunking in agentic training.

For step-level env managers (AgentNativeStepEnvManager), each trajectory
produces multiple training samples (chunks), one per agent turn. The parent
implementation only accounts for rollout_batch_size (trajectories), causing
the LR scheduler to exhaust early when chunks >> trajectories.

This override multiplies the optimizer step budget by an estimate of chunks
per trajectory to align the scheduler with actual training dynamics.
"""
chunks_multiplier = self._get_chunks_per_traj_estimate()

actor_backward_batch_size = (
self.actor_train.training_args.per_device_train_batch_size
* self.actor_train.training_args.gradient_accumulation_steps
)
critic_backward_batch_size = (
self.critic.training_args.per_device_train_batch_size
* self.critic.training_args.gradient_accumulation_steps
)

self.actor_train.training_args.max_steps = max(1, (
max_steps
* self.rollout_batch_size
* self.actor_infer.generating_args.num_return_sequences
* self.ppo_epochs
* chunks_multiplier
// actor_backward_batch_size
))
self.critic.training_args.max_steps = max(1, (
max_steps
* self.rollout_batch_size
* self.actor_infer.generating_args.num_return_sequences
* chunks_multiplier
// critic_backward_batch_size
))

logger.info(f"pipeline max_steps: {self.max_steps} to {max_steps}")
logger.info(f"chunks_per_traj_multiplier: {chunks_multiplier}")
logger.info(f"actor train max_steps without dp_size: {self.actor_train.training_args.max_steps}")
logger.info(f"critic train max_steps without dp_size: {self.critic.training_args.max_steps}")
self.max_steps = max_steps

def make_env_configs(self, env_manager_config: EnvManagerConfig):
# construct env configs
env_configs = defaultdict(defaultdict)
Expand Down
Loading