Skip to content

Commit 9350e95

Browse files
committed
feat: add Qwen3Next model support
Add support for Qwen3Next architecture including: - New model implementation with GDN (Gated Delta Network) attention - Mamba cache memory manager for hybrid architecture - FLA (Flash Linear Attention) triton kernels - Custom triton kernels for causal conv1d, gated RMSNorm, fused gating - MTP (Multi-Token Prediction) variant support - Allocator utilities and parameter weight management - Hybrid radix cache for dynamic prompt handling
1 parent eb8c8c0 commit 9350e95

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+8908
-165
lines changed

lightllm/common/allocator_utils.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
from typing import List, Union
2+
3+
import torch
4+
5+
from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt
6+
from lightllm.utils.log_utils import init_logger
7+
8+
logger = init_logger(__name__)
9+
10+
11+
class TokenAllocator:
12+
def __init__(self, size, shared_can_use_token_num_name: str):
13+
self.size = size
14+
15+
self.mem_state = torch.arange(
16+
0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True
17+
)
18+
self._mem_state_return = torch.arange(
19+
0, self.size * 3, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True
20+
)
21+
self._return_start = 0
22+
self.mark_start = 0
23+
self.mark_end = self.size
24+
25+
self.can_use_mem_size = self.size
26+
27+
# 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。
28+
self.shared_can_use_token_num = SharedInt(shared_can_use_token_num_name)
29+
30+
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
31+
self.HOLD_TOKEN_MEMINDEX = self.size
32+
33+
def alloc(self, need_size) -> torch.Tensor:
34+
if need_size > self.mark_end - self.mark_start:
35+
logger.error(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}")
36+
assert False, "error alloc state"
37+
38+
start = self.mark_start
39+
end = self.mark_start + need_size
40+
self.mark_start += need_size
41+
42+
self.can_use_mem_size -= need_size
43+
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
44+
45+
# 利用缓冲区返回,避免异步情况下的内存竞争
46+
if self._return_start + need_size > self._mem_state_return.shape[0]:
47+
self._return_start = 0
48+
ans = self._mem_state_return[self._return_start : self._return_start + need_size]
49+
ans.copy_(self.mem_state[start:end])
50+
self._return_start += need_size
51+
return ans
52+
53+
def free(self, free_index: Union[torch.Tensor, List[int]]):
54+
"""_summary_
55+
56+
Args:
57+
free_index (torch.Tensor): _description_
58+
"""
59+
end = self.mark_start
60+
start = self.mark_start - len(free_index)
61+
assert start >= 0, f"error free state start: {self.mark_start} free len {len(free_index)}"
62+
63+
if isinstance(free_index, list):
64+
free_index_tensor = torch.tensor(free_index, dtype=self.mem_state.dtype, device=self.mem_state.device)
65+
self.mem_state[start:end] = free_index_tensor
66+
else:
67+
# 从 gpu 到 cpu 的拷贝操作是流内阻塞操作
68+
self.mem_state[start:end] = free_index
69+
70+
self.mark_start -= len(free_index)
71+
72+
self.can_use_mem_size += len(free_index)
73+
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
74+
75+
if self.can_use_mem_size == len(self.mem_state):
76+
logger.debug(f"freed all gpu mem size {self.can_use_mem_size}")
77+
return
78+
79+
def free_all(self):
80+
self.can_use_mem_size = len(self.mem_state)
81+
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
82+
self.mem_state.numpy()[:] = list(range(0, len(self.mem_state)))
83+
self.mark_start = 0
84+
self.mark_end = len(self.mem_state)
85+
86+
def resize_mem(self, new_size):
87+
"""
88+
just for test code
89+
"""
90+
self.size = new_size
91+
self.mem_state = torch.arange(
92+
0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True
93+
)
94+
self.mark_start = 0
95+
self.mark_end = self.size
96+
self.can_use_mem_size = self.size
97+
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
98+
return

