@@ -248,6 +248,7 @@ def apply(
248248 """
249249 Triton compute Fused MoE.
250250 """
251+ x = paddle .concat ([x , paddle .ones ([1 , layer .hidden_size ], dtype = x .dtype )])
251252 gate_out = gate (x .cast ("float32" ))
252253 token_num = x .shape [0 ]
253254 top_k = layer .top_k
@@ -395,6 +396,7 @@ def apply(
395396 if layer .reduce_results and layer .tp_size > 1 :
396397 out = tensor_model_parallel_all_reduce (out )
397398
399+ out = out [:- 1 ]
398400 return out
399401
400402
@@ -601,6 +603,7 @@ def apply(
601603 """
602604 Triton compute Fused MoE.
603605 """
606+ x = paddle .concat ([x , paddle .ones ([1 , layer .hidden_size ], dtype = x .dtype )])
604607 gate_out = gate (x .cast ("float32" ))
605608 token_num = x .shape [0 ]
606609 top_k = layer .top_k
@@ -769,6 +772,7 @@ def apply(
769772 if layer .reduce_results and layer .tp_size > 1 :
770773 out = tensor_model_parallel_all_reduce (out )
771774
775+ out = out [:- 1 ]
772776 return out
773777
774778
@@ -891,6 +895,7 @@ def apply(
891895 """
892896 Triton compute Fused MoE.
893897 """
898+ x = paddle .concat ([x , paddle .ones ([1 , layer .hidden_size ], dtype = x .dtype )])
894899 gate_out = gate (x .cast ("float32" ))
895900 token_num = x .shape [0 ]
896901 top_k = layer .top_k
@@ -1058,6 +1063,7 @@ def apply(
10581063 if layer .tp_size > 1 :
10591064 out = tensor_model_parallel_all_reduce (out )
10601065
1066+ out = out [:- 1 ]
10611067 return out
10621068
10631069
@@ -1315,6 +1321,7 @@ def apply(
13151321 """
13161322 Triton compute Fused MoE.
13171323 """
1324+ x = paddle .concat ([x , paddle .ones ([1 , layer .hidden_size ], dtype = x .dtype )])
13181325 gate_out = gate (x .cast ("float32" ))
13191326 token_num = x .shape [0 ]
13201327 top_k = layer .top_k
@@ -1462,4 +1469,5 @@ def apply(
14621469 if layer .tp_size > 1 :
14631470 out = tensor_model_parallel_all_reduce (out )
14641471
1472+ out = out [:- 1 ]
14651473 return out
0 commit comments