Skip to content

Commit d9836a6

Browse files
committed
unify nemo gym interaface
Signed-off-by: ruit <ruit@nvidia.com>
1 parent c0b8cde commit d9836a6

File tree

1 file changed

+28
-45
lines changed

1 file changed

+28
-45
lines changed

examples/nemo_gym/run_grpo_nemo_gym.py

Lines changed: 28 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,8 @@
1313
# limitations under the License.
1414

1515
import argparse
16-
import json
1716
import os
1817
import pprint
19-
from itertools import chain, repeat
2018
from typing import Dict, Optional
2119

2220
# Increase the W&B single object size warning threshold. Initially 100_000 (100 KB) -> 10_000_000 (10 MB)
@@ -45,14 +43,13 @@
4543
from nemo_rl.algorithms.utils import get_tokenizer
4644
from nemo_rl.data.datasets import (
4745
AllTaskProcessedDataset,
46+
extract_necessary_env_names,
4847
load_response_dataset,
4948
update_single_dataset_config,
5049
)
51-
from nemo_rl.data.interfaces import DatumSpec
5250
from nemo_rl.distributed.virtual_cluster import init_ray
5351
from nemo_rl.environments.nemo_gym import (
5452
NemoGymConfig,
55-
nemo_gym_example_to_nemo_rl_datum_spec,
5653
setup_nemo_gym_config,
5754
)
5855
from nemo_rl.environments.utils import create_env
@@ -77,40 +74,6 @@ def parse_args() -> tuple[argparse.Namespace, list[str]]:
7774
return args, overrides
7875

7976

80-
def setup_single_nemo_gym_dataset(
81-
jsonl_fpath: str, tokenizer, num_repeats: Optional[int] = None
82-
):
83-
with open(jsonl_fpath) as f:
84-
nemo_gym_examples = list(map(json.loads, f))
85-
86-
print(f"Loaded data at {jsonl_fpath}. Found {len(nemo_gym_examples)} examples")
87-
88-
if num_repeats:
89-
previous_length = len(nemo_gym_examples)
90-
nemo_gym_examples = list(
91-
chain.from_iterable(
92-
repeat(nemo_gym_example, num_repeats)
93-
for nemo_gym_example in nemo_gym_examples
94-
)
95-
)
96-
print(
97-
f"Repeating examples (in a pattern of abc to aabbcc) for {jsonl_fpath} from {previous_length} to {len(nemo_gym_examples)}!"
98-
)
99-
100-
nemo_rl_compatible_examples: list[DatumSpec] = [
101-
nemo_gym_example_to_nemo_rl_datum_spec(nemo_gym_example, idx)
102-
for idx, nemo_gym_example in enumerate(nemo_gym_examples)
103-
]
104-
105-
passthrough_task_processor = lambda datum_dict, *args, **kwargs: datum_dict
106-
return AllTaskProcessedDataset(
107-
nemo_rl_compatible_examples,
108-
tokenizer,
109-
None,
110-
passthrough_task_processor,
111-
)
112-
113-
11477
def setup_data(
11578
tokenizer: TokenizerType,
11679
data_config: Dict,
@@ -122,18 +85,31 @@ def setup_data(
12285
dict[str, EnvironmentInterface],
12386
dict[str, EnvironmentInterface],
12487
]:
88+
print("\n▶ Setting up envs...")
89+
env_name_list = extract_necessary_env_names(data_config)
90+
envs = {
91+
env_name: create_env(env_name=env_name, env_config=env_configs[env_name])
92+
for env_name in env_name_list
93+
if env_name != "nemo_gym"
94+
}
12595
print("\n▶ Setting up data...")
12696
# setup train dataset
127-
data_list = []
12897
task_data_processors = {}
98+
task_to_env = {}
99+
data_list = []
129100

130101
if isinstance(data_config["train"], dict):
131102
data_config["train"] = [data_config["train"]]
132103
for cfg in data_config["train"]:
133104
update_single_dataset_config(cfg, data_config["default"])
134105
data = load_response_dataset(cfg, seed)
135106
data_list.append(data)
136-
task_data_processors[data.task_name] = (data.task_spec, data.processor)
107+
# bind task_name to task_data_processors and task_to_env
108+
task_name = data.task_name
109+
task_data_processors[task_name] = (data.task_spec, data.processor)
110+
# Skip binding nemo_gym env to task_to_env, nemo_gym env need to initialize policy first
111+
if cfg["env_name"] != "nemo_gym":
112+
task_to_env[task_name] = envs[cfg["env_name"]]
137113

138114
merged_data = concatenate_datasets([data.dataset for data in data_list])
139115
dataset = AllTaskProcessedDataset(
@@ -147,6 +123,7 @@ def setup_data(
147123

148124
# setup validation dataset
149125
val_task_data_processors = {}
126+
val_task_to_env = {}
150127
val_data_list = []
151128

152129
for data in data_list:
@@ -155,6 +132,8 @@ def setup_data(
155132
# bind task_name to task_data_processors
156133
task_name = data.task_name
157134
val_task_data_processors[task_name] = task_data_processors[task_name]
135+
if task_name in task_to_env:
136+
val_task_to_env[task_name] = task_to_env[task_name]
158137

159138
if data_config["validation"] is not None:
160139
if isinstance(data_config["validation"], dict):
@@ -165,10 +144,13 @@ def setup_data(
165144
val_data = load_response_dataset(cfg, seed)
166145
val_data_list.append(val_data.dataset)
167146
# bind task_name to task_data_processors
168-
val_task_data_processors[val_data.task_name] = (
147+
task_name = val_data.task_name
148+
val_task_data_processors[task_name] = (
169149
val_data.task_spec,
170150
val_data.processor,
171151
)
152+
if cfg["env_name"] != "nemo_gym":
153+
val_task_to_env[task_name] = envs[cfg["env_name"]]
172154

173155
val_dataset = None
174156
if len(val_data_list) > 0:
@@ -182,7 +164,7 @@ def setup_data(
182164
)
183165
print(f" ✓ Validation dataset loaded with {len(val_dataset)} samples.")
184166

185-
return dataset, val_dataset
167+
return dataset, val_dataset, task_to_env, val_task_to_env
186168

187169

188170
# These types are directly imported from grpo_train since if something about the architecture changes we want to immediately fail.
@@ -278,7 +260,7 @@ def main() -> None:
278260
assert _should_use_nemo_gym(config)
279261

280262
print("\n▶ Setting up data...")
281-
train_dataset, val_dataset = setup_data(
263+
train_dataset, val_dataset, task_to_env, val_task_to_env = setup_data(
282264
tokenizer=tokenizer,
283265
data_config=config["data"],
284266
env_configs=config["env"],
@@ -328,11 +310,12 @@ def main() -> None:
328310
base_urls=policy_generation.dp_openai_server_base_urls,
329311
initial_global_config_dict=config["env"]["nemo_gym"],
330312
)
313+
# Default nemo_gym env is used for trajectory collection
331314
nemo_gym = create_env(env_name="nemo_gym", env_config=nemo_gym_config)
332315
# Blocking wait for NeMo-Gym to spin up
333316
ray.get(nemo_gym.health_check.remote())
334-
task_to_env = {"nemo_gym": nemo_gym}
335-
val_task_to_env = task_to_env
317+
task_to_env["nemo_gym"] = nemo_gym
318+
val_task_to_env["nemo_gym"] = nemo_gym
336319

337320
if is_trajectory_collection:
338321
collect_trajectories(

0 commit comments

Comments
 (0)