lightllm/common/basemodel/basemodel.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,16 @@ class TpPartBaseModel:
5353
# infer state class
5454
infer_state_class = InferStateInfo
5555

56+
@classmethod
57+
def get_radix_cache_class(cls):
58+
"""Return the appropriate RadixCache class for this model type.
59+
60+
Override in subclasses that need specialized cache (e.g., hybrid models).
61+
"""
62+
from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache
63+
64+
return RadixCache
65+
5666
def __init__(self, kvargs):
5767
self.args = get_env_start_args()
5868
self.run_mode = kvargs["run_mode"]
@@ -302,6 +312,7 @@ def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0)
302312
infer_state.prefix_total_token_num = model_input.prefix_total_token_num
303313
assert model_input.b_req_idx.shape[0] == model_input.b_seq_len.shape[0]
304314
infer_state.b_req_idx = model_input.b_req_idx
315+
infer_state.b_mtp_index = model_input.b_mtp_index
305316
infer_state.b_seq_len = model_input.b_seq_len
306317
if model_input.is_prefill:
307318
if model_input.b_ready_cache_len is not None:
@@ -1028,6 +1039,7 @@ def _gen_special_model_input(self, token_num: int):
10281039
"Deepseek3MTPModel" in str(self.__class__)
10291040
or "Qwen3MOEMTPModel" in str(self.__class__)
10301041
or "MistralMTPModel" in str(self.__class__)
1042+
or "Qwen3NextMTPModel" in str(self.__class__)
10311043
)
10321044
if is_mtp_draft_model:
10331045
special_model_input["mtp_draft_input_hiddens"] = torch.randn(

lightllm/common/basemodel/cuda_graph.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import copy
44
import bisect
55
from typing import Optional
6+
from tqdm import tqdm
67
from lightllm.utils.log_utils import init_logger
78
from lightllm.utils.envs_utils import get_env_start_args
89
from lightllm.distributed import dist_group_manager, lightllm_capture_graph, CustomProcessGroup
@@ -196,7 +197,12 @@ def warmup(self, model):
196197
model: TpPartBaseModel = model
197198

198199
# decode cuda graph init
199-
for batch_size in self.cuda_graph_batch_sizes[::-1]:
200+
progress_bar = tqdm(self.cuda_graph_batch_sizes[::-1], desc="Capturing CUDA graphs")
201+
for batch_size in progress_bar:
202+
# Get available memory info
203+
avail_mem, total_mem = torch.cuda.mem_get_info()
204+
avail_mem_gb = avail_mem / (1024 ** 3)
205+
progress_bar.set_description(f"Capturing CUDA graphs - Batch: {batch_size}, AvailMem: {avail_mem_gb:.2f}GB")
200206
seq_len = 2
201207
total_token_num = batch_size * seq_len
202208
max_len_in_batch = self.graph_max_len_in_batch
@@ -251,7 +257,14 @@ def warmup_overlap(self, model):
251257

252258
model: TpPartBaseModel = model
253259

254-
for batch_size in self.cuda_graph_batch_sizes[::-1]:
260+
progress_bar = tqdm(self.cuda_graph_batch_sizes[::-1], desc="Capturing overlap CUDA graphs")
261+
for batch_size in progress_bar:
262+
# Get available memory info
263+
avail_mem, total_mem = torch.cuda.mem_get_info()
264+
avail_mem_gb = avail_mem / (1024 ** 3)
265+
progress_bar.set_description(
266+
f"Capturing overlap CUDA graphs - Batch: {batch_size}, AvailMem: {avail_mem_gb:.2f}GB"
267+
)
255268
decode_batches = []
256269
for micro_batch_index in [0, 1]:
257270
# dummy decoding, capture the cudagraph

lightllm/common/basemodel/infer_struct.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ def __init__(self):
3232
self.batch_size: int = None
3333
self.total_token_num: int = None
3434
self.b_req_idx: torch.Tensor = None
35+
self.b_mtp_index: torch.Tensor = None # MTP index for each batch item (0: main, 1-mtp_step: candidates)
36+
self.b_start_loc: torch.Tensor = None
3537
self.b_ready_cache_len: torch.Tensor = None # only for prefill prompt cache used.
3638

3739
self.b_shared_seq_len: torch.Tensor = None # only for diverse mode used in decode phase.
@@ -98,7 +100,10 @@ def __init__(self):
98100
self.dp_output_split_sizes: List[List[int]] = None
99101
self.dp_input_split_sizes: List[List[int]] = None
100102

101-
def init_some_extra_state(self, model):
103+
# 专门用于管理混合注意力模型的buffer
104+
self.buffer_indexes: torch.Tensor = None
105+
106+
def init_some_extra_state(self, model, input_ids: torch.Tensor = None):
102107
if self.is_prefill:
103108
(
104109
self.b_q_seq_len,
@@ -121,6 +126,9 @@ def init_some_extra_state(self, model):
121126
self.position_ids,
122127
) = gen_decode_params(self.b_seq_len)
123128
self.b_kv_start_loc = self.b1_cu_kv_seq_len[0:-1]
129+
# max_kv_seq_len is already set in _create_inferstate from model_input.max_kv_seq_len
130+
self.max_q_seq_len = self.b_q_seq_len.max().item() if self.b_q_seq_len.numel() > 0 else 1
131+
self.b_start_loc = self.b1_cu_kv_seq_len[0:-1]
124132

125133
def init_att_state(self):
126134
if self.is_prefill:
@@ -136,7 +144,7 @@ def copy_for_cuda_graph(self, new_infer_state: "InferStateInfo"):
136144
for attr_name, attr_value in vars(new_infer_state).items():
137145
if isinstance(attr_value, torch.Tensor):
138146
attr_ = getattr(self, attr_name, None)
139-
if attr_ is not None and attr_.data_ptr() != attr_value.data_ptr():
147+
if attr_ is not None and attr_.data_ptr() != attr_value.data_ptr() and attr_.shape == attr_value.shape:
140148
attr_.copy_(attr_value, non_blocking=True)
141149

142150
self.decode_att_state.copy_for_decode_cuda_graph(new_infer_state.decode_att_state)

lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -62,20 +62,21 @@ def _ffn(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor
6262
def _tpsp_ffn(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor:
6363
raise Exception("need to impl")
6464

65-
def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight):
66-
input1 = self._att_norm(input_embdings, infer_state, layer_weight)
67-
q, cache_kv = self._get_qkv(input1, infer_state, layer_weight)
68-
input1 = None
65+
def context_attention_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight):
66+
q, cache_kv = self._get_qkv(input_embdings, infer_state, layer_weight)
6967
self._post_cache_kv(cache_kv, infer_state, layer_weight)
70-
7168
o = self._context_attention_wrapper_run(
7269
q=q, cache_kv=cache_kv, infer_state=infer_state, layer_weight=layer_weight
7370
)
74-
7571
q = None
7672
o = self._get_o(o, infer_state, layer_weight)
7773
if self.tp_world_size_ > 1:
7874
all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
75+
return o
76+
77+
def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight):
78+
input1 = self._att_norm(input_embdings, infer_state, layer_weight)
79+
o = self.context_attention_forward(input1, infer_state, layer_weight)
7980
input_embdings.add_(o.view(-1, self.embed_dim_))
8081
o = None
8182

