Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions docs/guides/grpo.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,31 @@ data:
env_name: "math"
```

We support using multiple datasets for train and validation. You can refer to `examples/configs/grpo_multiple_datasets.yaml` for a full configuration example. Here's an example configuration:
```yaml
data:
_override_: true # override the data config instead of merging with it
# other data settings, see `examples/configs/sft.yaml` for more details
...
# dataset settings
train:
# train dataset 1
- dataset_name: OpenMathInstruct-2
split_validation_size: 0.05 # use 5% of the training data as validation data
seed: 42 # seed for train/validation split when split_validation_size > 0
# train dataset 2
- dataset_name: DeepScaler
validation:
# validation dataset 1
- dataset_name: AIME2024
repeat: 16
# validation dataset 2
- dataset_name: DAPOMathAIME2024
# default settings for all datasets
default:
...
```

We support using a single dataset for both train and validation by using `split_validation_size` to set the validation ratio.
[OpenAssistant](../../nemo_rl/data/datasets/response_datasets/oasst.py), [OpenMathInstruct-2](../../nemo_rl/data/datasets/response_datasets/openmathinstruct2.py), [ResponseDataset](../../nemo_rl/data/datasets/response_datasets/response_dataset.py), [Tulu3SftMixtureDataset](../../nemo_rl/data/datasets/response_datasets/tulu3.py) are supported for this feature.
If you want to support this feature for your custom datasets or other built-in datasets, you can simply add the code to the dataset like [ResponseDataset](../../nemo_rl/data/datasets/response_datasets/response_dataset.py).
Expand Down
25 changes: 25 additions & 0 deletions docs/guides/sft.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,31 @@ data:
processor: "sft_processor"
```

We support using multiple datasets for train and validation. You can refer to `examples/configs/grpo_multiple_datasets.yaml` for a full configuration example. Here's an example configuration:
```yaml
data:
_override_: true # override the data config instead of merging with it
# other data settings, see `examples/configs/sft.yaml` for more details
...
# dataset settings
train:
# train dataset 1
- dataset_name: OpenMathInstruct-2
split_validation_size: 0.05 # use 5% of the training data as validation data
seed: 42 # seed for train/validation split when split_validation_size > 0
# train dataset 2
- dataset_name: DeepScaler
validation:
# validation dataset 1
- dataset_name: AIME2024
repeat: 16
# validation dataset 2
- dataset_name: DAPOMathAIME2024
# default settings for all datasets
default:
...
```

We support using a single dataset for both train and validation by using `split_validation_size` to set the ratio of validation.
[OpenAssistant](../../nemo_rl/data/datasets/response_datasets/oasst.py), [OpenMathInstruct-2](../../nemo_rl/data/datasets/response_datasets/openmathinstruct2.py), [ResponseDataset](../../nemo_rl/data/datasets/response_datasets/response_dataset.py), [Tulu3SftMixtureDataset](../../nemo_rl/data/datasets/response_datasets/tulu3.py) are supported for this feature.
If you want to support this feature for your custom datasets or other built-in datasets, you can simply add the code to the dataset like [ResponseDataset](../../nemo_rl/data/datasets/response_datasets/response_dataset.py).
Expand Down
4 changes: 4 additions & 0 deletions examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,10 @@ data:
system_prompt_file: null
processor: "math_hf_data_processor"
env_name: "math"

# You can also use multiple datasets by using a list of datasets.
# See `examples/configs/grpo_multiple_datasets.yaml` for a full configuration example.

# You can use custom response datasets for training and validation. For example:
# train:
# # this dataset will override input_key and use the default values for other vars
Expand Down
28 changes: 28 additions & 0 deletions examples/configs/grpo_multiple_datasets.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# GRPO Algorithm Configuration
defaults: "grpo_math_1B.yaml"

data:
_override_: true # override the data config instead of merging with it

max_input_seq_length: ${policy.max_total_sequence_length} # upper bound, real truncation occurs at vllm.max_model_len
shuffle: true
num_workers: 1

# dataset
# See https://github.com/NVIDIA-NeMo/RL/blob/main/docs/guides/sft.md#datasets for more details.
train:
- dataset_name: OpenMathInstruct-2
split_validation_size: 0.05 # use 5% of the training data as validation data
seed: ${grpo.seed} # seed for train/validation split when split_validation_size > 0
- dataset_name: DeepScaler
validation:
- dataset_name: AIME2024
repeat: 16
- dataset_name: DAPOMathAIME2024

# default settings for all datasets
default:
prompt_file: "examples/prompts/cot.txt"
system_prompt_file: null
processor: "math_hf_data_processor"
env_name: "math"
7 changes: 5 additions & 2 deletions examples/configs/sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,10 @@ data:
prompt_file: null
system_prompt_file: null
processor: "sft_processor"

# You can also use multiple datasets by using a list of datasets.
# See `examples/configs/grpo_multiple_datasets.yaml` for a full configuration example.

