diff --git a/roll/pipeline/agentic/agentic_config.py b/roll/pipeline/agentic/agentic_config.py index 231f3859d..7f07e8de4 100644 --- a/roll/pipeline/agentic/agentic_config.py +++ b/roll/pipeline/agentic/agentic_config.py @@ -200,6 +200,16 @@ class AgenticConfig(PPOConfig): open_feedback_turn: bool = field(default=False, metadata={"help": "open feedback turn"}) use_token_reward: bool = field(default=False, metadata={"help": "use token reward"}) + estimated_chunks_per_traj: int = field( + default=0, + metadata={ + "help": "Estimated average number of training chunks (turns) per trajectory for step-level " + "agentic training (e.g., AgentNativeStepEnvManager). When > 0, the LR scheduler " + "budget is multiplied by this factor to prevent premature exhaustion. " + "When 0 (default), auto-detected from custom_envs max_steps (conservative: max_steps // 2)." + }, + ) + batch_adjust_mode: Literal["copy", "delete", "auto", "random_sample"] = field( default="copy", metadata={"help": "batch adjust mode: copy or delete"} ) @@ -357,6 +367,83 @@ def __post_init__(self): # Apply OPD configuration at the end (handles student_train/student_infer/teacher mapping) self._apply_opd_config() + def _get_chunks_per_traj_estimate(self) -> int: + """Estimate the average number of training chunks per trajectory. + + For step-level env managers (e.g., AgentNativeStepEnvManager), each trajectory + produces multiple training samples (one per turn/step). This method estimates + the multiplier from custom_envs config or the explicit override field. + + Returns: + int: estimated chunks per trajectory (minimum 1) + """ + if self.estimated_chunks_per_traj > 0: + return self.estimated_chunks_per_traj + + # Auto-detect from custom_envs: look for max_steps in env configs + max_env_steps = 0 + if self.custom_envs: + for tag, cfg in self.custom_envs.items(): + if hasattr(cfg, 'max_steps') and cfg.max_steps is not None: + max_env_steps = max(max_env_steps, int(cfg.max_steps)) + elif isinstance(cfg, dict) and 'max_steps' in cfg: + max_env_steps = max(max_env_steps, int(cfg['max_steps'])) + + if max_env_steps > 1: + # Conservative estimate: half of max_steps as average turns per trajectory + estimate = max(1, max_env_steps // 2) + logger.info( + f"Auto-detected estimated_chunks_per_traj={estimate} " + f"(from custom_envs max_steps={max_env_steps}, using max_steps // 2)" + ) + return estimate + + return 1 + + def set_max_steps(self, max_steps: int): + """Override to account for multi-turn chunking in agentic training. + + For step-level env managers (AgentNativeStepEnvManager), each trajectory + produces multiple training samples (chunks), one per agent turn. The parent + implementation only accounts for rollout_batch_size (trajectories), causing + the LR scheduler to exhaust early when chunks >> trajectories. + + This override multiplies the optimizer step budget by an estimate of chunks + per trajectory to align the scheduler with actual training dynamics. + """ + chunks_multiplier = self._get_chunks_per_traj_estimate() + + actor_backward_batch_size = ( + self.actor_train.training_args.per_device_train_batch_size + * self.actor_train.training_args.gradient_accumulation_steps + ) + critic_backward_batch_size = ( + self.critic.training_args.per_device_train_batch_size + * self.critic.training_args.gradient_accumulation_steps + ) + + self.actor_train.training_args.max_steps = max(1, ( + max_steps + * self.rollout_batch_size + * self.actor_infer.generating_args.num_return_sequences + * self.ppo_epochs + * chunks_multiplier + // actor_backward_batch_size + )) + self.critic.training_args.max_steps = max(1, ( + max_steps + * self.rollout_batch_size + * self.actor_infer.generating_args.num_return_sequences + * chunks_multiplier + // critic_backward_batch_size + )) + + logger.info(f"pipeline max_steps: {self.max_steps} to {max_steps}") + logger.info(f"chunks_per_traj_multiplier: {chunks_multiplier}") + logger.info(f"actor train max_steps without dp_size: {self.actor_train.training_args.max_steps}") + logger.info(f"critic train max_steps without dp_size: {self.critic.training_args.max_steps}") + self.max_steps = max_steps + def make_env_configs(self, env_manager_config: EnvManagerConfig): # construct env configs env_configs = defaultdict(defaultdict)