@@ -87,39 +88,42 @@ def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_wei
8788
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
8889
return input_embdings
8990

90-
def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight):
91-
input1 = self._att_norm(input_embdings, infer_state, layer_weight)
92-
q, cache_kv = self._get_qkv(input1, infer_state, layer_weight)
93-
input1 = None
91+
def token_attention_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight):
92+
q, cache_kv = self._get_qkv(input_embdings, infer_state, layer_weight)
9493
self._post_cache_kv(cache_kv, infer_state, layer_weight)
9594
o = self._token_attention_kernel(q, infer_state, layer_weight)
9695
q = None
9796
o = self._get_o(o, infer_state, layer_weight)
9897
if self.tp_world_size_ > 1:
9998
all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
99+
return o
100+
101+
def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight):
102+
input1 = self._att_norm(input_embdings, infer_state, layer_weight)
103+
o = self.token_attention_forward(input1, infer_state, layer_weight)
100104
input_embdings.add_(o.view(-1, self.embed_dim_))
101105
o = None
102106

103107
input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
104108
ffn_out = self._ffn(input1, infer_state, layer_weight)
105-
input1 = None
106109
if self.tp_world_size_ > 1:
107110
all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
108111
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
109112
return input_embdings
110113

111-
def tpsp_context_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight):
112-
input1 = self._att_norm(input_embdings, infer_state, layer_weight)
113-
q, cache_kv = self._tpsp_get_qkv(input1, infer_state, layer_weight)
114-
input1 = None
114+
def tpsp_context_attention_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight):
115+
q, cache_kv = self._tpsp_get_qkv(input_embdings, infer_state, layer_weight)
115116
self._post_cache_kv(cache_kv, infer_state, layer_weight)
116-
117117
o = self._context_attention_wrapper_run(
118118
q=q, cache_kv=cache_kv, infer_state=infer_state, layer_weight=layer_weight
119119
)
120-
121120
q = None
122121
o = self._tpsp_get_o(o, infer_state, layer_weight)
122+
return o
123+
124+
def tpsp_context_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight):
125+
input1 = self._att_norm(input_embdings, infer_state, layer_weight)
126+
o = self.tpsp_context_attention_forward(input1, infer_state, layer_weight)
123127
input_embdings.add_(o.view(-1, self.embed_dim_))
124128
o = None
125129

