-
Notifications
You must be signed in to change notification settings - Fork 266
Open
Description
Hey,
I wonder how to approach debugging the RL training (rewards, etc.). Its nice when things are set up, as in the example in notebooks:
env_name = "CartpoleBalance"
ppo_params = dm_control_suite_params.brax_ppo_config(env_name)
sac_params = dm_control_suite_params.brax_sac_config(env_name)
# PPO:
x_data, y_data, y_dataerr = [], [], []
times = [datetime.now()]
def progress(num_steps, metrics):
clear_output(wait=True)
times.append(datetime.now())
x_data.append(num_steps)
y_data.append(metrics["eval/episode_reward"])
y_dataerr.append(metrics["eval/episode_reward_std"])
plt.xlim([0, ppo_params["num_timesteps"] * 1.25])
plt.ylim([0, 1100])
plt.xlabel("# environment steps")
plt.ylabel("reward per episode")
plt.title(f"y={y_data[-1]:.3f}")
plt.errorbar(x_data, y_data, yerr=y_dataerr, color="blue")
display(plt.gcf())
ppo_training_params = dict(ppo_params)
network_factory = ppo_networks.make_ppo_networks
if "network_factory" in ppo_params:
del ppo_training_params["network_factory"]
network_factory = functools.partial(
ppo_networks.make_ppo_networks,
**ppo_params.network_factory
)
train_fn = functools.partial(
ppo.train, **dict(ppo_training_params),
network_factory=network_factory,
progress_fn=progress
)
However, no synchronous visualization is present. There are ways to store the training data, and visualize via rscope, however I am experiencing issues with blocking (it seams to be not actively maintained)
Also, plotting all the information etc.
Do you have recommendations on how to approach debugging RL for new environments (how to have syncronous visualization and data plotting, examples on setting parameters so that I can call progress_fn more often, etc.). In general, maybe you have some aspects how you approach debugging the learning environments
Abban-Fahim and lenguyen1807
Metadata
Metadata
Assignees
Labels
No labels