# You can use custom response datasets for training and validation. For example:
# train:
# # this dataset will override input_key and use the default values for other vars
Expand All @@ -212,8 +216,7 @@ data:
# processor: "sft_processor"
# See https://github.com/NVIDIA-NeMo/RL/blob/main/docs/guides/sft.md#datasets for more details.


## OpenAI format specific configs
# OpenAI format specific configs
# train_data_path: "/path/to/train.jsonl" # Path to training data
# val_data_path: "/path/to/val.jsonl" # Path to validation data
# chat_key: "messages" # Key for messages in the data
Expand Down
76 changes: 47 additions & 29 deletions examples/run_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,19 +64,30 @@ def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig):

print("\n▶ Setting up data...")
# setup train dataset
if "default" in data_config:
update_single_dataset_config(data_config["train"], data_config["default"])
data = load_response_dataset(data_config["train"])
data_processor = partial(
data.processor,
add_bos=data_config["add_bos"],
add_eos=data_config["add_eos"],
add_generation_prompt=data_config["add_generation_prompt"],
)
task_data_processors = {data.task_name: (data.task_spec, data_processor)}
task_data_processors = {}
data_list = []

if isinstance(data_config["train"], dict):
data_config["train"] = [data_config["train"]]

for cfg in data_config["train"]:
# load dataset
if "default" in data_config and data_config["default"] is not None:
update_single_dataset_config(cfg, data_config["default"])
data = load_response_dataset(cfg)
data_list.append(data)
# bind task_name to task_data_processors
data_processor = partial(
data.processor,
add_bos=data_config["add_bos"],
add_eos=data_config["add_eos"],
add_generation_prompt=data_config["add_generation_prompt"],
)
task_data_processors[data.task_name] = (data.task_spec, data_processor)

merged_data = concatenate_datasets([data.dataset for data in data_list])
dataset = AllTaskProcessedDataset(
data.dataset,
merged_data,
tokenizer,
None,
task_data_processors,
Expand All @@ -89,28 +100,35 @@ def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig):
val_data_list = []

# validation dataset from train dataset (when train dataset's split_validation_size > 0)
if hasattr(data, "val_dataset") and data.val_dataset is not None:
val_data_list.append(data.val_dataset)
val_task_data_processors = task_data_processors.copy()
for data in data_list:
if hasattr(data, "val_dataset") and data.val_dataset is not None:
val_data_list.append(data.val_dataset)
# bind task_name to task_data_processors
task_name = data.task_name
val_task_data_processors[task_name] = task_data_processors[task_name]

