Skip to content

Commit 197a0f7

Browse files
authored
[BugFix] fix VL fp8 bug when moe token_num is 0 (#4929)
* [BugFix] fix VL fp8 bug when moe token_num is 0 * fix bug
1 parent 0a6981f commit 197a0f7

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,7 @@ def apply(
248248
"""
249249
Triton compute Fused MoE.
250250
"""
251+
x = paddle.concat([x, paddle.ones([1, layer.hidden_size], dtype=x.dtype)])
251252
gate_out = gate(x.cast("float32"))
252253
token_num = x.shape[0]
253254
top_k = layer.top_k
@@ -395,6 +396,7 @@ def apply(
395396
if layer.reduce_results and layer.tp_size > 1:
396397
out = tensor_model_parallel_all_reduce(out)
397398

399+
out = out[:-1]
398400
return out
399401

400402

@@ -601,6 +603,7 @@ def apply(
601603
"""
602604
Triton compute Fused MoE.
603605
"""
606+
x = paddle.concat([x, paddle.ones([1, layer.hidden_size], dtype=x.dtype)])
604607
gate_out = gate(x.cast("float32"))
605608
token_num = x.shape[0]
606609
top_k = layer.top_k
@@ -769,6 +772,7 @@ def apply(
769772
if layer.reduce_results and layer.tp_size > 1:
770773
out = tensor_model_parallel_all_reduce(out)
771774

775+
out = out[:-1]
772776
return out
773777

774778

@@ -891,6 +895,7 @@ def apply(
891895
"""
892896
Triton compute Fused MoE.
893897
"""
898+
x = paddle.concat([x, paddle.ones([1, layer.hidden_size], dtype=x.dtype)])
894899
gate_out = gate(x.cast("float32"))
895900
token_num = x.shape[0]
896901
top_k = layer.top_k
@@ -1058,6 +1063,7 @@ def apply(
10581063
if layer.tp_size > 1:
10591064
out = tensor_model_parallel_all_reduce(out)
10601065

1066+
out = out[:-1]
10611067
return out
10621068

10631069

@@ -1315,6 +1321,7 @@ def apply(
13151321
"""
13161322
Triton compute Fused MoE.
13171323
"""
1324+
x = paddle.concat([x, paddle.ones([1, layer.hidden_size], dtype=x.dtype)])
13181325
gate_out = gate(x.cast("float32"))
13191326
token_num = x.shape[0]
13201327
top_k = layer.top_k
@@ -1462,4 +1469,5 @@ def apply(
14621469
if layer.tp_size > 1:
14631470
out = tensor_model_parallel_all_reduce(out)
14641471

1472+
out = out[:-1]
14651473
return out

0 commit comments

Comments
 (0)