@@ -637,6 +637,7 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int =
637637 batch_pooling_params = []
638638 for i in range (req_len ):
639639 request = req_dicts [i ]
640+ # assert isinstance(request, Request)
640641 idx = request .idx
641642
642643 if hasattr (request , "pooling_params" ) and request .pooling_params is not None :
@@ -655,14 +656,14 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int =
655656 logits_info , schemata_key = self ._init_logits_processor (request )
656657 request .schemata_key = schemata_key
657658
658- if self . scheduler_config . splitwise_role == "decode" :
659- if (
660- hasattr (request , "prefill_end_index" )
661- and hasattr (request , "prompt_token_ids" )
662- and request .prefill_end_index > len (request .prompt_token_ids )
663- ):
664- if hasattr ( request , "output_token_ids" ):
665- prefill_tokens .extend (request .output_token_ids )
659+ if (
660+ self . scheduler_config . splitwise_role == "decode"
661+ and hasattr (request , "prefill_end_index" )
662+ and hasattr (request , "prompt_token_ids" )
663+ and request .prefill_end_index > len (request .prompt_token_ids )
664+ and hasattr ( request , "output_token_ids" )
665+ ):
666+ prefill_tokens .extend (request .output_token_ids )
666667
667668 prefill_start_index = request .prefill_start_index
668669 prefill_end_index = request .prefill_end_index
@@ -784,12 +785,12 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int =
784785
785786 if request .get ("bad_words_token_ids" ) is not None and len (request .get ("bad_words_token_ids" )) > 0 :
786787 bad_words_len = len (request .get ("bad_words_token_ids" ))
787- self .share_inputs ["bad_tokens_len" ][idx : idx + 1 ] = bad_words_len
788+ self .share_inputs ["bad_tokens_len" ][idx ] = bad_words_len
788789 self .share_inputs ["bad_tokens" ][idx : idx + 1 , :bad_words_len ] = np .array (
789790 request .get ("bad_words_token_ids" ), dtype = "int64"
790791 )
791792 else :
792- self .share_inputs ["bad_tokens_len" ][idx : idx + 1 ] = 1
793+ self .share_inputs ["bad_tokens_len" ][idx ] = 1
793794 self .share_inputs ["bad_tokens" ][idx : idx + 1 , :] = np .array ([- 1 ], dtype = "int64" )
794795
795796 if request .get ("stop_token_ids" ) is not None and request .get ("stop_seqs_len" ) is not None :
@@ -1007,12 +1008,12 @@ def get_attr_from_request(request, attr, default_value=None):
10071008
10081009 if request .get ("bad_words_token_ids" ) is not None and len (request .get ("bad_words_token_ids" )) > 0 :
10091010 bad_words_len = len (request .get ("bad_words_token_ids" ))
1010- self .share_inputs ["bad_tokens_len" ][idx : idx + 1 ] = bad_words_len
1011+ self .share_inputs ["bad_tokens_len" ][idx ] = bad_words_len
10111012 self .share_inputs ["bad_tokens" ][idx : idx + 1 , :bad_words_len ] = np .array (
10121013 request .get ("bad_words_token_ids" ), dtype = "int64"
10131014 )
10141015 else :
1015- self .share_inputs ["bad_tokens_len" ][idx : idx + 1 ] = 1
1016+ self .share_inputs ["bad_tokens_len" ][idx ] = 1
10161017 self .share_inputs ["bad_tokens" ][idx : idx + 1 , :] = np .array ([- 1 ], dtype = "int64" )
10171018
10181019 if request .get ("stop_token_ids" ) is not None and request .get ("stop_seqs_len" ) is not None :
@@ -1217,7 +1218,7 @@ def _init_share_inputs(self, max_num_seqs: int):
12171218 self .share_inputs ["stop_nums" ] = paddle .full ([1 ], max_num_seqs , dtype = "int64" )
12181219
12191220 self .share_inputs ["bad_tokens" ] = paddle .full ([max_num_seqs , self .model_config .vocab_size ], - 1 , dtype = "int64" )
1220- self .share_inputs ["bad_tokens_len" ] = paddle . full ([ max_num_seqs ], 1 , dtype = "int64" )
1221+ self .share_inputs ["bad_tokens_len" ] = [ - 1 ] * max_num_seqs
12211222 self .share_inputs ["next_tokens" ] = paddle .full ([max_num_seqs , 1 ], - 1 , dtype = "int64" )
12221223 self .share_inputs ["is_block_step" ] = paddle .full ([max_num_seqs ], False , dtype = "bool" )
12231224 self .share_inputs ["is_chunk_step" ] = paddle .full ([max_num_seqs ], False , dtype = "bool" ).cpu ()
@@ -1447,7 +1448,7 @@ def _prepare_inputs(self, is_dummy_or_profile_run=False) -> None:
14471448 self .share_inputs ["output_padding_offset" ].copy_ (output_padding_offset , False )
14481449
14491450 # Update bad tokens len
1450- max_bad_tokens_len = np . max (self .share_inputs ["bad_tokens_len" ]. numpy () )
1451+ max_bad_tokens_len = max (self .share_inputs ["bad_tokens_len" ])
14511452
14521453 # Initialize forward meta data
14531454 self .initialize_forward_meta (is_dummy_or_profile_run = is_dummy_or_profile_run )
0 commit comments