diff --git a/roll/datasets/collator.py b/roll/datasets/collator.py index 47e3f6bb5..aa39d6afd 100644 --- a/roll/datasets/collator.py +++ b/roll/datasets/collator.py @@ -138,6 +138,7 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: # model_inputs for hf/deepspeed: input_id, attention_mask, pixel_values, image_grid_thw padded_features = defaultdict(list) un_padded_features = defaultdict(list) + mm_token_type_id_features = [] mm_feature_keys = set() for feature in features: # cannot process as batch directly though processor output as batch @@ -165,6 +166,8 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: model_inputs.pop(key) for key in filter(lambda k: k in model_inputs, self.padded_keys): padded_features[key].append(model_inputs.pop(key)[0]) + if "mm_token_type_ids" in model_inputs: + mm_token_type_id_features.append(torch.as_tensor(model_inputs.pop("mm_token_type_ids")[0])) # mm feature fileds can be different because of mixed data mm_feature_keys = mm_feature_keys.union(model_inputs.keys()) # to tensors except padded_keys which would be converted after padding @@ -208,6 +211,22 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: return_tensors=self.return_tensors, ) batch.update(un_padded_features) + if mm_token_type_id_features: + target_len = batch["input_ids"].shape[-1] + padded_mm_token_type_ids = [] + for token_type_ids in mm_token_type_id_features: + pad_len = target_len - token_type_ids.shape[-1] + if pad_len < 0: + raise ValueError( + f"mm_token_type_ids length {token_type_ids.shape[-1]} exceeds padded input length {target_len}" + ) + pad = torch.zeros(pad_len, dtype=token_type_ids.dtype, device=token_type_ids.device) + if self.tokenizer.padding_side == "left": + token_type_ids = torch.cat([pad, token_type_ids], dim=-1) + else: + token_type_ids = torch.cat([token_type_ids, pad], dim=-1) + padded_mm_token_type_ids.append(token_type_ids) + batch["mm_token_type_ids"] = torch.stack(padded_mm_token_type_ids, dim=0) # other custom data fields: mainly for specific position_ids currently # position_ids for qwen2-vl is optional and make sure it is a 3D tensor @@ -226,6 +245,7 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: kwargs[key] = fun_params[key].default extra_data = self.extra_data_provider(**kwargs) batch.update(extra_data) + batch.pop("mm_token_type_ids", None) # each field should be a tensor or np.array(val=list_data, dtype=object) # to be stored in DataProto diff --git a/roll/models/model_providers.py b/roll/models/model_providers.py index b5c432026..054a7e393 100644 --- a/roll/models/model_providers.py +++ b/roll/models/model_providers.py @@ -277,6 +277,8 @@ def load_model( freeze_model(model, model_args) else: model = setup_lora_training(config, model, model_args, is_trainable) + if not model_args.disable_gradient_checkpointing and hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() if add_valuehead: from trl import AutoModelForCausalLMWithValueHead @@ -710,8 +712,6 @@ def get_extra_data_provider(model_name_or_path: str, processor=None): if isinstance(model_type, str) and (("qwen2" in model_type) or (model_type in ("qwen3_vl", "qwen3_vl_moe"))): import types - from transformers import BatchFeature # help define a object to accesss attr - def _call_get_rope_index(fn, input_ids: torch.LongTensor, **candidate_kwargs): sig = inspect.signature(fn) params = sig.parameters @@ -745,17 +745,13 @@ def _call_get_rope_index(fn, input_ids: torch.LongTensor, **candidate_kwargs): "<|vision_start|>" ) - dummy_self = BatchFeature( - { - "config": BatchFeature( - { - "vision_config": BatchFeature(vc), - "image_token_id": image_token_id, - "video_token_id": video_token_id, - "vision_start_token_id": vision_start_token_id, - } - ) - } + dummy_self = types.SimpleNamespace( + config=types.SimpleNamespace( + vision_config=types.SimpleNamespace(**vc), + image_token_id=image_token_id, + video_token_id=video_token_id, + vision_start_token_id=vision_start_token_id, + ) ) is_tf_ge_4_52 = is_transformers_version_greater_than("4.52.0") @@ -771,6 +767,9 @@ def _call_get_rope_index(fn, input_ids: torch.LongTensor, **candidate_kwargs): elif model_type in ("qwen3_vl", "qwen3_vl_moe"): from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLModel + dummy_self.get_vision_position_ids = types.MethodType( + Qwen3VLModel.get_vision_position_ids, dummy_self + ) get_rope_index = types.MethodType(Qwen3VLModel.get_rope_index, dummy_self) else: if is_tf_ge_4_52: @@ -787,8 +786,15 @@ def extra_data_provider( image_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, + mm_token_type_ids: Optional[torch.Tensor] = None, second_per_grid_ts: Optional[torch.Tensor] = None, ): + if model_type in ("qwen3_vl", "qwen3_vl_moe") and mm_token_type_ids is None: + mm_token_type_ids = torch.zeros_like(input_ids) + if image_token_id is not None: + mm_token_type_ids = torch.where(input_ids == image_token_id, 1, mm_token_type_ids) + if video_token_id is not None: + mm_token_type_ids = torch.where(input_ids == video_token_id, 2, mm_token_type_ids) # Keep kwargs to be resilient to HF signature changes between versions/models. out = _call_get_rope_index( get_rope_index, @@ -797,6 +803,7 @@ def extra_data_provider( video_grid_thw=video_grid_thw, second_per_grid_ts=second_per_grid_ts, attention_mask=attention_mask, + mm_token_type_ids=mm_token_type_ids, ) rope_index = out[0] # PumpkinComment: