@@ -76,6 +76,14 @@ class TokenClassificationInput(BaseModel):
7676 "a token_classification task"
7777 )
7878 )
79+ is_split_into_words : bool = Field (
80+ default = False ,
81+ description = (
82+ "True if the input is a batch size 1 list of strings representing. "
83+ "individual word tokens. Currently only supports batch size 1. "
84+ "Default is False"
85+ ),
86+ )
7987
8088
8189class TokenClassificationResult (BaseModel ):
@@ -245,13 +253,17 @@ def process_inputs(
245253 and dictionary containing offset mappings and special tokens mask to
246254 be used during postprocessing
247255 """
256+ if inputs .is_split_into_words and self .engine .batch_size != 1 :
257+ raise ValueError ("is_split_into_words=True only supported for batch size 1" )
258+
248259 tokens = self .tokenizer (
249260 inputs .inputs ,
250261 return_tensors = "np" ,
251262 truncation = TruncationStrategy .LONGEST_FIRST .value ,
252263 padding = PaddingStrategy .MAX_LENGTH .value ,
253264 return_special_tokens_mask = True ,
254265 return_offsets_mapping = self .tokenizer .is_fast ,
266+ is_split_into_words = inputs .is_split_into_words ,
255267 )
256268
257269 offset_mapping = (
@@ -260,11 +272,29 @@ def process_inputs(
260272 else [None ] * len (inputs .inputs )
261273 )
262274 special_tokens_mask = tokens .pop ("special_tokens_mask" )
275+
276+ word_start_mask = None
277+ if inputs .is_split_into_words :
278+ # create mask for word in the split words where values are True
279+ # if they are the start of a tokenized word
280+ word_start_mask = []
281+ word_ids = tokens .word_ids (batch_index = 0 )
282+ previous_id = None
283+ for word_id in word_ids :
284+ if word_id is None :
285+ continue
286+ if word_id != previous_id :
287+ word_start_mask .append (True )
288+ previous_id = word_id
289+ else :
290+ word_start_mask .append (False )
291+
263292 postprocessing_kwargs = dict (
264293 inputs = inputs ,
265294 tokens = tokens ,
266295 offset_mapping = offset_mapping ,
267296 special_tokens_mask = special_tokens_mask ,
297+ word_start_mask = word_start_mask ,
268298 )
269299
270300 return self .tokens_to_engine_input (tokens ), postprocessing_kwargs
@@ -284,6 +314,7 @@ def process_engine_outputs(
284314 tokens = kwargs ["tokens" ]
285315 offset_mapping = kwargs ["offset_mapping" ]
286316 special_tokens_mask = kwargs ["special_tokens_mask" ]
317+ word_start_mask = kwargs ["word_start_mask" ]
287318
288319 predictions = [] # type: List[List[TokenClassificationResult]]
289320
@@ -293,6 +324,7 @@ def process_engine_outputs(
293324 scores = numpy .exp (current_entities ) / numpy .exp (current_entities ).sum (
294325 - 1 , keepdims = True
295326 )
327+
296328 pre_entities = self ._gather_pre_entities (
297329 inputs .inputs [entities_index ],
298330 input_ids ,
@@ -303,9 +335,11 @@ def process_engine_outputs(
303335 grouped_entities = self ._aggregate (pre_entities )
304336 # Filter anything that is in self.ignore_labels
305337 current_results = [] # type: List[TokenClassificationResult]
306- for entity in grouped_entities :
307- if entity .get ("entity" ) in self .ignore_labels or (
308- entity .get ("entity_group" ) in self .ignore_labels
338+ for entity_idx , entity in enumerate (grouped_entities ):
339+ if (
340+ entity .get ("entity" ) in self .ignore_labels
341+ or (entity .get ("entity_group" ) in self .ignore_labels )
342+ or (word_start_mask and not word_start_mask [entity_idx ])
309343 ):
310344 continue
311345 if entity .get ("entity_group" ):
0 commit comments