Skip to content

Conversation

@sufubao
Copy link
Collaborator

@sufubao sufubao commented Jan 18, 2026

No description provided.

yeahdongcn and others added 4 commits January 6, 2026 19:29
This PR adds support for Moore Threads (MUSA) GPU platform, expanding
LightLLM's hardware compatibility.

*NOTE:*

1. `_fwd_kernel_token_att1` has been slightly updated to ensure
compatibility with the Triton version.
2. `has_mtlink` will be used in upcoming enhancements to enable
multi-GPU support.
3. `torch` / `torch_musa` need to be upgraded to the latest versions.

### Testing Done

```bash
root@worker3218:/ws# python -m lightllm.server.api_server --model_dir /home/dist/Qwen3-0.6B/ --disable_cudagraph --host 0.0.0.0
WARNING 01-02 12:22:47 [sgl_utils.py:29] sgl_kernel is not installed, or the installed version did not support fa3.         Try to upgrade it.
WARNING 01-02 12:22:47 [light_utils.py:13] lightllm_kernel is not installed, you can't use the api of it.
INFO 01-02 12:22:48 [__init__.py:36] Available plugins for group vllm.platform_plugins:
INFO 01-02 12:22:48 [__init__.py:38] - musa -> vllm_musa:register
INFO 01-02 12:22:48 [__init__.py:41] All plugins in this group will be loaded. Set `VLLM_PLUGINS` to control which plugins to load.
INFO 01-02 12:22:48 [__init__.py:232] Platform plugin musa is activated
WARNING 01-02 12:22:48 [vllm_utils.py:18] vllm is not installed, you can't use the api of it.                    You can solve it by running `pip install vllm`.
INFO 01-02 12:22:48 [communication_op.py:57] deep_ep is not installed, you can't use the api of it.
INFO 01-02 12:22:48 [cache_tensor_manager.py:17] USE_GPU_TENSOR_CACHE is On
WARNING 01-02 12:22:48 [grouped_fused_moe_ep.py:28] no deepep or deep_gemm
WARNING 01-02 12:22:48 [nixl_kv_transporter.py:19] nixl is not installed, which is required for pd disagreggation!!!
INFO 01-02 12:22:48 [shm_size_check.py:21] SHM check: Available=500.00 GB,Recommended=2.32 GB.Sufficient: True
INFO 01-02 12:22:48 [api_start.py:94] zmq mode head: ipc:///tmp/_28765_0_
INFO 01-02 12:22:48 [api_start.py:96] use tgi api: False
INFO 01-02 12:22:48 [api_start.py:233] alloced ports: [10105, 10128, 10009, 10002, 10268, 10173, 10255, 10190, 10225, 10305]
INFO 01-02 12:22:48 [api_start.py:284] all start args:Namespace(run_mode='normal', host='0.0.0.0', port=8000, httpserver_workers=1, zmq_mode='ipc:///tmp/_28765_0_', pd_master_ip='0.0.0.0', pd_master_port=1212, pd_decode_rpyc_port=42000, select_p_d_node_strategy='round_robin', config_server_host=None, config_server_port=None, nixl_pd_kv_page_num=16, nixl_pd_kv_page_size=1024, model_name='default_model_name', model_dir='/home/dist/Qwen3-0.6B/', tokenizer_mode='fast', load_way='HF', max_total_token_num=None, mem_fraction=0.9, batch_max_tokens=8448, eos_id=[151645], tool_call_parser=None, reasoning_parser=None, chat_template=None, running_max_req_size=1000, nnodes=1, node_rank=0, multinode_httpmanager_port=12345, multinode_router_gloo_port=20001, tp=1, dp=1, dp_balancer='bs_balancer', max_req_total_len=16384, nccl_host='127.0.0.1', nccl_port=28765, use_config_server_to_init_nccl=False, mode=[], trust_remote_code=False, disable_log_stats=False, log_stats_interval=10, disable_shm_warning=False, router_token_ratio=0.0, router_max_new_token_len=1024, router_max_wait_tokens=1, disable_aggressive_schedule=False, use_dynamic_prompt_cache=False, disable_dynamic_prompt_cache=False, chunked_prefill_size=4096, disable_chunked_prefill=False, diverse_mode=False, token_healing_mode=False, output_constraint_mode='none', first_token_constraint_mode=False, enable_multimodal=False, enable_multimodal_audio=False, enable_mps=False, disable_custom_allreduce=False, enable_custom_allgather=False, enable_tpsp_mix_mode=False, enable_dp_prefill_balance=False, enable_prefill_microbatch_overlap=False, enable_decode_microbatch_overlap=False, enable_flashinfer_prefill=False, enable_flashinfer_decode=False, enable_fa3=False, cache_capacity=200, embed_cache_storage_size=4, data_type='bfloat16', return_all_prompt_logprobs=False, use_reward_model=False, long_truncation_mode=None, use_tgi_api=False, health_monitor=False, metric_gateway=None, job_name='lightllm', grouping_key=[], push_interval=10, visual_infer_batch_size=1, visual_send_batch_size=1, visual_gpu_ids=[0], visual_tp=1, visual_dp=1, visual_nccl_ports=[29500], enable_monitor_auth=False, disable_cudagraph=True, enable_prefill_cudagraph=False, prefll_cudagraph_max_handle_token=512, graph_max_batch_size=256, graph_split_batch_size=32, graph_grow_step_size=16, graph_max_len_in_batch=16384, quant_type='none', quant_cfg=None, vit_quant_type='none', vit_quant_cfg=None, sampling_backend='triton', penalty_counter_mode='gpu_counter', ep_redundancy_expert_config_path=None, auto_update_redundancy_expert=False, enable_fused_shared_experts=False, mtp_mode=None, mtp_draft_model_dir=None, mtp_step=0, kv_quant_calibration_config_path=None, schedule_time_interval=0.03, enable_cpu_cache=False, cpu_cache_storage_size=2, cpu_cache_token_page_size=256, enable_disk_cache=False, disk_cache_storage_size=10, disk_cache_dir=None, enable_dp_prompt_cache_fetch=False, router_port=10105, detokenization_port=10128, http_server_port=10009, visual_port=10002, audio_port=10268, cache_port=10173, metric_port=10255, multi_level_kv_cache_port=10190, pd_node_infer_rpyc_ports=[10305], pd_node_id=294623010895931863621527973304373176200, pd_p_allowed_port_min=20000, pd_p_allowed_port_max=30000)
WARNING 01-02 12:22:55 [sgl_utils.py:29] sgl_kernel is not installed, or the installed version did not support fa3.         Try to upgrade it.
WARNING 01-02 12:22:55 [light_utils.py:13] lightllm_kernel is not installed, you can't use the api of it.
INFO 01-02 12:22:55 [__init__.py:36] Available plugins for group vllm.platform_plugins:
INFO 01-02 12:22:55 [__init__.py:38] - musa -> vllm_musa:register
INFO 01-02 12:22:55 [__init__.py:41] All plugins in this group will be loaded. Set `VLLM_PLUGINS` to control which plugins to load.
INFO 01-02 12:22:55 [__init__.py:232] Platform plugin musa is activated
WARNING 01-02 12:22:55 [vllm_utils.py:18] vllm is not installed, you can't use the api of it.                    You can solve it by running `pip install vllm`.
INFO 01-02 12:22:55 [communication_op.py:57] deep_ep is not installed, you can't use the api of it.
2026-01-02 12:22:55 | server | 140684395422848 | INFO : server started on [0.0.0.0]:10255
INFO 01-02 12:22:55 [start_utils.py:37] init func start_metric_manager : init ok
WARNING 01-02 12:23:02 [sgl_utils.py:29] sgl_kernel is not installed, or the installed version did not support fa3.         Try to upgrade it.
WARNING 01-02 12:23:02 [light_utils.py:13] lightllm_kernel is not installed, you can't use the api of it.
WARNING 01-02 12:23:02 [sgl_utils.py:29] sgl_kernel is not installed, or the installed version did not support fa3.         Try to upgrade it.
WARNING 01-02 12:23:02 [light_utils.py:13] lightllm_kernel is not installed, you can't use the api of it.
INFO 01-02 12:23:02 [__init__.py:36] Available plugins for group vllm.platform_plugins:
INFO 01-02 12:23:02 [__init__.py:38] - musa -> vllm_musa:register
INFO 01-02 12:23:02 [__init__.py:41] All plugins in this group will be loaded. Set `VLLM_PLUGINS` to control which plugins to load.
INFO 01-02 12:23:02 [__init__.py:232] Platform plugin musa is activated
WARNING 01-02 12:23:02 [vllm_utils.py:18] vllm is not installed, you can't use the api of it.                    You can solve it by running `pip install vllm`.
INFO 01-02 12:23:02 [communication_op.py:57] deep_ep is not installed, you can't use the api of it.
INFO 01-02 12:23:02 [cache_tensor_manager.py:17] USE_GPU_TENSOR_CACHE is On
INFO 01-02 12:23:02 [__init__.py:36] Available plugins for group vllm.platform_plugins:
INFO 01-02 12:23:02 [__init__.py:38] - musa -> vllm_musa:register
INFO 01-02 12:23:02 [__init__.py:41] All plugins in this group will be loaded. Set `VLLM_PLUGINS` to control which plugins to load.
INFO 01-02 12:23:02 [__init__.py:232] Platform plugin musa is activated
WARNING 01-02 12:23:02 [vllm_utils.py:18] vllm is not installed, you can't use the api of it.                    You can solve it by running `pip install vllm`.
INFO 01-02 12:23:02 [communication_op.py:57] deep_ep is not installed, you can't use the api of it.
WARNING 01-02 12:23:02 [grouped_fused_moe_ep.py:28] no deepep or deep_gemm
INFO 01-02 12:23:02 [cache_tensor_manager.py:17] USE_GPU_TENSOR_CACHE is On
WARNING 01-02 12:23:03 [grouped_fused_moe_ep.py:28] no deepep or deep_gemm
INFO 01-02 12:23:03 [manager.py:36] pub_to_httpserver sendhwm 1000
WARNING 01-02 12:23:03 [nixl_kv_transporter.py:19] nixl is not installed, which is required for pd disagreggation!!!
2026-01-02 12:23:03 | server | 140684395422848 | INFO : accepted ('127.0.0.1', 36414) with fd 25
2026-01-02 12:23:03 | server | 140653235951168 | INFO : welcome ('127.0.0.1', 36414)
INFO 01-02 12:23:08 [cache_tensor_manager.py:17] USE_GPU_TENSOR_CACHE is On
WARNING 01-02 12:23:09 [sgl_utils.py:29] sgl_kernel is not installed, or the installed version did not support fa3.         Try to upgrade it.
INFO 01-02 12:23:10 [__init__.py:36] Available plugins for group vllm.platform_plugins:
INFO 01-02 12:23:10 [__init__.py:38] - musa -> vllm_musa:register
INFO 01-02 12:23:10 [__init__.py:41] All plugins in this group will be loaded. Set `VLLM_PLUGINS` to control which plugins to load.
INFO 01-02 12:23:10 [__init__.py:232] Platform plugin musa is activated
WARNING 01-02 12:23:10 [vllm_utils.py:18] vllm is not installed, you can't use the api of it.                    You can solve it by running `pip install vllm`.
WARNING 01-02 12:23:10 [light_utils.py:13] lightllm_kernel is not installed, you can't use the api of it.
WARNING 01-02 12:23:10 [grouped_fused_moe_ep.py:28] no deepep or deep_gemm
INFO 01-02 12:23:10 [communication_op.py:57] deep_ep is not installed, you can't use the api of it.
WARNING 01-02 12:23:10 [nixl_kv_transporter.py:19] nixl is not installed, which is required for pd disagreggation!!!
INFO 01-02 12:23:10 [model_rpc.py:67] Initialized RPC server for rank 0.
INFO 01-02 12:23:10 [model_rpc.py:168] use ChunkedPrefillBackend
INFO 01-02 12:23:11 [basemodel.py:157] Initial quantization. The default quantization method is none
pid 39235 Loading model weights with 1 workers: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.01it/s]
INFO 01-02 12:23:12 [mem_utils.py:37] mode setting params: []
INFO 01-02 12:23:12 [mem_utils.py:57] Model kv cache using mode normal
INFO 01-02 12:23:12 [mem_manager.py:84] 69.38735313415528 GB space is available after load the model weight
INFO 01-02 12:23:12 [mem_manager.py:84] 0.109375 MB is the size of one token kv cache
INFO 01-02 12:23:12 [mem_manager.py:84] 649624 is the profiled max_total_token_num with the mem_fraction 0.9
INFO 01-02 12:23:12 [mem_manager.py:84] 
warming up:   0%|                                                                                                                                                                  | 0/12 [00:00<?, ?it/s]WARNING 01-02 12:23:23 [autotuner.py:169] No kernel config for silu_and_mul_fwd:v1 in {N=3072,out_dtype=torch.bfloat16}_MTT_S5000.json,the performance may be suboptimal!You can use LIGHTLLM_TRITON_AUTOTUNE_LEVEL=1 to enable autotune.
WARNING 01-02 12:23:23 [kernel_config.py:40] can not find config_path /ws/lightllm/common/all_kernel_configs/moe_silu_and_mul_kernel/{N=3072,out_dtype=torch.bfloat16}_MTT_S5000.json kernel name moe_silu_and_mul_kernel use default kernel setting
warming up: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:15<00:00,  1.29s/it]
INFO 01-02 12:23:30 [basemodel.py:812] begin check max_len infer
INFO 01-02 12:23:30 [basemodel.py:849] check max_len 8448 infer ok
INFO 01-02 12:23:45 [base_backend.py:185] loaded model class <class 'lightllm.models.qwen3.model.Qwen3TpPartModel'>
INFO 01-02 12:23:45 [manager.py:196] use req queue ChunkedPrefillQueue
INFO 01-02 12:23:45 [start_utils.py:37] init func start_router_process : init ok
INFO 01-02 12:23:45 [start_utils.py:37] init func start_detokenization_process : init ok
INFO 01-02 12:23:45 [api_start.py:58] start process pid 30307
INFO 01-02 12:23:45 [api_start.py:59] http server pid 54746
[2026-01-02 12:23:45 +0800] [54746] [INFO] Starting gunicorn 23.0.0
[2026-01-02 12:23:45 +0800] [54746] [INFO] Listening at: http://0.0.0.0:8000 (54746)
[2026-01-02 12:23:45 +0800] [54746] [INFO] Using worker: uvicorn.workers.UvicornWorker
[2026-01-02 12:23:45 +0800] [54966] [INFO] Booting worker with pid: 54966
WARNING 01-02 12:23:51 [sgl_utils.py:29] sgl_kernel is not installed, or the installed version did not support fa3.         Try to upgrade it.
WARNING 01-02 12:23:51 [light_utils.py:13] lightllm_kernel is not installed, you can't use the api of it.
INFO 01-02 12:23:52 [__init__.py:36] Available plugins for group vllm.platform_plugins:
INFO 01-02 12:23:52 [__init__.py:38] - musa -> vllm_musa:register
INFO 01-02 12:23:52 [__init__.py:41] All plugins in this group will be loaded. Set `VLLM_PLUGINS` to control which plugins to load.
INFO 01-02 12:23:52 [__init__.py:232] Platform plugin musa is activated
WARNING 01-02 12:23:52 [vllm_utils.py:18] vllm is not installed, you can't use the api of it.                    You can solve it by running `pip install vllm`.
INFO 01-02 12:23:52 [communication_op.py:57] deep_ep is not installed, you can't use the api of it.
INFO 01-02 12:23:52 [cache_tensor_manager.py:17] USE_GPU_TENSOR_CACHE is On
WARNING 01-02 12:23:52 [grouped_fused_moe_ep.py:28] no deepep or deep_gemm
[2026-01-02 12:23:52 +0800] [54966] [INFO] Started server process [54966]
[2026-01-02 12:23:52 +0800] [54966] [INFO] Waiting for application startup.
INFO 01-02 12:23:52 [api_http.py:359] server start up
2026-01-02 12:23:53 | server | 140684395422848 | INFO : accepted ('127.0.0.1', 55128) with fd 26
2026-01-02 12:23:53 | server | 140653227558464 | INFO : welcome ('127.0.0.1', 55128)
2026-01-02 12:23:53 | server | 140684395422848 | INFO : accepted ('127.0.0.1', 55144) with fd 27
2026-01-02 12:23:53 | server | 140653219165760 | INFO : welcome ('127.0.0.1', 55144)
INFO 01-02 12:23:54 [req_id_generator.py:34] ReqIDGenerator init finished
INFO 01-02 12:23:54 [api_http.py:363] server start up ok, loop use is <uvloop.Loop running=True closed=False debug=False>
[2026-01-02 12:23:54 +0800] [54966] [INFO] Application startup complete.
INFO 01-02 12:23:58 [manager.py:417] recieved req X-Request-Id: X-Session-Id: start_time:2026-01-02 12:23:58 lightllm_req_id:8 
INFO 01-02 12:23:58 [manager.py:424] router recive req id 8 cost time 0.05271601676940918 s
DEBUG 01-02 12:23:58 [manager.py:322] Prefill Batch: batch_id=-1, time:1767327838.6764812s req_ids:[8] 
DEBUG 01-02 12:23:58 [manager.py:322] 
INFO 01-02 12:23:58 [manager.py:55] detokenization recv req id 8 cost time 0.0744318962097168 s
INFO 01-02 12:23:59 [manager.py:163] detoken release req id 8
INFO 01-02 12:23:59 [manager.py:611] X-Request-Id: X-Session-Id: start_time:2026-01-02 12:23:58 lightllm_req_id:8 first_token_cost:409.63053703308105ms total_cost_time:907.1474075317383ms,out_token_counter:17 mean_per_token_cost_time: 29.265698264626895ms prompt_token_num:4 gpu cache hit: False gpu_prompt_cache_len:0 gpu_prompt_cache_ratio:0.0 cpu cache hit: False cpu_prompt_cache_len:0 cpu_prompt_cache_ratio:0.0 disk cache hit: False disk_prompt_cache_len:0 disk_prompt_cache_ratio:0.0 mtp_avg_token_per_step:1.0 
127.0.0.1:38158 - "POST /generate HTTP/1.1" 200
DEBUG 01-02 12:23:59 [req_manager.py:78] freed all request size 1008
DEBUG 01-02 12:23:59 [infer_batch.py:172] free a batch state:
DEBUG 01-02 12:23:59 [infer_batch.py:172] radix refed token num 0
DEBUG 01-02 12:23:59 [infer_batch.py:172] radix hold token num 21
DEBUG 01-02 12:23:59 [infer_batch.py:172] mem manager can alloc token num 649603
DEBUG 01-02 12:23:59 [infer_batch.py:172] mem manager total size 649624
INFO 01-02 12:23:59 [batch.py:56] router release req id 8
INFO 01-02 12:23:59 [shm_req_manager.py:111] all shm req has been release ok
```

Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
Co-authored-by: wangzaijun <wangzaijun@sensetime.com>
Co-authored-by: root <root@DESKTOP-5FJJCPK.localdomain>
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @sufubao, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates the Qwen3Next model, which employs a hybrid attention mechanism combining full attention with Gated Delta Networks (GDN) linear attention. It includes substantial enhancements to the Multi-Token Prediction (MTP) system, featuring specialized memory management and optimized Triton kernels for efficient state handling during inference. The changes streamline memory allocation, introduce new inference logic for GDN layers, and provide autotuning for critical kernel operations, aiming to boost performance and memory efficiency for this new model.

