Skip to content

Commit c0b8cde

Browse files
committed
support run nemo-gym grpo
Signed-off-by: ruit <ruit@nvidia.com>
1 parent dac1fe0 commit c0b8cde

File tree

6 files changed

+188
-22
lines changed

6 files changed

+188
-22
lines changed

examples/nemo_gym/grpo_dapo17k_bytedtsinghua_qwen3_4binstruct_nf.yaml

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ policy:
211211
num_first_layers_in_bf16: 0
212212
expose_http_server: true
213213
skip_tokenizer_init: false
214-
kv_cache_dtype: null
214+
kv_cache_dtype: ${policy.precision}
215215
http_server_serving_chat_kwargs:
216216
# This is the tool parser for Qwen 3 4B Instruct. This needs to be changed for other models.
217217
enable_auto_tools: true
@@ -234,10 +234,21 @@ policy:
234234
num_nodes: null # Decides number of nodes to be dedicated to generation
235235

236236
data:
237-
train_jsonl_fpath: 3rdparty/Gym-workspace/Gym/data/train.jsonl
238-
validation_jsonl_fpath: 3rdparty/Gym-workspace/Gym/data/validation.jsonl
237+
max_input_seq_length: ${policy.max_total_sequence_length}
239238
shuffle: true
240239
num_workers: 0
240+
train:
241+
dataset_name: NemoGymDataset
242+
data_path: 3rdparty/Gym-workspace/Gym/data/train.jsonl
243+
repeat: 1
244+
validation:
245+
dataset_name: NemoGymDataset
246+
data_path: 3rdparty/Gym-workspace/Gym/data/validation.jsonl
247+
default:
248+
env_name: "nemo_gym"
249+
prompt_file: null
250+
system_prompt_file: null
251+
processor: "nemo_gym_data_processor"
241252

242253
env:
243254
should_use_nemo_gym: true

examples/nemo_gym/run_grpo_nemo_gym.py

Lines changed: 87 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,15 @@
1717
import os
1818
import pprint
1919
from itertools import chain, repeat
20-
from typing import Optional
20+
from typing import Dict, Optional
2121

2222
# Increase the W&B single object size warning threshold. Initially 100_000 (100 KB) -> 10_000_000 (10 MB)
2323
import wandb.util
2424

2525
wandb.util.VALUE_BYTES_LIMIT = 10_000_000
2626

2727
import ray
28+
from datasets import concatenate_datasets
2829
from omegaconf import OmegaConf
2930
from wandb import Table
3031

