Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .dev.commit
Original file line number Diff line number Diff line change
@@ -1 +1 @@
c72c4599012297cfbd1d57e006b544478b6bbf78
7ff046b1c8b976ff33761976796c4302ebd0a7bc
2 changes: 1 addition & 1 deletion 3rdparty/Megatron-LM
Submodule Megatron-LM updated 513 files
42 changes: 42 additions & 0 deletions src/megatron/bridge/models/gpt_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,3 +417,45 @@ def _fixed_from_config(config):


_patch_yarn_concentration_factor()


def _patch_te_grouped_linear_single_grouped_weight():
"""Guard for main/dev branch submodule compat: single_grouped_weight/bias kwargs.

MCore dev (commit 5c544844) passes ``single_grouped_weight`` and
``single_grouped_bias`` to TE ``GroupedLinear.__init__`` when
``is_te_min_version("2.14.0")``. However some TE 2.14.0 builds only
expose a single ``single_grouped_parameter`` kwarg. Remap so both
APIs work.

TODO: remove guard once TE ships the split weight/bias API in a
stable release and the CI container is updated.
"""
try:
import transformer_engine.pytorch as te_pytorch

_te_gl_init_params = set(inspect.signature(te_pytorch.GroupedLinear.__init__).parameters)

# Nothing to patch if TE already accepts the split kwargs.
if "single_grouped_weight" in _te_gl_init_params:
return

# Nothing to patch if TE has neither API (older TE without the feature).
if "single_grouped_parameter" not in _te_gl_init_params:
return

_original_init = te_pytorch.GroupedLinear.__init__

def _patched_init(self, *args, **kwargs):
sgw = kwargs.pop("single_grouped_weight", False)
sgb = kwargs.pop("single_grouped_bias", False)
if sgw or sgb:
kwargs["single_grouped_parameter"] = True
_original_init(self, *args, **kwargs)

te_pytorch.GroupedLinear.__init__ = _patched_init
except ImportError:
pass