# validation dataset from config
if "validation" in data_config and data_config["validation"] is not None:
if "default" in data_config:
update_single_dataset_config(
data_config["validation"], data_config["default"]
if isinstance(data_config["validation"], dict):
data_config["validation"] = [data_config["validation"]]

for cfg in data_config["validation"]:
# load dataset
if "default" in data_config and data_config["default"] is not None:
update_single_dataset_config(cfg, data_config["default"])
val_data = load_response_dataset(cfg)
val_data_list.append(val_data.dataset)
# bind task_name to task_data_processors
val_data_processor = partial(
val_data.processor,
add_bos=data_config["add_bos"],
add_eos=data_config["add_eos"],
add_generation_prompt=data_config["add_generation_prompt"],
)
val_task_data_processors[val_data.task_name] = (
val_data.task_spec,
val_data_processor,
)
val_data = load_response_dataset(data_config["validation"])
val_data_list.append(val_data.dataset)
val_data_processor = partial(
val_data.processor,
add_bos=data_config["add_bos"],
add_eos=data_config["add_eos"],
add_generation_prompt=data_config["add_generation_prompt"],
)
val_task_data_processors[val_data.task_name] = (
val_data.task_spec,
val_data_processor,
)

val_dataset = None
if len(val_data_list) > 0:
Expand Down
4 changes: 2 additions & 2 deletions nemo_rl/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ class DataConfig(TypedDict):
num_workers: NotRequired[int]
# dataset configs
# TODO: remove NotRequired once preference dataset is refactored
train: NotRequired[ResponseDatasetConfig]
validation: NotRequired[ResponseDatasetConfig | None]
train: NotRequired[ResponseDatasetConfig | list[ResponseDatasetConfig]]
validation: NotRequired[ResponseDatasetConfig | list[ResponseDatasetConfig] | None]
default: NotRequired[ResponseDatasetConfig | None]
# TODO: remove once preference dataset is refactored
dataset_name: NotRequired[str]
Expand Down
65 changes: 42 additions & 23 deletions nemo_rl/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,27 @@ def setup_data_with_envs(

print("\n▶ Setting up data...")
# setup train dataset
if "default" in data_config:
update_single_dataset_config(data_config["train"], data_config["default"])
data = load_response_dataset(data_config["train"])
task_data_processors = {data.task_name: (data.task_spec, data.processor)}
task_to_env = {data.task_name: envs[data_config["train"]["env_name"]]}

task_data_processors = {}
task_to_env = {}
data_list = []

if isinstance(data_config["train"], dict):
data_config["train"] = [data_config["train"]]

for cfg in data_config["train"]:
# load dataset
if "default" in data_config and data_config["default"] is not None:
update_single_dataset_config(cfg, data_config["default"])
data = load_response_dataset(cfg)
data_list.append(data)
# bind task_name to task_data_processors and task_to_env
task_name = data.task_name
task_data_processors[task_name] = (data.task_spec, data.processor)
task_to_env[task_name] = envs[cfg["env_name"]]

merged_data = concatenate_datasets([data.dataset for data in data_list])
dataset = AllTaskProcessedDataset(
data.dataset,
merged_data,
tokenizer,
None,
task_data_processors,
Expand All @@ -89,26 +102,32 @@ def setup_data_with_envs(
val_data_list = []

# validation dataset from train dataset (when train dataset's split_validation_size > 0)
if hasattr(data, "val_dataset") and data.val_dataset is not None:
val_data_list.append(data.val_dataset)
val_task_data_processors = task_data_processors.copy()
val_task_to_env = task_to_env.copy()
for data in data_list:
if hasattr(data, "val_dataset") and data.val_dataset is not None:
val_data_list.append(data.val_dataset)
# bind task_name to task_data_processors and task_to_env
task_name = data.task_name
val_task_data_processors[task_name] = task_data_processors[task_name]
val_task_to_env[task_name] = task_to_env[task_name]

# validation dataset from config
if "validation" in data_config and data_config["validation"] is not None:
if "default" in data_config:
update_single_dataset_config(
data_config["validation"], data_config["default"]
if isinstance(data_config["validation"], dict):
data_config["validation"] = [data_config["validation"]]

for cfg in data_config["validation"]:
# load dataset
if "default" in data_config and data_config["default"] is not None:
update_single_dataset_config(cfg, data_config["default"])
val_data = load_response_dataset(cfg)
val_data_list.append(val_data.dataset)
# bind task_name to task_data_processors and task_to_env
task_name = val_data.task_name
val_task_data_processors[task_name] = (
val_data.task_spec,
val_data.processor,
)
val_data = load_response_dataset(data_config["validation"])
val_data_list.append(val_data.dataset)
val_task_data_processors[val_data.task_name] = (
val_data.task_spec,
val_data.processor,
)
val_task_to_env[val_data.task_name] = envs[
data_config["validation"]["env_name"]
]
val_task_to_env[task_name] = envs[cfg["env_name"]]

val_dataset = None
if len(val_data_list) > 0:
Expand Down
23 changes: 21 additions & 2 deletions nemo_rl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,23 @@ def resolve_path(base_path: Path, path: str) -> Path:
return base_path / path


def merge_with_override(
base_config: DictConfig, override_config: DictConfig
) -> DictConfig:
"""Merge configs with support for _override_ marker to completely override sections."""
for key in list(override_config.keys()):
if isinstance(override_config[key], DictConfig):
if override_config[key].get("_override_", False):
# remove the _override_ marker
override_config[key].pop("_override_")
# remove the key from base_config so it won't be merged
if key in base_config:
base_config.pop(key)

merged_config = cast(DictConfig, OmegaConf.merge(base_config, override_config))
return merged_config


def load_config_with_inheritance(
config_path: Union[str, Path],
base_dir: Optional[Union[str, Path]] = None,
Expand Down Expand Up @@ -63,10 +80,12 @@ def load_config_with_inheritance(
for default in defaults:
parent_path = resolve_path(base_dir, str(default))
parent_config = load_config_with_inheritance(parent_path, base_dir)
base_config = cast(DictConfig, OmegaConf.merge(base_config, parent_config))
base_config = cast(
DictConfig, merge_with_override(base_config, parent_config)
)

# Merge with current config
config = cast(DictConfig, OmegaConf.merge(base_config, config))
config = cast(DictConfig, merge_with_override(base_config, config))

return config

Expand Down
1 change: 1 addition & 0 deletions tests/functional/L1_Functional_Tests_GPU.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ time uv run --no-sync bash ./tests/functional/grpo_megatron_generation.sh
time uv run --no-sync bash ./tests/functional/grpo_multiturn.sh
time uv run --no-sync bash ./tests/functional/grpo_non_colocated.sh
time uv run --no-sync bash ./tests/functional/grpo_sglang.sh
time uv run --no-sync bash ./tests/functional/grpo_multiple_datasets.sh
time uv run --no-sync bash ./tests/functional/dpo.sh
time uv run --no-sync bash ./tests/functional/rm.sh
time uv run --no-sync bash ./tests/functional/eval.sh
Expand Down
Loading
Loading