@@ -42,18 +43,19 @@
4243
setup,
4344
)
4445
from nemo_rl.algorithms.utils import get_tokenizer
45-
from nemo_rl.data.datasets import AllTaskProcessedDataset
46-
from nemo_rl.data.interfaces import DatumSpec
47-
from nemo_rl.distributed.ray_actor_environment_registry import (
48-
get_actor_python_env,
46+
from nemo_rl.data.datasets import (
47+
AllTaskProcessedDataset,
48+
load_response_dataset,
49+
update_single_dataset_config,
4950
)
51+
from nemo_rl.data.interfaces import DatumSpec
5052
from nemo_rl.distributed.virtual_cluster import init_ray
5153
from nemo_rl.environments.nemo_gym import (
52-
NemoGym,
5354
NemoGymConfig,
5455
nemo_gym_example_to_nemo_rl_datum_spec,
5556
setup_nemo_gym_config,
5657
)
58+
from nemo_rl.environments.utils import create_env
5759
from nemo_rl.experience.rollouts import run_async_nemo_gym_rollout
5860
from nemo_rl.models.generation import configure_generation_config
5961
from nemo_rl.utils.config import load_config, parse_hydra_overrides
@@ -109,6 +111,80 @@ def setup_single_nemo_gym_dataset(
109111
)
110112

111113

114+
def setup_data(
115+
tokenizer: TokenizerType,
116+
data_config: Dict,
117+
env_configs: Dict,
118+
seed: int,
119+
) -> tuple[
120+
AllTaskProcessedDataset,
121+
Optional[AllTaskProcessedDataset],
122+
dict[str, EnvironmentInterface],
123+
dict[str, EnvironmentInterface],
124+
]:
125+
print("\n▶ Setting up data...")
126+
# setup train dataset
127+
data_list = []
128+
task_data_processors = {}
129+
130+
if isinstance(data_config["train"], dict):
131+
data_config["train"] = [data_config["train"]]
132+
for cfg in data_config["train"]:
133+
update_single_dataset_config(cfg, data_config["default"])
134+
data = load_response_dataset(cfg, seed)
135+
data_list.append(data)
136+
task_data_processors[data.task_name] = (data.task_spec, data.processor)
137+
138+
merged_data = concatenate_datasets([data.dataset for data in data_list])
139+
dataset = AllTaskProcessedDataset(
140+
merged_data,
141+
tokenizer,
142+
None,
143+
task_data_processors,
144+
max_seq_length=data_config["max_input_seq_length"],
145+
)
146+
print(f" ✓ Training dataset loaded with {len(dataset)} samples.")
147+
148+
# setup validation dataset
149+
val_task_data_processors = {}
150+
val_data_list = []
151+
152+
for data in data_list:
153+
if hasattr(data, "val_dataset") and data.val_dataset is not None:
154+
val_data_list.append(data.val_dataset)
155+
# bind task_name to task_data_processors
156+
task_name = data.task_name
157+
val_task_data_processors[task_name] = task_data_processors[task_name]
158+
159+
if data_config["validation"] is not None:
160+
if isinstance(data_config["validation"], dict):
161+
data_config["validation"] = [data_config["validation"]]
162+
163+
for cfg in data_config["validation"]:
164+
update_single_dataset_config(cfg, data_config["default"])
165+
val_data = load_response_dataset(cfg, seed)
166+
val_data_list.append(val_data.dataset)
167+
# bind task_name to task_data_processors
168+
val_task_data_processors[val_data.task_name] = (
169+
val_data.task_spec,
170+
val_data.processor,
171+
)
172+
173+
val_dataset = None
174+
if len(val_data_list) > 0:
175+
merged_val_data = concatenate_datasets(val_data_list)
176+
val_dataset = AllTaskProcessedDataset(
177+
merged_val_data,
178+
tokenizer,
179+
None,
180+
val_task_data_processors,
181+
max_seq_length=data_config["max_input_seq_length"],
182+
)
183+
print(f" ✓ Validation dataset loaded with {len(val_dataset)} samples.")
184+
185+
return dataset, val_dataset
186+
187+
112188
# These types are directly imported from grpo_train since if something about the architecture changes we want to immediately fail.
113189
def collect_trajectories(
114190
policy: ColocatablePolicyInterface,
@@ -202,13 +278,11 @@ def main() -> None:
202278
assert _should_use_nemo_gym(config)
203279

204280
print("\n▶ Setting up data...")
205-
train_dataset = setup_single_nemo_gym_dataset(
206-
jsonl_fpath=config["data"]["train_jsonl_fpath"],
207-
tokenizer=tokenizer,
208-
)
209-
val_dataset = setup_single_nemo_gym_dataset(
210-
jsonl_fpath=config["data"]["validation_jsonl_fpath"],
281+
train_dataset, val_dataset = setup_data(
211282
tokenizer=tokenizer,
283+
data_config=config["data"],
284+
env_configs=config["env"],
285+
seed=config["grpo"]["seed"],
212286
)
213287

214288
# Validation dataset config setup.
@@ -254,13 +328,7 @@ def main() -> None:
254328
base_urls=policy_generation.dp_openai_server_base_urls,
255329
initial_global_config_dict=config["env"]["nemo_gym"],
256330
)
257-
nemo_gym = NemoGym.options(
258-
runtime_env={
259-
"py_executable": get_actor_python_env(
260-
"nemo_rl.environments.nemo_gym.NemoGym"
261-
),
262-
}
263-
).remote(nemo_gym_config)
331+
nemo_gym = create_env(env_name="nemo_gym", env_config=nemo_gym_config)
264332
# Blocking wait for NeMo-Gym to spin up
265333
ray.get(nemo_gym.health_check.remote())
266334
task_to_env = {"nemo_gym": nemo_gym}

nemo_rl/data/datasets/response_datasets/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from nemo_rl.data.datasets.response_datasets.deepscaler import DeepScalerDataset
2525
from nemo_rl.data.datasets.response_datasets.geometry3k import Geometry3KDataset
2626
from nemo_rl.data.datasets.response_datasets.helpsteer3 import HelpSteer3Dataset
27+
from nemo_rl.data.datasets.response_datasets.nemogym_dataset import NemoGymDataset
2728
from nemo_rl.data.datasets.response_datasets.oai_format_dataset import (
2829
OpenAIFormatDataset,
2930
)
@@ -87,6 +88,8 @@ def load_response_dataset(data_config: ResponseDatasetConfig, seed: int = 42):
8788
**data_config, # pyrefly: ignore[missing-argument] `data_path` is required for this class
8889
seed=seed,
8990
)
91+
elif dataset_name == "NemoGymDataset":
92+
base_dataset: Any = NemoGymDataset(**data_config)
9093
else:
9194
raise ValueError(
9295
f"Unsupported {dataset_name=}. "
@@ -115,4 +118,5 @@ def load_response_dataset(data_config: ResponseDatasetConfig, seed: int = 42):
115118
"SquadDataset",
116119
"Tulu3SftMixtureDataset",
117120
"HelpSteer3Dataset",
121+
"NemoGymDataset",
118122
]
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Any, Optional
16+
17+
import torch
18+
19+
from nemo_rl.data.datasets.raw_dataset import RawDataset
20+
from nemo_rl.data.datasets.utils import load_dataset_from_path
21+
22+
23+
class NemoGymDataset(RawDataset):
24+
"""Simple wrapper around the Nemo Gym dataset."""
25+
26+
def __init__(self, data_path: Optional[str] = None, **kwargs) -> None:
27+
self.task_name = "NemoGymDataset"
28+
29+
# load from jsonl
30+
if data_path is None:
31+
# Allow optional at type level for config validation; enforce at runtime for clarity
32+
raise ValueError(
33+
"NemoGymDataset requires `data_path` in data_config to load examples."
34+
)
35+
self.dataset = load_dataset_from_path(data_path)
36+
37+
# format the dataset
38+
# HuggingFace Dataset 在 map/写入 Arrow 时不会持久化 torch.Tensor,会把它序列化成 Python 列表。因此下游在取样时读到的是 [](list),触发断言
39+
self.dataset = self.dataset.map(
40+
self.format_data,
41+
with_indices=True,
42+
)
43+
if "repeat" in kwargs:
44+
self.dataset = self.dataset.repeat(kwargs["repeat"])
45+
46+
def format_data(self, data: dict[str, Any], idx: int) -> dict[str, Any]:
47+
return {
48+
"message_log": [
49+
{"role": "user", "content": "", "token_ids": torch.tensor([])}
50+
],
51+
"task_name": self.task_name,
52+
"length": 0,
53+
"extra_env_info": data,
54+
"loss_multiplier": 1.0, # Fix to 1.0 to backprop on all examples
55+
"idx": idx,
56+
"stop_strings": None,
57+
# Extra vars
58+
"token_ids": [], # Just need this empty key to be compatible with the current NeMo RL GRPO impl
59+
}

nemo_rl/data/processors.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,26 @@ def multichoice_qa_processor(
538538
return output
539539

540540

541+
def nemo_gym_data_processor(
542+
datum_dict: dict[str, Any],
543+
*args,
544+
**kwargs,
545+
) -> DatumSpec:
546+
"""Process a datum dictionary (directly loaded from dataset) into a DatumSpec for Nemo Gym."""
547+
# Ensure message_log exists and contains tensor token_ids so downstream padding works
548+
if "message_log" not in datum_dict or not datum_dict["message_log"]:
549+
datum_dict["message_log"] = [
550+
{"role": "user", "content": "", "token_ids": torch.tensor([])}
551+
]
552+
else:
553+
for msg in datum_dict["message_log"]:
554+
if "token_ids" not in msg:
555+
msg["token_ids"] = torch.tensor([])
556+
elif not isinstance(msg["token_ids"], torch.Tensor):
557+
msg["token_ids"] = torch.tensor(msg["token_ids"])
558+
return cast(DatumSpec, datum_dict)
559+
560+
541561
# Processor registry. Key is the processor name, value is the processor function.
542562
# Note: We cast the literal dict to Dict[str, TaskDataProcessFnCallable] because
543563
# type checkers see each concrete function's signature as a distinct callable type.
@@ -554,6 +574,7 @@ def multichoice_qa_processor(
554574
"multichoice_qa_processor": multichoice_qa_processor,
555575
"sft_processor": sft_processor,
556576
"vlm_hf_data_processor": vlm_hf_data_processor,
577+
"nemo_gym_data_processor": nemo_gym_data_processor,
557578
},
558579
)
559580

nemo_rl/environments/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ class EnvRegistryEntry(TypedDict, total=False):
4646
"vlm": {
4747
"actor_class_fqn": "nemo_rl.environments.vlm_environment.VLMEnvironment",
4848
},
49+
"nemo_gym": {
50+
"actor_class_fqn": "nemo_rl.environments.nemo_gym.NemoGym",
51+
},
4952
}
5053

5154

0 commit comments

Comments
 (0)