Highlights

  • Qwen3Next Model Integration: Introduced comprehensive support for the Qwen3Next model, including its base and Multi-Token Prediction (MTP) variants, featuring a novel hybrid attention architecture.
  • Hybrid Attention Mechanism: Implemented Gated Delta Networks (GDN) for linear attention layers, alongside traditional full attention, enabling a more efficient and flexible model architecture.
  • Enhanced Multi-Token Prediction (MTP): Significantly upgraded MTP capabilities with specialized request management, optimized buffer allocation, and efficient state copying for linear attention states during decode.
  • Advanced Triton Kernel Optimizations: Developed new Triton kernels for core GDN operations, such as causal 1D convolutions, fused recurrent GDN, and fused gating, complete with autotuning configurations for improved performance.
  • Refactored Memory Management: Refactored the memory management system by introducing a generic TokenAllocator and a specialized MambaCacheManager for linear attention states, enhancing memory efficiency and flexibility.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for the Qwen3next model, which features a hybrid attention mechanism (standard attention + Gated Delta Networks). The changes are extensive and well-structured, including a new model implementation, a hybrid memory manager for both KV cache and Mamba-style buffers, and several performance optimizations using custom Triton kernels. The core components like ReqManager and MemoryManager have been refactored to be more generic, which is a great improvement for future extensibility. Overall, this is a high-quality contribution that significantly expands the framework's capabilities. I have a few minor suggestions regarding a misleading docstring, a removed validation check, and improving logging for better user experience and debuggability.

