44import os
55import random
66import tempfile
7+ from collections .abc import Sequence
78from concurrent .futures import ProcessPoolExecutor
89from functools import cached_property , partial
910from multiprocessing import shared_memory
1011from pathlib import Path
11- from typing import Sized
12+ from typing import Sized , cast
1213
1314import numpy as np
1415import torch
1516import xxhash
1617from datasets import Dataset , concatenate_datasets
1718from torch import distributed as dist
1819from torch .utils .data import ConcatDataset
20+ from torch .utils .data import Dataset as TorchDataset
1921from tqdm import tqdm
2022
2123from xtuner .v1 .utils import get_logger , is_local_rank0
@@ -309,7 +311,7 @@ def get_pack_infos_by_expand_soft_split(
309311class ExpandSoftPackDataset (_LegacySoftPackDataset ):
310312 def __init__ (
311313 self ,
312- datasets : list [JsonlDataset ],
314+ datasets : Sequence [JsonlDataset ],
313315 pack_max_length : int = 2048 ,
314316 global_pack : bool = False ,
315317 pack_extra_buffer_size : int = 1000 ,
@@ -642,7 +644,7 @@ def get_state_dict(self):
642644 def load_state_dict (self , state_dict ): ...
643645
644646
645- class MLLMPretrainHybridPackDataset (_LegacySoftPackDataset ):
647+ class MLLMPretrainHybridPackDataset (TorchDataset ):
646648 def __init__ (
647649 self ,
648650 datasets : list [JsonlDataset ],
@@ -653,17 +655,12 @@ def __init__(
653655 pack_extra_buffer_size : int = 1000 , # for ExpandSoftPackDataset
654656 pack_chunk_size : int = 10000 , # for ExpandSoftPackDataset
655657 ):
656- self .pack_extra_buffer_size = pack_extra_buffer_size
657- self .pack_workers = pack_workers
658- self .torch_random_generator = torch .Generator ()
659- self .pack_chunk_size = pack_chunk_size
660- if seed is not None :
661- self .torch_random_generator .manual_seed (seed )
662- logger .info (f"Using { self .pack_workers } pack workers for packing datasets." )
663-
664658 self .seed = seed
665- self .global_pack = global_pack
666659 self .pack_max_length = pack_max_length
660+ self .global_pack = global_pack
661+ self .pack_workers = pack_workers
662+ self .pack_extra_buffer_size = pack_extra_buffer_size
663+ self .pack_chunk_size = pack_chunk_size
667664
668665 hard_pack_groups = []
669666 soft_pack_groups = []
@@ -673,100 +670,81 @@ def __init__(
673670 elif isinstance (dset , JsonlDataset ):
674671 hard_pack_groups .append (dset )
675672
676- if global_pack :
677- hard_pack_datasets : list [Sized ] = []
678- if len (hard_pack_groups ) > 0 :
679- num_tokens = [ndarray_to_mmap (np .concatenate ([dset .num_tokens for dset in hard_pack_groups ]))]
680- hard_pack_datasets = [ConcatDataset (hard_pack_groups )]
681-
682- pack_infos_list = []
683- for i , dataset in enumerate (hard_pack_datasets ):
684- _infos = self .get_hard_pack_infos (dataset , i , num_tokens [i ])
685- pack_infos_list .extend (_infos )
686- hard_pack_len = len (pack_infos_list )
687-
688- soft_pack_datasets : list [Sized ] = []
689- if len (soft_pack_groups ) > 0 :
690- num_tokens = [ndarray_to_mmap (np .concatenate ([dset .num_tokens for dset in soft_pack_groups ]))]
691- proxy_attn_flops = [
692- ndarray_to_mmap (np .concatenate ([dset .proxy_attn_flops for dset in soft_pack_groups ]))
693- ]
694-
695- soft_pack_datasets = [ConcatDataset (soft_pack_groups )]
696- for i , dataset in enumerate (soft_pack_datasets ):
697- _infos = self .get_soft_pack_infos (dataset , i , num_tokens [i ], proxy_attn_flops [i ])
698- pack_infos_list .extend (_infos )
699- pack_infos = Dataset .from_list (pack_infos_list )
673+ dataset_list : list [HardPackDataset | ExpandSoftPackDataset ] = []
700674
701- else :
702- raise NotImplementedError
675+ if hard_pack_groups :
676+ hard_pack_dataset = HardPackDataset (
677+ datasets = hard_pack_groups ,
678+ pack_max_length = pack_max_length ,
679+ global_pack = global_pack ,
680+ seed = seed ,
681+ pack_workers = pack_workers ,
682+ )
683+ dataset_list .append (hard_pack_dataset )
703684
704- self .hard_pack_datasets = hard_pack_datasets
705- self .datasets = soft_pack_datasets
706- self .hard_pack_len = hard_pack_len
707- self .pack_infos = pack_infos
685+ if soft_pack_groups :
686+ soft_pack_dataset = ExpandSoftPackDataset (
687+ datasets = soft_pack_groups ,
688+ pack_max_length = pack_max_length ,
689+ global_pack = global_pack ,
690+ pack_extra_buffer_size = pack_extra_buffer_size ,
691+ pack_chunk_size = pack_chunk_size ,
692+ pack_workers = pack_workers ,
693+ seed = seed ,
694+ )
695+ dataset_list .append (soft_pack_dataset )
708696
709- def get_hard_pack_item (self , item : int ):
710- info = self .pack_infos [item ]
711- dataset_id = info ["dataset_id" ]
712- ds = self .hard_pack_datasets [dataset_id ]
697+ assert dataset_list , "No datasets provided for packing."
698+ self .datasets : ConcatDataset [HardPackDataset | ExpandSoftPackDataset ] = ConcatDataset (dataset_list )
713699
714- indices = info ["indices" ]
715- s_off = info ["start_offset" ]
716- e_off = info ["end_offset" ]
700+ @cached_property
701+ def longest (self ):
702+ longest_list = []
703+ for dataset in self .datasets .datasets :
704+ longest_list .extend (cast (HardPackDataset | ExpandSoftPackDataset , dataset ).longest )
705+ return longest_list
717706
718- packed_list : list [dict ] = []
707+ def __getitem__ (self , item : int ):
708+ return self .datasets [item ]
719709
720- for i in range (len (indices )):
721- idx = indices [i ]
722- sample = ds [idx ]
723- ids = sample ["input_ids" ]
724- labs = sample .get ("labels" , None )
710+ def __len__ (self ) -> int :
711+ return len (self .datasets )
725712
726- st = 0 if i != 0 else s_off
727- ed = len (ids ) if i != len (indices ) - 1 else e_off
713+ def get_state_dict (self ):
714+ return {
715+ "pack_max_length" : self .pack_max_length ,
716+ "seed" : self .seed ,
717+ "global_pack" : self .global_pack ,
718+ "pack_extra_buffer_size" : self .pack_extra_buffer_size ,
719+ "pack_chunk_size" : self .pack_chunk_size ,
720+ }
728721
729- packed_list .append (
730- {
731- "input_ids" : ids [st :ed ],
732- "labels" : labs [st :ed ] if labs is not None else None ,
733- "num_tokens" : ed - st ,
734- }
722+ def load_state_dict (self , state_dict ):
723+ if self .seed != state_dict ["seed" ]:
724+ raise ValueError (
725+ f"Cannot load state dict with different seed . Origin: { state_dict ['seed' ]} , New: { self .seed } "
735726 )
736- assert (total_num_tokens := sum (i ["num_tokens" ] for i in packed_list )) == self .pack_max_length , (
737- f"Internal Error! Found size: { total_num_tokens } mismatch after hard packing."
738- )
739- return packed_list
740-
741- def __getitem__ (self , item : int ):
742- if item < self .hard_pack_len :
743- return self .get_hard_pack_item (item )
744- else :
745- return super ().__getitem__ (item )
746727
747- def get_hard_pack_infos (self , dataset : Sized , dataset_id : int , num_tokens : np .ndarray ):
748- # shuffled indices
749- inds = torch .randperm (len (dataset ), generator = self .torch_random_generator ).tolist ()
728+ if self .pack_max_length != state_dict ["pack_max_length" ]:
729+ raise ValueError (
730+ "Cannot load state dict with different pack_max_length "
731+ f". Origin: { state_dict ['pack_max_length' ]} , New: { self .pack_max_length } "
732+ )
750733
751- pack_infos_list = get_pack_infos_by_hard_split (
752- inds , dataset_id , num_tokens , pack_max_length = self .pack_max_length , pack_workers = self .pack_workers
753- )
754- return pack_infos_list
734+ if self .global_pack != state_dict ["global_pack" ]:
735+ raise ValueError (
736+ "Cannot load state dict with different global_pack "
737+ f". Origin: { state_dict ['global_pack' ]} , New: { self .global_pack } "
738+ )
755739
756- def get_soft_pack_infos (
757- self , dataset : Sized , dataset_id : int , num_tokens : np . ndarray , proxy_attn_flops : np . ndarray
758- ):
759- # shuffled indices
760- inds = torch . randperm ( len ( dataset ), generator = self . torch_random_generator ). tolist ( )
740+ if self . pack_extra_buffer_size != state_dict [ "pack_extra_buffer_size" ]:
741+ raise ValueError (
742+ "Cannot load state dict with different pack_extra_buffer_size "
743+ f". Origin: { state_dict [ 'pack_extra_buffer_size' ] } , New: { self . pack_extra_buffer_size } "
744+ )
761745
762- pack_infos_list = get_pack_infos_by_expand_soft_split (
763- inds ,
764- dataset_id ,
765- num_tokens ,
766- proxy_attn_flops ,
767- pack_max_length = self .pack_max_length ,
768- pack_workers = self .pack_workers ,
769- pack_chunk_size = self .pack_chunk_size ,
770- pack_extra_buffer_size = self .pack_extra_buffer_size ,
771- )
772- return pack_infos_list
746+ if self .pack_chunk_size != state_dict ["pack_chunk_size" ]:
747+ raise ValueError (
748+ "Cannot load state dict with different pack_chunk_size "
749+ f". Origin: { state_dict ['pack_chunk_size' ]} , New: { self .pack_chunk_size } "
750+ )
0 commit comments