1313# limitations under the License.
1414
1515import argparse
16- import json
1716import os
1817import pprint
19- from itertools import chain , repeat
2018from typing import Dict , Optional
2119
2220# Increase the W&B single object size warning threshold. Initially 100_000 (100 KB) -> 10_000_000 (10 MB)
4543from nemo_rl .algorithms .utils import get_tokenizer
4644from nemo_rl .data .datasets import (
4745 AllTaskProcessedDataset ,
46+ extract_necessary_env_names ,
4847 load_response_dataset ,
4948 update_single_dataset_config ,
5049)
51- from nemo_rl .data .interfaces import DatumSpec
5250from nemo_rl .distributed .virtual_cluster import init_ray
5351from nemo_rl .environments .nemo_gym import (
5452 NemoGymConfig ,
55- nemo_gym_example_to_nemo_rl_datum_spec ,
5653 setup_nemo_gym_config ,
5754)
5855from nemo_rl .environments .utils import create_env
@@ -77,40 +74,6 @@ def parse_args() -> tuple[argparse.Namespace, list[str]]:
7774 return args , overrides
7875
7976
80- def setup_single_nemo_gym_dataset (
81- jsonl_fpath : str , tokenizer , num_repeats : Optional [int ] = None
82- ):
83- with open (jsonl_fpath ) as f :
84- nemo_gym_examples = list (map (json .loads , f ))
85-
86- print (f"Loaded data at { jsonl_fpath } . Found { len (nemo_gym_examples )} examples" )
87-
88- if num_repeats :
89- previous_length = len (nemo_gym_examples )
90- nemo_gym_examples = list (
91- chain .from_iterable (
92- repeat (nemo_gym_example , num_repeats )
93- for nemo_gym_example in nemo_gym_examples
94- )
95- )
96- print (
97- f"Repeating examples (in a pattern of abc to aabbcc) for { jsonl_fpath } from { previous_length } to { len (nemo_gym_examples )} !"
98- )
99-
100- nemo_rl_compatible_examples : list [DatumSpec ] = [
101- nemo_gym_example_to_nemo_rl_datum_spec (nemo_gym_example , idx )
102- for idx , nemo_gym_example in enumerate (nemo_gym_examples )
103- ]
104-
105- passthrough_task_processor = lambda datum_dict , * args , ** kwargs : datum_dict
106- return AllTaskProcessedDataset (
107- nemo_rl_compatible_examples ,
108- tokenizer ,
109- None ,
110- passthrough_task_processor ,
111- )
112-
113-
11477def setup_data (
11578 tokenizer : TokenizerType ,
11679 data_config : Dict ,
@@ -122,18 +85,31 @@ def setup_data(
12285 dict [str , EnvironmentInterface ],
12386 dict [str , EnvironmentInterface ],
12487]:
88+ print ("\n ▶ Setting up envs..." )
89+ env_name_list = extract_necessary_env_names (data_config )
90+ envs = {
91+ env_name : create_env (env_name = env_name , env_config = env_configs [env_name ])
92+ for env_name in env_name_list
93+ if env_name != "nemo_gym"
94+ }
12595 print ("\n ▶ Setting up data..." )
12696 # setup train dataset
127- data_list = []
12897 task_data_processors = {}
98+ task_to_env = {}
99+ data_list = []
129100
130101 if isinstance (data_config ["train" ], dict ):
131102 data_config ["train" ] = [data_config ["train" ]]
132103 for cfg in data_config ["train" ]:
133104 update_single_dataset_config (cfg , data_config ["default" ])
134105 data = load_response_dataset (cfg , seed )
135106 data_list .append (data )
136- task_data_processors [data .task_name ] = (data .task_spec , data .processor )
107+ # bind task_name to task_data_processors and task_to_env
108+ task_name = data .task_name
109+ task_data_processors [task_name ] = (data .task_spec , data .processor )
110+ # Skip binding nemo_gym env to task_to_env, nemo_gym env need to initialize policy first
111+ if cfg ["env_name" ] != "nemo_gym" :
112+ task_to_env [task_name ] = envs [cfg ["env_name" ]]
137113
138114 merged_data = concatenate_datasets ([data .dataset for data in data_list ])
139115 dataset = AllTaskProcessedDataset (
@@ -147,6 +123,7 @@ def setup_data(
147123
148124 # setup validation dataset
149125 val_task_data_processors = {}
126+ val_task_to_env = {}
150127 val_data_list = []
151128
152129 for data in data_list :
@@ -155,6 +132,8 @@ def setup_data(
155132 # bind task_name to task_data_processors
156133 task_name = data .task_name
157134 val_task_data_processors [task_name ] = task_data_processors [task_name ]
135+ if task_name in task_to_env :
136+ val_task_to_env [task_name ] = task_to_env [task_name ]
158137
159138 if data_config ["validation" ] is not None :
160139 if isinstance (data_config ["validation" ], dict ):
@@ -165,10 +144,13 @@ def setup_data(
165144 val_data = load_response_dataset (cfg , seed )
166145 val_data_list .append (val_data .dataset )
167146 # bind task_name to task_data_processors
168- val_task_data_processors [val_data .task_name ] = (
147+ task_name = val_data .task_name
148+ val_task_data_processors [task_name ] = (
169149 val_data .task_spec ,
170150 val_data .processor ,
171151 )
152+ if cfg ["env_name" ] != "nemo_gym" :
153+ val_task_to_env [task_name ] = envs [cfg ["env_name" ]]
172154
173155 val_dataset = None
174156 if len (val_data_list ) > 0 :
@@ -182,7 +164,7 @@ def setup_data(
182164 )
183165 print (f" ✓ Validation dataset loaded with { len (val_dataset )} samples." )
184166
185- return dataset , val_dataset
167+ return dataset , val_dataset , task_to_env , val_task_to_env
186168
187169
188170# These types are directly imported from grpo_train since if something about the architecture changes we want to immediately fail.
@@ -278,7 +260,7 @@ def main() -> None:
278260 assert _should_use_nemo_gym (config )
279261
280262 print ("\n ▶ Setting up data..." )
281- train_dataset , val_dataset = setup_data (
263+ train_dataset , val_dataset , task_to_env , val_task_to_env = setup_data (
282264 tokenizer = tokenizer ,
283265 data_config = config ["data" ],
284266 env_configs = config ["env" ],
@@ -328,11 +310,12 @@ def main() -> None:
328310 base_urls = policy_generation .dp_openai_server_base_urls ,
329311 initial_global_config_dict = config ["env" ]["nemo_gym" ],
330312 )
313+ # Default nemo_gym env is used for trajectory collection
331314 nemo_gym = create_env (env_name = "nemo_gym" , env_config = nemo_gym_config )
332315 # Blocking wait for NeMo-Gym to spin up
333316 ray .get (nemo_gym .health_check .remote ())
334- task_to_env = { "nemo_gym" : nemo_gym }
335- val_task_to_env = task_to_env
317+ task_to_env [ "nemo_gym" ] = nemo_gym
318+ val_task_to_env [ "nemo_gym" ] = nemo_gym
336319
337320 if is_trajectory_collection :
338321 collect_trajectories (
0 commit comments