Comment on lines 81 to 86
input_ids.extend(origin_ids[start_idx:])
return input_ids
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The validation check that ensures the number of image tags in the prompt matches the number of provided images has been removed. This could lead to silent errors or unexpected behavior if there's a mismatch. It's recommended to restore this check to maintain data integrity.

Suggested change
input_ids.extend(origin_ids[start_idx:])
return input_ids
input_ids.extend(origin_ids[start_idx:])
if multimodal_params:
image_cnt = len(multimodal_params.images)
assert image_cnt == image_id, f"invalid image tag num: {image_cnt} vs {image_id}!"

Comment on lines +59 to +65
"""
Copy buffers from source indices to destination indices using optimized Triton kernel.
Args:
src_buffer_indexes: Source buffer indices (1D tensor)
dst_buffer_indexes: Destination buffer indices (1D tensor)
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring here mentions using an "optimized Triton kernel", but the implementation below uses PyTorch's advanced indexing for the copy operation. To avoid confusion, the docstring should be updated to accurately reflect the implementation.

Suggested change
"""
Copy buffers from source indices to destination indices using optimized Triton kernel.
Args:
src_buffer_indexes: Source buffer indices (1D tensor)
dst_buffer_indexes: Destination buffer indices (1D tensor)
"""
"""
Copy buffers from source indices to destination indices.
Args:
src_buffer_indexes: Source buffer indices (1D tensor)
dst_buffer_indexes: Destination buffer indices (1D tensor)
"""

Comment on lines 154 to 155
if args.mtp_draft_model_dir is None:
args.mtp_draft_model_dir = [args.model_dir] * args.mtp_step
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

When mtp_draft_model_dir is not provided, the code now defaults to using the main model directory. This is a convenient fallback, but it might be surprising to users. It would be helpful to add a log message to inform the user that this default behavior is being applied.

Suggested change
if args.mtp_draft_model_dir is None:
args.mtp_draft_model_dir = [args.model_dir] * args.mtp_step
if args.mtp_draft_model_dir is None:
logger.info(f"'mtp_draft_model_dir' not set, using main model dir '{args.model_dir}' as draft model.")
args.mtp_draft_model_dir = [args.model_dir] * args.mtp_step

estimated_peak_token_count = self.shared_token_load.get_estimated_peak_token_count(d_i)
paused_req_num = self._get_paused_req_num_in_dp_index(dp_index=d_i)
logger.debug(
logger.warning(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The log level for this message about DP status was changed from debug to warning. While this makes it more visible, this information seems more suited for info or debug levels, as it reflects normal operational state rather than a potential problem. Using warning might cause unnecessary alarm in production logs.

Suggested change
logger.warning(
logger.info(

@sufubao sufubao force-pushed the qwen3next branch 4 times, most recently from b09c18e to 2c64777 Compare January 20, 2026 06:50
@sufubao
Copy link
Collaborator Author

sufubao commented Jan 21, 2026

Code review (additional unnecessary changes)

Found additional file modifications that appear unrelated to supporting qwen3next:

1. int8kv kernel file renames and refactoring - Files renamed from int8kv_* to ppl_int8kv_*, stage2 switched from Triton to CUDA kernels, various optimizations. These are general kernel improvements unrelated to qwen3next.

# 为 diverse mode 定制设计的 int8kv flash decoding attention 实现,可以实现更高效的多样性采样
import torch
from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops
from lightllm.common.basemodel.infer_struct import InferStateInfo
from .ppl_int8kv_flash_decoding_diverse_stage1 import flash_decode_stage1
from .ppl_int8kv_flash_decoding_diverse_stage3 import flash_diverse_decode_stage3
from lightllm.utils.envs_utils import get_diverse_max_batch_shared_group_size
def token_decode_attention_flash_decoding(
q,
infer_state: InferStateInfo,
cache_k,
cache_k_scale,
cache_v,
cache_v_scale,
out=None,
alloc_tensor_func=torch.empty,
shared_streams_dict={},
):
if "stream1" not in shared_streams_dict:
shared_streams_dict["stream1"] = torch.cuda.Stream()
if "stream2" not in shared_streams_dict:
shared_streams_dict["stream2"] = torch.cuda.Stream()
stream1 = shared_streams_dict["stream1"]
stream2 = shared_streams_dict["stream2"]
q_head_num = q.shape[1]
head_dim = q.shape[2]

2. MTP metrics removed - Removes mtp_avg_token_per_step metric from httpserver/manager.py, httpserver_for_pd_master/manager.py, and metrics/metrics.py. This removes existing monitoring functionality.

import os
import time
from prometheus_client import CollectorRegistry, Histogram, Counter, Gauge
from prometheus_client import push_to_gateway
from prometheus_client.exposition import basic_auth_handler
MONITOR_INFO = {
"lightllm_request_count": "The total number of requests",
"lightllm_request_success": "The number of successful requests",
"lightllm_request_failure": "The number of failed requests",
"lightllm_request_duration": "Duration of the request (s)",
"lightllm_request_validation_duration": "Validation time of the request",
"lightllm_request_inference_duration": "Inference time of the request",
"lightllm_request_mean_time_per_token_duration": "Per token time of the request",
"lightllm_request_first_token_duration": "First token time of the request",
"lightllm_request_input_length": "Length of the input tokens",
"lightllm_request_generated_tokens": "Number of generated tokens",
"lightllm_request_max_new_tokens": "Max new token",
"lightllm_batch_next_size": "Batch size of the next new batch",
"lightllm_batch_current_size": "Current batch size",
"lightllm_batch_pause_size": "The number of pause requests",
"lightllm_queue_size": "Queue size",
"lightllm_request_queue_duration_bucket": "Queue duration of requests",
"lightllm_batch_inference_count": "The number of prefill steps / decode steps",
"lightllm_batch_inference_duration_bucket": "Inference time of prefill step / decode step",
"lightllm_cache_length": "Length of tokens which hit prompt cache",
"lightllm_cache_ratio": "cache length / input_length",
"lightllm_batch_current_max_tokens": "dynamic max token used for current batch",
}
def my_auth_handler(url, method, timeout, headers, data):
username = os.getenv("USERNAME", None)
password = os.getenv("PASSWORD", None)
if username is None or password is None:
raise ValueError("USERNAME and PASSWORD must be set when the auth is opened.")
return basic_auth_handler(url, method, timeout, headers, data, username, password)
class Monitor:
def __init__(self, args):
duration_buckets = []
value = 0.001
n_duration_buckets = 35
for _ in range(n_duration_buckets):
value *= 1.5
duration_buckets.append(value)
self.duration_buckets = duration_buckets
self.monitor_registry = {}
self.gateway_url = args.metric_gateway
self.registry = CollectorRegistry()
self.job_name = args.job_name
self.grouping_key = {}
if args.grouping_key:
for item in args.grouping_key:
key, value = item.split("=")
self.grouping_key[key] = value
self.auth = args.enable_monitor_auth
self.init_metrics(args)

3. Mistral/Qwen3 MoE MTP model refactoring - Changes to mistral_mtp/* and qwen3_moe_mtp/* files including weight loading, layer numbering, and variable renames (e.g., mtp_draft_input_hiddens to deepseekv3_mtp_draft_input_hiddens). These are unrelated to qwen3next which has its own separate modules.

import torch
from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
from lightllm.models.mistral_mtp.layer_weights.pre_and_post_layer_weight import MistralMTPPreAndPostLayerWeight
from lightllm.common.basemodel.triton_kernel.rmsnorm import rmsnorm_forward
class MistralMTPPreLayerInfer(LlamaPreLayerInfer):
""" """
def __init__(self, network_config):
super().__init__(network_config)
return
def _mtp_context_forward(
self, input_embdings, infer_state: LlamaInferStateInfo, layer_weight: MistralMTPPreAndPostLayerWeight
):
tgt_embdings = infer_state.deepseekv3_mtp_draft_input_hiddens
assert (
input_embdings.shape[0] == tgt_embdings.shape[0]
), f"shape {input_embdings.shape} != shape {tgt_embdings.shape}"
rmsnorm_forward(input_embdings, weight=layer_weight.enorm_weight_, eps=self.eps_, out=input_embdings)
tgt_embdings = rmsnorm_forward(tgt_embdings, weight=layer_weight.final_norm_weight_, eps=self.eps_)
rmsnorm_forward(tgt_embdings, weight=layer_weight.hnorm_weight_, eps=self.eps_, out=tgt_embdings)
cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1)
ans_logics = self.alloc_tensor(
(cat_embdings.shape[0], layer_weight.eh_proj_weight_.shape[1]), dtype=input_embdings.dtype
)
torch.mm(cat_embdings, layer_weight.eh_proj_weight_, out=ans_logics)
return ans_logics
def _mtp_token_forward(
self, input_embdings, infer_state: LlamaInferStateInfo, layer_weight: MistralMTPPreAndPostLayerWeight
):
tgt_embdings = infer_state.deepseekv3_mtp_draft_input_hiddens
assert input_embdings.shape[0] == tgt_embdings.shape[0]
rmsnorm_forward(input_embdings, weight=layer_weight.enorm_weight_, eps=self.eps_, out=input_embdings)
tgt_embdings = rmsnorm_forward(tgt_embdings, weight=layer_weight.final_norm_weight_, eps=self.eps_)
rmsnorm_forward(tgt_embdings, weight=layer_weight.hnorm_weight_, eps=self.eps_, out=tgt_embdings)
cat_embdings = torch.cat((input_embdings, tgt_embdings), dim=-1)
ans_logics = self.alloc_tensor(
(cat_embdings.shape[0], layer_weight.eh_proj_weight_.shape[1]), dtype=input_embdings.dtype
)
torch.mm(cat_embdings, layer_weight.eh_proj_weight_, out=ans_logics)

4. Test/benchmark files - New files test/benchmark/service/benchmark_gsm8k.py, test/test_api/test_chat.py, test/test_api/test_gsmk.py are general evaluation tools not specific to qwen3next. Unit test file renames for int8kv kernels are also unrelated.

from openai import OpenAI
from datetime import datetime
import argparse
import threading
import random
import time
import sys
from typing import List
# 尝试导入 readline 以支持更好的中文输入处理
try:
import readline
except ImportError:
# Windows 上可能没有 readline,尝试使用 pyreadline3
try:
import pyreadline3 as readline
except ImportError:
readline = None
def safe_input(prompt: str) -> str:
"""
安全的输入函数,处理中文输入删除乱码问题
"""
if readline is not None:
# readline 已加载,直接使用 input
return input(prompt)
else:
# 没有 readline,使用替代方案
sys.stdout.write(prompt)

5. CUDA graph tqdm progress bar - Adds UI enhancement to show progress during CUDA graph warmup. This is a general improvement unrelated to qwen3next.

import os
import torch
import copy
import bisect
from typing import Optional
from tqdm import tqdm
from lightllm.utils.log_utils import init_logger
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.distributed import dist_group_manager, lightllm_capture_graph, CustomProcessGroup
from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput
from .infer_struct import InferStateInfo
logger = init_logger(__name__)
class CudaGraph:
# CudaGraph forward pass for the decoding stage.
def __init__(self, max_batch_size=8, max_len_in_batch=8192):
self.graph = {}
self.mempool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
self.args = get_env_start_args()
self.mtp_step = self.args.mtp_step
self.max_batch_size = max_batch_size
self.graph_max_len_in_batch = max_len_in_batch
self.enable_decode_microbatch_overlap = self.args.enable_decode_microbatch_overlap
# gen cuda graph batch_sizes
# cuda graph gen for batch size = [1, 2, 3, ..., graph_split_batch_size]
# and [graph_split_batch_size + graph_grow_step_size,
# if the mtp_step is not 0, then the batch_sizes will be multiply of (mtp_step + 1)
graph_split_batch_size = self.args.graph_split_batch_size * (self.mtp_step + 1)
graph_grow_step_size = self.args.graph_grow_step_size * (self.mtp_step + 1)
batch_sizes = [i * (self.mtp_step + 1) for i in range(1, graph_split_batch_size + 1)]
for _batch_size in range(graph_split_batch_size + graph_grow_step_size, max_batch_size, graph_grow_step_size):
batch_sizes.append(_batch_size)
batch_sizes = list(set([e for e in batch_sizes if e < max_batch_size]))
batch_sizes.append(max_batch_size)
batch_sizes.sort()
self.cuda_graph_batch_sizes = batch_sizes
assert batch_sizes[-1] == self.max_batch_size
logger.info(f"cuda graph batch_sizes: {self.cuda_graph_batch_sizes}")
def can_run(self, batch_size, max_len_in_batch):
return batch_size <= self.max_batch_size and max_len_in_batch <= self.graph_max_len_in_batch

6. start_args_type.py default value changes - Many default value changes (tokenizer_mode slow->fast, max_req_total_len 3072->16384, chunked_prefill_size 8192->4096, etc.) are bundled configuration changes unrelated to qwen3next.

from dataclasses import dataclass, field
from typing import List, Optional, Tuple
# 服务启动参数
@dataclass
class StartArgs:
run_mode: str = field(
default="normal",
metadata={
"choices": ["normal", "prefill", "decode", "nixl_prefill", "nixl_decode", "pd_master", "config_server"]
},
)
host: str = field(default="127.0.0.1")
port: int = field(default=8000)
httpserver_workers: int = field(default=1)
zmq_mode: str = field(
default="ipc:///tmp/",
metadata={"help": "use socket mode or ipc mode, only can be set in ['tcp://', 'ipc:///tmp/']"},
)
pd_master_ip: str = field(default="0.0.0.0")
pd_master_port: int = field(default=1212)
config_server_host: str = field(default=None)
config_server_port: int = field(default=None)
pd_decode_rpyc_port: int = field(default=42000)
select_p_d_node_strategy: str = field(
default="round_robin", metadata={"choices": ["random", "round_robin", "adaptive_load"]}
)
model_name: str = field(default="default_model_name")
model_dir: Optional[str] = field(default=None)
tokenizer_mode: str = field(default="fast")
load_way: str = field(default="HF")
max_total_token_num: Optional[int] = field(default=None)
mem_fraction: float = field(default=0.9)
batch_max_tokens: Optional[int] = field(default=None)
eos_id: Optional[List[int]] = field(default=None)
tool_call_parser: Optional[str] = field(
default=None, metadata={"choices": ["qwen25", "llama3", "mistral", "deepseekv3", "qwen"]}
)
reasoning_parser: Optional[str] = field(
default=None,
metadata={
"choices": [
"deepseek-r1",
"deepseek-v3",
"glm45",
"gpt-oss",
"kimi",
"kimi_k2",
"qwen3",
"qwen3-thinking",
"minimax",
"minimax-append-think",
"step3",
"nano_v3",
"interns1",
]
},
)
chat_template: Optional[str] = field(default=None)
running_max_req_size: int = field(default=1000)
tp: int = field(default=1)
dp: int = field(default=1)
nnodes: int = field(default=1)
node_rank: int = field(default=0)
max_req_total_len: int = field(default=16384)
nccl_host: str = field(default="127.0.0.1")
nccl_port: int = field(default=28765)
use_config_server_to_init_nccl: bool = field(default=False)
mode: List[str] = field(default_factory=lambda: [])
trust_remote_code: bool = field(default=False)
disable_log_stats: bool = field(default=False)
log_stats_interval: int = field(default=10)
router_token_ratio: float = field(default=0.0)
router_max_new_token_len: int = field(default=1024)
router_max_wait_tokens: int = field(default=1)
disable_aggressive_schedule: bool = field(default=False)
disable_dynamic_prompt_cache: bool = field(default=False)
chunked_prefill_size: int = field(default=4096)
disable_chunked_prefill: bool = field(default=False)
diverse_mode: bool = field(default=False)
token_healing_mode: bool = field(default=False)
output_constraint_mode: str = field(default="none", metadata={"choices": ["outlines", "xgrammar", "none"]})
first_token_constraint_mode: bool = field(default=False)
enable_tpsp_mix_mode: bool = field(default=False)
enable_dp_prefill_balance: bool = field(default=False)
enable_decode_microbatch_overlap: bool = field(default=False)
enable_prefill_microbatch_overlap: bool = field(default=False)
cache_capacity: int = field(default=200)
embed_cache_storage_size: float = field(default=4)
data_type: Optional[str] = field(
default=None, metadata={"choices": ["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"]}
)
return_all_prompt_logprobs: bool = field(default=False)
use_reward_model: bool = field(default=False)
long_truncation_mode: Optional[str] = field(default=None, metadata={"choices": [None, "head", "center"]})
use_tgi_api: bool = field(default=False)
health_monitor: bool = field(default=False)
metric_gateway: Optional[str] = field(default=None)

7. api_start.py SIGHUP handler - Adds graceful shutdown signal handling which is infrastructure work unrelated to model support.

import os
import sys
import time
import uuid
import subprocess
import signal
from lightllm.utils.net_utils import alloc_can_use_network_port, PortLocker
from lightllm.utils.start_utils import process_manager, kill_recursive
from .metrics.manager import start_metric_manager
from .embed_cache.manager import start_cache_manager
from lightllm.utils.log_utils import init_logger
from lightllm.utils.envs_utils import set_env_start_args, set_unique_server_name, get_unique_server_name
from lightllm.utils.envs_utils import get_lightllm_gunicorn_time_out_seconds, get_lightllm_gunicorn_keep_alive
from .detokenization.manager import start_detokenization_process
from .router.manager import start_router_process
from lightllm.utils.process_check import is_process_active
from lightllm.utils.multinode_utils import send_and_receive_node_ip
from lightllm.utils.shm_size_check import check_recommended_shm_size
from lightllm.server.core.objs.start_args_type import StartArgs
logger = init_logger(__name__)
def setup_signal_handlers(http_server_process, process_manager):
def signal_handler(sig, frame):
if sig == signal.SIGINT:
logger.info("Received SIGINT (Ctrl+C), forcing immediate exit...")
if http_server_process:
kill_recursive(http_server_process)
process_manager.terminate_all_processes()
logger.info("All processes have been forcefully terminated.")
sys.exit(0)
elif sig == signal.SIGTERM:
logger.info("Received SIGTERM, shutting down gracefully...")
if http_server_process and http_server_process.poll() is None:
http_server_process.send_signal(signal.SIGTERM)
start_time = time.time()
while (time.time() - start_time) < 60:
if not is_process_active(http_server_process.pid):
logger.info("httpserver exit")
break
time.sleep(1)
if time.time() - start_time < 60:
logger.info("HTTP server has exited gracefully")
else:
logger.warning("HTTP server did not exit in time, killing it...")
kill_recursive(http_server_process)

Consider splitting these into separate PRs for cleaner review and easier rollback if needed.

Generated with Claude Code

- If this code review was useful, please react with 👍. Otherwise, react with 👎.

yeahdongcn and others added 5 commits January 21, 2026 04:27
Co-authored-by: shihaobai <42648726+shihaobai@users.noreply.github.com>
Co-authored-by: wangzaijun <wangzaijun@sensetime.com>
Co-authored-by: sangchengmeng <sangchengmeng@sensetime.com>
Add autotune kernel configurations for NVIDIA H200:
- FLA chunk kernels (chunk_fwd_o, chunk_gated_delta_rule_fwd_h)
- Cumsum and dot product kernels
- Fused GDN gating and gated RMSNorm kernels
- MoE grouped matmul and alignment kernels
- SiLU activation kernels

Configs provided for both triton 3.4.0 and 3.5.1
Add support for Qwen3Next architecture including:
- New model implementation with GDN (Gated Delta Network) attention
- Mamba cache memory manager for hybrid architecture
- FLA (Flash Linear Attention) triton kernels
- Custom triton kernels for causal conv1d, gated RMSNorm, fused gating
- MTP (Multi-Token Prediction) variant support
- Allocator utilities and parameter weight management
- Hybrid radix cache for dynamic prompt handling
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants