Skip to content

Commit af36003

Browse files
authored
[GRANITEMOESHARED] drop granitemoeshared model support (#379)
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
1 parent 7204b77 commit af36003

File tree

4 files changed

+7
-267
lines changed

4 files changed

+7
-267
lines changed

lm_engine/hf_models/model_conversion/__init__.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,11 @@
1414
_import_granitemoehybrid_config,
1515
_import_granitemoehybrid_state_dict,
1616
)
17-
from .granitemoeshared import (
18-
_export_granitemoeshared_config,
19-
_export_granitemoeshared_state_dict,
20-
_import_granitemoeshared_config,
21-
_import_granitemoeshared_state_dict,
22-
)
2317
from .llama import _export_llama_config, _export_llama_state_dict, _import_llama_config, _import_llama_state_dict
2418

2519

2620
_MODEL_IMPORT_FUNCTIONS = {
2721
"granite": (_import_granite_config, _import_llama_state_dict),
28-
"granitemoeshared": (_import_granitemoeshared_config, _import_granitemoeshared_state_dict),
2922
"granitemoehybrid": (_import_granitemoehybrid_config, _import_granitemoehybrid_state_dict),
3023
"llama": (_import_llama_config, _import_llama_state_dict),
3124
}
@@ -64,7 +57,6 @@ def import_from_huggingface(
6457

6558
_MODEL_EXPORT_FUNCTIONS = {
6659
"granite": (_export_granite_config, _export_llama_state_dict),
67-
"granitemoeshared": (_export_granitemoeshared_config, _export_granitemoeshared_state_dict),
6860
"granitemoehybrid": (_export_granitemoehybrid_config, _export_granitemoehybrid_state_dict),
6961
"llama": (_export_llama_config, _export_llama_state_dict),
7062
}

lm_engine/hf_models/model_conversion/granitemoehybrid.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# Copyright (c) 2025, Mayank Mishra
33
# **************************************************
44

5+
import torch
56
from transformers import GraniteMoeHybridConfig, GraniteMoeHybridForCausalLM
67

78
from ...utils import SafeTensorsWeightsManager, divide_if_divisible
@@ -10,7 +11,12 @@
1011
split_query_key_value_tensor_for_attention,
1112
)
1213
from ..models import GPTBaseConfig
13-
from .granitemoeshared import _split_and_reorder_for_glu
14+
15+
16+
def _split_and_reorder_for_glu(weight: torch.Tensor, dim: int) -> torch.Tensor:
17+
x, y = weight.chunk(2, dim=dim)
18+
weight = torch.cat([y, x], dim=dim)
19+
return weight
1420

1521

1622
def _import_granitemoehybrid_config(original_config: GraniteMoeHybridConfig) -> GPTBaseConfig:

lm_engine/hf_models/model_conversion/granitemoeshared.py

Lines changed: 0 additions & 238 deletions
This file was deleted.

tests/hf_models/single_gpu/model_conversion_test.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -36,26 +36,6 @@ def test_granite_model_conversion(self, device: torch.device, add_bias: bool) ->
3636
lm_engine_config=lm_engine_config, model_type="granite", device=device, exact_match=False
3737
)
3838

39-
@parameterized.expand(TestCommons.get_all_devices())
40-
def test_granitemoeshared_model_conversion(self, device: torch.device) -> None:
41-
lm_engine_config = self.get_moe_test_config(
42-
"rope",
43-
add_bias=False,
44-
shared_n_inner=64,
45-
activation_function="swiglu",
46-
normalization_function="rmsnorm",
47-
m_emb=2,
48-
m_width=2,
49-
)
50-
51-
self.model_conversion_test(
52-
lm_engine_config=lm_engine_config,
53-
model_type="granitemoeshared",
54-
device=device,
55-
exact_match=False,
56-
compare_loss=False,
57-
)
58-
5939
@parameterized.expand(TestCommons.make_args_matrix(TestCommons.get_all_devices(), [True, False]))
6040
def test_granitemoehybrid_model_conversion(self, device: torch.device, is_moe: bool) -> None:
6141
if is_moe:

0 commit comments

Comments
 (0)