_patch_te_grouped_linear_single_grouped_weight()
8 changes: 7 additions & 1 deletion src/megatron/bridge/training/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Input/output checkpointing."""

import contextlib
import inspect
import os
import random
import shutil
Expand Down Expand Up @@ -944,6 +945,11 @@ def save_checkpoint(
checkpointing_context["save_strategy"] = save_strategy
end_ckpt = time()
logger.debug(f"rank: {rank}, takes {end_ckpt - start_ckpt} to prepare state dict for ckpt ")
# Guard for main/dev branch submodule compat: async_strategy was removed in mcore dev.
_save_params = set(inspect.signature(dist_checkpointing.save).parameters)
_save_optional_kwargs: dict[str, Any] = {}
if "async_strategy" in _save_params:
_save_optional_kwargs["async_strategy"] = ckpt_cfg.async_strategy
async_save_request = dist_checkpointing.save(
state_dict,
checkpoint_name,
Expand All @@ -952,7 +958,7 @@ def save_checkpoint(
validate_access_integrity=validate_sharding_integrity,
preprocess_common_before_consistancy_check=preprocess_common_state_dict_fn,
content_metadata=_clean_metadata_for_serialization(sharded_sd_metadata),
async_strategy=ckpt_cfg.async_strategy,
**_save_optional_kwargs,
)
# [ModelOpt]: save sharded modelopt_state (skip if model is empty, e.g., low-memory save mode)
if model:
Expand Down
10 changes: 9 additions & 1 deletion src/megatron/bridge/training/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import datetime
import inspect
import os
import time
import warnings
Expand Down Expand Up @@ -676,14 +677,20 @@ def _initialize_distributed(
if parallel_state.model_parallel_is_initialized():
print("model parallel is already initialized")
else:
# Guard for main/dev branch submodule compat: hybrid_context_parallel was added in the dev branch.
# TODO: remove guard once the addition lands in main and Bridge pins the new main commit.
_init_mp_params = set(inspect.signature(parallel_state.initialize_model_parallel).parameters)
_optional_kwargs = {}
if "hybrid_context_parallel" in _init_mp_params:
_optional_kwargs["hybrid_context_parallel"] = model_config.hybrid_context_parallel

parallel_state.initialize_model_parallel(
tensor_model_parallel_size=model_config.tensor_model_parallel_size,
pipeline_model_parallel_size=model_config.pipeline_model_parallel_size,
virtual_pipeline_model_parallel_size=model_config.virtual_pipeline_model_parallel_size,
pipeline_model_parallel_comm_backend=model_config.pipeline_model_parallel_comm_backend,
context_parallel_size=model_config.context_parallel_size,
hierarchical_context_parallel_sizes=model_config.hierarchical_context_parallel_sizes,
hybrid_context_parallel=model_config.hybrid_context_parallel,
expert_model_parallel_size=model_config.expert_model_parallel_size,
num_distributed_optimizer_instances=num_distributed_optimizer_instances,
expert_tensor_parallel_size=model_config.expert_tensor_parallel_size,
Expand All @@ -696,6 +703,7 @@ def _initialize_distributed(
use_sharp=dist_config.use_sharp,
high_priority_stream_groups=dist_config.high_priority_stream_groups,
sharp_enabled_group=dist_config.sharp_enabled_group,
**_optional_kwargs,
)
if get_rank_safe() == 0:
print(
Expand Down
37 changes: 25 additions & 12 deletions src/megatron/bridge/training/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,23 @@
from typing import Any, Optional

import torch
from megatron.core.energy_monitor import EnergyMonitor
from megatron.core.timers import Timers
from megatron.core.utils import StragglerDetector
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.tensorboard.writer import SummaryWriter


# TODO: Remove try/except once `get_async_strategy` lands in mcore dev.
# The function was added to mcore main but has not yet been merged into dev.
# TODO: Remove try/except guards once these land in mcore dev.
try:
from megatron.core.dist_checkpointing.strategies.torch import get_async_strategy
except ImportError:
get_async_strategy = None # type: ignore[assignment]

try:
from megatron.core.energy_monitor import EnergyMonitor
except ImportError:
EnergyMonitor = None # type: ignore[assignment]

from megatron.bridge.training.config import ConfigContainer
from megatron.bridge.training.nvrx_straggler import NVRxStragglerDetectionManager
from megatron.bridge.training.tokenizers.tokenizer import build_tokenizer
Expand Down Expand Up @@ -144,7 +147,7 @@ def __init__(self) -> None:
self._async_calls_queue: Optional[Any] = None
self._nvrx_straggler_manager: Optional[NVRxStragglerDetectionManager] = None
self._nvrx_straggler_created: bool = False
self._energy_monitor: Optional[EnergyMonitor] = None
self._energy_monitor: Optional[Any] = None
self._energy_monitor_created: bool = False

@property
Expand Down Expand Up @@ -402,13 +405,22 @@ def initialize_async_checkpoint_worker(self) -> None:
and self.cfg.checkpoint.save is not None
and self.cfg.checkpoint.async_save
):
if get_async_strategy is None:
raise RuntimeError(
"get_async_strategy is required for async checkpointing but is not available "
"in the current mcore version. Please use mcore main or a newer mcore dev branch."
if get_async_strategy is not None:
# mcore main path: get_async_strategy selects nvrx vs mcore backend
async_strategy, async_modules = get_async_strategy(self.cfg.checkpoint.async_strategy)
async_calls_queue_cls = async_modules["AsyncCallsQueue"]
get_write_results_queue_fn = async_modules["get_write_results_queue"]
else:
# mcore dev path: nvrx modules merged into core, no strategy selector
from megatron.core.dist_checkpointing.strategies.async_utils import AsyncCallsQueue
from megatron.core.dist_checkpointing.strategies.filesystem_async import (
get_write_results_queue,
)
async_strategy, async_modules = get_async_strategy(self.cfg.checkpoint.async_strategy)
async_calls_queue_cls = async_modules["AsyncCallsQueue"]

async_strategy = None
async_calls_queue_cls = AsyncCallsQueue
get_write_results_queue_fn = get_write_results_queue

self._async_calls_queue = async_calls_queue_cls(persistent=self.cfg.checkpoint.use_persistent_ckpt_worker)

if self.cfg.checkpoint.use_persistent_ckpt_worker:
Expand All @@ -419,7 +431,7 @@ def initialize_async_checkpoint_worker(self) -> None:
if async_strategy == "mcore":
warmup_kwargs["mp_mode"] = "spawn"
self._async_calls_queue.warmup_persistent_caller(get_rank_safe(), **warmup_kwargs)
async_modules["get_write_results_queue"](self.cfg.checkpoint.async_write_results_mp_mode)
get_write_results_queue_fn(self.cfg.checkpoint.async_write_results_mp_mode)

@property
def async_calls_queue(self) -> Optional[Any]:
Expand All @@ -440,13 +452,14 @@ def nvrx_straggler_manager(self) -> Optional[NVRxStragglerDetectionManager]:
return self._nvrx_straggler_manager

@property
def energy_monitor(self) -> Optional[EnergyMonitor]:
def energy_monitor(self) -> Optional[Any]:
"""The EnergyMonitor instance for tracking energy consumption."""
if (
not self._energy_monitor_created
and self._energy_monitor is None
and self.cfg is not None
and self.cfg.logger.log_energy
and EnergyMonitor is not None
):
self._energy_monitor = EnergyMonitor()
self._energy_monitor_created = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,9 @@ def _fuse_moe_expert_weights(model_dir: Path, num_experts: int) -> None:
if not keys_to_remove:
return

new_state_dict = {k: v for k, v in state_dict.items() if k not in keys_to_remove}
# Clone kept tensors so they are no longer backed by the mmap of weights_path,
# which will be overwritten by save_file below.
new_state_dict = {k: v.clone() for k, v in state_dict.items() if k not in keys_to_remove}

for prefix, experts in layers.items():
gate_up = torch.stack(
Expand Down
Loading
Loading