@@ -129,14 +133,17 @@ def tpsp_context_forward(self, input_embdings: torch.Tensor, infer_state: InferS
129133
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
130134
return input_embdings
131135

132-
def tpsp_token_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight):
133-
input1 = self._att_norm(input_embdings, infer_state, layer_weight)
134-
q, cache_kv = self._tpsp_get_qkv(input1, infer_state, layer_weight)
135-
input1 = None
136+
def tpsp_token_attention_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight):
137+
q, cache_kv = self._tpsp_get_qkv(input_embdings, infer_state, layer_weight)
136138
self._post_cache_kv(cache_kv, infer_state, layer_weight)
137139
o = self._token_attention_kernel(q, infer_state, layer_weight)
138140
q = None
139141
o = self._tpsp_get_o(o, infer_state, layer_weight)
142+
return o
143+
144+
def tpsp_token_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight):
145+
input1 = self._att_norm(input_embdings, infer_state, layer_weight)
146+
o = self.tpsp_token_attention_forward(input1, infer_state, layer_weight)
140147
input_embdings.add_(o.view(-1, self.embed_dim_))
141148
o = None
142149

lightllm/common/basemodel/layer_weights/meta_weights/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77
ROWBMMWeight,
88
)
99
from .norm_weight import NoTpGEMMANormWeight, TpVitPadNormWeight, NoTpNormWeight, TpHeadNormWeight
10+
11+
# NormWeight is an alias for NoTpNormWeight for backward compatibility
12+
NormWeight = NoTpNormWeight
1013
from .fused_moe_weight_tp import create_tp_moe_wegiht_obj
1114
from .fused_moe_weight_ep import FusedMoeWeightEP
1215
from .embedding_weight import EmbeddingWeight, LMHeadWeight, NoTpPosEmbeddingWeight
1316
from .att_sink_weight import TpAttSinkWeight
17+
from .parameter_weight import ParameterWeight, TpParameterWeight

0 commit comments

Comments
 (0)