Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit df7275f

Browse files
authored
transformers eval support for conll2003 (#504) (#506)
1 parent 13d966f commit df7275f

File tree

2 files changed

+91
-3
lines changed

2 files changed

+91
-3
lines changed

src/deepsparse/transformers/eval_downstream.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,59 @@ def sst2_eval(args):
225225
return sst2_metrics
226226

227227

228+
def conll2003_eval(args):
229+
# load qqp validation dataset and eval tool
230+
conll2003 = load_dataset("conll2003")["validation"]
231+
conll2003_metrics = load_metric("seqeval")
232+
233+
# load pipeline
234+
token_classify = Pipeline.create(
235+
task="token-classification",
236+
model_path=args.onnx_filepath,
237+
engine_type=args.engine,
238+
num_cores=args.num_cores,
239+
sequence_length=args.max_sequence_length,
240+
)
241+
print(f"Engine info: {token_classify.engine}")
242+
243+
ner_tag_map = {
244+
"O": 0,
245+
"B-PER": 1,
246+
"I-PER": 2,
247+
"B-ORG": 3,
248+
"I-ORG": 4,
249+
"B-LOC": 5,
250+
"I-LOC": 6,
251+
"B-MISC": 7,
252+
"I-MISC": 8,
253+
}
254+
# map entity id and raw id from pipeline to NER tag
255+
label_map = {label_id: ner_tag for ner_tag, label_id in ner_tag_map.items()}
256+
label_map.update(
257+
{
258+
token_classify.config.id2label[label_id]: tag
259+
for tag, label_id in ner_tag_map.items()
260+
}
261+
)
262+
263+
for idx, sample in _enumerate_progress(conll2003, args.max_samples):
264+
if not sample["tokens"]:
265+
continue # invalid dataset item, no tokens
266+
pred = token_classify(inputs=sample["tokens"], is_split_into_words=True)
267+
pred_ids = [label_map[prediction.entity] for prediction in pred.predictions[0]]
268+
label_ids = [label_map[ner_tag] for ner_tag in sample["ner_tags"]]
269+
270+
conll2003_metrics.add_batch(
271+
predictions=[pred_ids],
272+
references=[label_ids],
273+
)
274+
275+
if args.max_samples and idx >= args.max_samples:
276+
break
277+
278+
return conll2003_metrics
279+
280+
228281
def _enumerate_progress(dataset, max_steps):
229282
progress_bar = tqdm(dataset, total=max_steps) if max_steps else tqdm(dataset)
230283
return enumerate(progress_bar)
@@ -242,6 +295,7 @@ def _get_label2id(config_file_path):
242295
"mnli": mnli_eval,
243296
"qqp": qqp_eval,
244297
"sst2": sst2_eval,
298+
"conll2003": conll2003_eval,
245299
}
246300

247301

src/deepsparse/transformers/pipelines/token_classification.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,14 @@ class TokenClassificationInput(BaseModel):
7676
"a token_classification task"
7777
)
7878
)
79+
is_split_into_words: bool = Field(
80+
default=False,
81+
description=(
82+
"True if the input is a batch size 1 list of strings representing. "
83+
"individual word tokens. Currently only supports batch size 1. "
84+
"Default is False"
85+
),
86+
)
7987

8088

8189
class TokenClassificationResult(BaseModel):
@@ -245,13 +253,17 @@ def process_inputs(
245253
and dictionary containing offset mappings and special tokens mask to
246254
be used during postprocessing
247255
"""
256+
if inputs.is_split_into_words and self.engine.batch_size != 1:
257+
raise ValueError("is_split_into_words=True only supported for batch size 1")
258+
248259
tokens = self.tokenizer(
249260
inputs.inputs,
250261
return_tensors="np",
251262
truncation=TruncationStrategy.LONGEST_FIRST.value,
252263
padding=PaddingStrategy.MAX_LENGTH.value,
253264
return_special_tokens_mask=True,
254265
return_offsets_mapping=self.tokenizer.is_fast,
266+
is_split_into_words=inputs.is_split_into_words,
255267
)
256268

257269
offset_mapping = (
@@ -260,11 +272,29 @@ def process_inputs(
260272
else [None] * len(inputs.inputs)
261273
)
262274
special_tokens_mask = tokens.pop("special_tokens_mask")
275+
276+
word_start_mask = None
277+
if inputs.is_split_into_words:
278+
# create mask for word in the split words where values are True
279+
# if they are the start of a tokenized word
280+
word_start_mask = []
281+
word_ids = tokens.word_ids(batch_index=0)
282+
previous_id = None
283+
for word_id in word_ids:
284+
if word_id is None:
285+
continue
286+
if word_id != previous_id:
287+
word_start_mask.append(True)
288+
previous_id = word_id
289+
else:
290+
word_start_mask.append(False)
291+
263292
postprocessing_kwargs = dict(
264293
inputs=inputs,
265294
tokens=tokens,
266295
offset_mapping=offset_mapping,
267296
special_tokens_mask=special_tokens_mask,
297+
word_start_mask=word_start_mask,
268298
)
269299

270300
return self.tokens_to_engine_input(tokens), postprocessing_kwargs
@@ -284,6 +314,7 @@ def process_engine_outputs(
284314
tokens = kwargs["tokens"]
285315
offset_mapping = kwargs["offset_mapping"]
286316
special_tokens_mask = kwargs["special_tokens_mask"]
317+
word_start_mask = kwargs["word_start_mask"]
287318

288319
predictions = [] # type: List[List[TokenClassificationResult]]
289320

@@ -293,6 +324,7 @@ def process_engine_outputs(
293324
scores = numpy.exp(current_entities) / numpy.exp(current_entities).sum(
294325
-1, keepdims=True
295326
)
327+
296328
pre_entities = self._gather_pre_entities(
297329
inputs.inputs[entities_index],
298330
input_ids,
@@ -303,9 +335,11 @@ def process_engine_outputs(
303335
grouped_entities = self._aggregate(pre_entities)
304336
# Filter anything that is in self.ignore_labels
305337
current_results = [] # type: List[TokenClassificationResult]
306-
for entity in grouped_entities:
307-
if entity.get("entity") in self.ignore_labels or (
308-
entity.get("entity_group") in self.ignore_labels
338+
for entity_idx, entity in enumerate(grouped_entities):
339+
if (
340+
entity.get("entity") in self.ignore_labels
341+
or (entity.get("entity_group") in self.ignore_labels)
342+
or (word_start_mask and not word_start_mask[entity_idx])
309343
):
310344
continue
311345
if entity.get("entity_group"):

0 commit comments

Comments
 (0)