Skip to content

Commit 7948f95

Browse files
committed
Fix MLLMPretrainHybridPackDataset
1 parent 59120b6 commit 7948f95

File tree

2 files changed

+77
-99
lines changed

2 files changed

+77
-99
lines changed

xtuner/v1/datasets/packing.py

Lines changed: 74 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,20 @@
44
import os
55
import random
66
import tempfile
7+
from collections.abc import Sequence
78
from concurrent.futures import ProcessPoolExecutor
89
from functools import cached_property, partial
910
from multiprocessing import shared_memory
1011
from pathlib import Path
11-
from typing import Sized
12+
from typing import Sized, cast
1213

1314
import numpy as np
1415
import torch
1516
import xxhash
1617
from datasets import Dataset, concatenate_datasets
1718
from torch import distributed as dist
1819
from torch.utils.data import ConcatDataset
20+
from torch.utils.data import Dataset as TorchDataset
1921
from tqdm import tqdm
2022

2123
from xtuner.v1.utils import get_logger, is_local_rank0
@@ -309,7 +311,7 @@ def get_pack_infos_by_expand_soft_split(
309311
class ExpandSoftPackDataset(_LegacySoftPackDataset):
310312
def __init__(
311313
self,
312-
datasets: list[JsonlDataset],
314+
datasets: Sequence[JsonlDataset],
313315
pack_max_length: int = 2048,
314316
global_pack: bool = False,
315317
pack_extra_buffer_size: int = 1000,
@@ -642,7 +644,7 @@ def get_state_dict(self):
642644
def load_state_dict(self, state_dict): ...
643645

644646

645-
class MLLMPretrainHybridPackDataset(_LegacySoftPackDataset):
647+
class MLLMPretrainHybridPackDataset(TorchDataset):
646648
def __init__(
647649
self,
648650
datasets: list[JsonlDataset],
@@ -653,17 +655,12 @@ def __init__(
653655
pack_extra_buffer_size: int = 1000, # for ExpandSoftPackDataset
654656
pack_chunk_size: int = 10000, # for ExpandSoftPackDataset
655657
):
656-
self.pack_extra_buffer_size = pack_extra_buffer_size
657-
self.pack_workers = pack_workers
658-
self.torch_random_generator = torch.Generator()
659-
self.pack_chunk_size = pack_chunk_size
660-
if seed is not None:
661-
self.torch_random_generator.manual_seed(seed)
662-
logger.info(f"Using {self.pack_workers} pack workers for packing datasets.")
663-
664658
self.seed = seed
665-
self.global_pack = global_pack
666659
self.pack_max_length = pack_max_length
660+
self.global_pack = global_pack
661+
self.pack_workers = pack_workers
662+
self.pack_extra_buffer_size = pack_extra_buffer_size
663+
self.pack_chunk_size = pack_chunk_size
667664

668665
hard_pack_groups = []
669666
soft_pack_groups = []
@@ -673,100 +670,81 @@ def __init__(
673670
elif isinstance(dset, JsonlDataset):
674671
hard_pack_groups.append(dset)
675672

676-
if global_pack:
677-
hard_pack_datasets: list[Sized] = []
678-
if len(hard_pack_groups) > 0:
679-
num_tokens = [ndarray_to_mmap(np.concatenate([dset.num_tokens for dset in hard_pack_groups]))]
680-
hard_pack_datasets = [ConcatDataset(hard_pack_groups)]
681-
682-
pack_infos_list = []
683-
for i, dataset in enumerate(hard_pack_datasets):
684-
_infos = self.get_hard_pack_infos(dataset, i, num_tokens[i])
685-
pack_infos_list.extend(_infos)
686-
hard_pack_len = len(pack_infos_list)
687-
688-
soft_pack_datasets: list[Sized] = []
689-
if len(soft_pack_groups) > 0:
690-
num_tokens = [ndarray_to_mmap(np.concatenate([dset.num_tokens for dset in soft_pack_groups]))]
691-
proxy_attn_flops = [
692-
ndarray_to_mmap(np.concatenate([dset.proxy_attn_flops for dset in soft_pack_groups]))
693-
]
694-
695-
soft_pack_datasets = [ConcatDataset(soft_pack_groups)]
696-
for i, dataset in enumerate(soft_pack_datasets):
697-
_infos = self.get_soft_pack_infos(dataset, i, num_tokens[i], proxy_attn_flops[i])
698-
pack_infos_list.extend(_infos)
699-
pack_infos = Dataset.from_list(pack_infos_list)
673+
dataset_list: list[HardPackDataset | ExpandSoftPackDataset] = []
700674

701-
else:
702-
raise NotImplementedError
675+
if hard_pack_groups:
676+
hard_pack_dataset = HardPackDataset(
677+
datasets=hard_pack_groups,
678+
pack_max_length=pack_max_length,
679+
global_pack=global_pack,
680+
seed=seed,
681+
pack_workers=pack_workers,
682+
)
683+
dataset_list.append(hard_pack_dataset)
703684

704-
self.hard_pack_datasets = hard_pack_datasets
705-
self.datasets = soft_pack_datasets
706-
self.hard_pack_len = hard_pack_len
707-
self.pack_infos = pack_infos
685+
if soft_pack_groups:
686+
soft_pack_dataset = ExpandSoftPackDataset(
687+
datasets=soft_pack_groups,
688+
pack_max_length=pack_max_length,
689+
global_pack=global_pack,
690+
pack_extra_buffer_size=pack_extra_buffer_size,
691+
pack_chunk_size=pack_chunk_size,
692+
pack_workers=pack_workers,
693+
seed=seed,
694+
)
695+
dataset_list.append(soft_pack_dataset)
708696

709-
def get_hard_pack_item(self, item: int):
710-
info = self.pack_infos[item]
711-
dataset_id = info["dataset_id"]
712-
ds = self.hard_pack_datasets[dataset_id]
697+
assert dataset_list, "No datasets provided for packing."
698+
self.datasets: ConcatDataset[HardPackDataset | ExpandSoftPackDataset] = ConcatDataset(dataset_list)
713699

714-
indices = info["indices"]
715-
s_off = info["start_offset"]
716-
e_off = info["end_offset"]
700+
@cached_property
701+
def longest(self):
702+
longest_list = []
703+
for dataset in self.datasets.datasets:
704+
longest_list.extend(cast(HardPackDataset | ExpandSoftPackDataset, dataset).longest)
705+
return longest_list
717706

718-
packed_list: list[dict] = []
707+
def __getitem__(self, item: int):
708+
return self.datasets[item]
719709

720-
for i in range(len(indices)):
721-
idx = indices[i]
722-
sample = ds[idx]
723-
ids = sample["input_ids"]
724-
labs = sample.get("labels", None)
710+
def __len__(self) -> int:
711+
return len(self.datasets)
725712

726-
st = 0 if i != 0 else s_off
727-
ed = len(ids) if i != len(indices) - 1 else e_off
713+
def get_state_dict(self):
714+
return {
715+
"pack_max_length": self.pack_max_length,
716+
"seed": self.seed,
717+
"global_pack": self.global_pack,
718+
"pack_extra_buffer_size": self.pack_extra_buffer_size,
719+
"pack_chunk_size": self.pack_chunk_size,
720+
}
728721

729-
packed_list.append(
730-
{
731-
"input_ids": ids[st:ed],
732-
"labels": labs[st:ed] if labs is not None else None,
733-
"num_tokens": ed - st,
734-
}
722+
def load_state_dict(self, state_dict):
723+
if self.seed != state_dict["seed"]:
724+
raise ValueError(
725+
f"Cannot load state dict with different seed . Origin: {state_dict['seed']}, New: {self.seed}"
735726
)
736-
assert (total_num_tokens := sum(i["num_tokens"] for i in packed_list)) == self.pack_max_length, (
737-
f"Internal Error! Found size: {total_num_tokens} mismatch after hard packing."
738-
)
739-
return packed_list
740-
741-
def __getitem__(self, item: int):
742-
if item < self.hard_pack_len:
743-
return self.get_hard_pack_item(item)
744-
else:
745-
return super().__getitem__(item)
746727

747-
def get_hard_pack_infos(self, dataset: Sized, dataset_id: int, num_tokens: np.ndarray):
748-
# shuffled indices
749-
inds = torch.randperm(len(dataset), generator=self.torch_random_generator).tolist()
728+
if self.pack_max_length != state_dict["pack_max_length"]:
729+
raise ValueError(
730+
"Cannot load state dict with different pack_max_length "
731+
f". Origin: {state_dict['pack_max_length']}, New: {self.pack_max_length}"
732+
)
750733

751-
pack_infos_list = get_pack_infos_by_hard_split(
752-
inds, dataset_id, num_tokens, pack_max_length=self.pack_max_length, pack_workers=self.pack_workers
753-
)
754-
return pack_infos_list
734+
if self.global_pack != state_dict["global_pack"]:
735+
raise ValueError(
736+
"Cannot load state dict with different global_pack "
737+
f". Origin: {state_dict['global_pack']}, New: {self.global_pack}"
738+
)
755739

756-
def get_soft_pack_infos(
757-
self, dataset: Sized, dataset_id: int, num_tokens: np.ndarray, proxy_attn_flops: np.ndarray
758-
):
759-
# shuffled indices
760-
inds = torch.randperm(len(dataset), generator=self.torch_random_generator).tolist()
740+
if self.pack_extra_buffer_size != state_dict["pack_extra_buffer_size"]:
741+
raise ValueError(
742+
"Cannot load state dict with different pack_extra_buffer_size "
743+
f". Origin: {state_dict['pack_extra_buffer_size']}, New: {self.pack_extra_buffer_size}"
744+
)
761745

762-
pack_infos_list = get_pack_infos_by_expand_soft_split(
763-
inds,
764-
dataset_id,
765-
num_tokens,
766-
proxy_attn_flops,
767-
pack_max_length=self.pack_max_length,
768-
pack_workers=self.pack_workers,
769-
pack_chunk_size=self.pack_chunk_size,
770-
pack_extra_buffer_size=self.pack_extra_buffer_size,
771-
)
772-
return pack_infos_list
746+
if self.pack_chunk_size != state_dict["pack_chunk_size"]:
747+
raise ValueError(
748+
"Cannot load state dict with different pack_chunk_size "
749+
f". Origin: {state_dict['pack_chunk_size']}, New: {self.pack_chunk_size}"
750+
)

xtuner/v1/datasets/sampler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from xtuner.v1.utils import get_logger
1313

1414
from .jsonl import JsonlDataset
15-
from .packing import _LegacySoftPackDataset
15+
from .packing import MLLMPretrainHybridPackDataset, _LegacySoftPackDataset
1616

1717

1818
logger = get_logger()
@@ -49,7 +49,7 @@ class ParallelSampler(Sampler):
4949

5050
def __init__(
5151
self,
52-
dataset: TorchConcatDataset[JsonlDataset] | _LegacySoftPackDataset,
52+
dataset: TorchConcatDataset[JsonlDataset] | _LegacySoftPackDataset | MLLMPretrainHybridPackDataset,
5353
global_batch_size: int,
5454
dp_mesh: DeviceMesh | None = None,
5555
shuffle: bool = True,
@@ -173,7 +173,7 @@ class LengthGroupedSampler(Sampler):
173173

174174
def __init__(
175175
self,
176-
dataset: _LegacySoftPackDataset,
176+
dataset: _LegacySoftPackDataset | MLLMPretrainHybridPackDataset,
177177
global_batch_size: int,
178178
dp_mesh: DeviceMesh | None = None,
179179
seed: Optional[int] = None,

0 commit comments

Comments
 (0)