Skip to content

Commit 31b8161

Browse files
committed
progress towards relation extraction
1 parent da1e9ea commit 31b8161

File tree

4 files changed

+138
-42
lines changed

4 files changed

+138
-42
lines changed

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ dependencies = [
2222
"grimbert>=0.1.5",
2323
"datasets>=4.0.0",
2424
"rank-bm25>=0.2.2",
25-
"accelerate>=1.10.1"
25+
"accelerate>=1.10.1",
26+
"scikit-learn[ui]>=1.6.1",
2627
]
2728

2829
[build-system]
@@ -84,4 +85,4 @@ explicit = true
8485
[[tool.uv.index]]
8586
name = "pytorch-rocm63"
8687
url = "https://download.pytorch.org/whl/rocm6.3"
87-
explicit = true
88+
explicit = true

renard/pipeline/relation_extraction.py

Lines changed: 103 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,24 @@
44
from datasets import load_dataset, Dataset as HGDataset
55
import torch
66
from transformers import (
7-
AutoModelForSeq2SeqLM,
8-
T5ForConditionalGeneration,
7+
AutoModelForCausalLM,
98
AutoTokenizer,
10-
Seq2SeqTrainer,
11-
Seq2SeqTrainingArguments,
9+
Trainer,
10+
TrainingArguments,
1211
PreTrainedTokenizerFast,
12+
DataCollatorForLanguageModeling,
1313
DataCollatorForSeq2Seq,
1414
PreTrainedModel,
15+
EvalPrediction,
1516
pipeline as hg_pipeline,
1617
)
18+
from more_itertools import flatten
1719
from transformers.pipelines.pt_utils import KeyDataset
1820
from renard.pipeline.core import PipelineStep
1921
from renard.pipeline.progress import ProgressReporter
2022
from renard.pipeline.character_unification import Character
23+
from renard.utils import make_vocab
24+
from sklearn.metrics import precision_recall_fscore_support
2125

2226
#: (subject, relation, object)
2327
Relation = tuple[Character, str, Character]
@@ -30,13 +34,13 @@ def format_rel(rel: dict) -> str:
3034
return "({}, {}, {})".format(rel["entity1"], rel["relation"], rel["entity2"])
3135

3236
labels = " ".join(map(format_rel, example["relations"]))
33-
with tokenizer.as_target_tokenizer():
34-
labels_batch = tokenizer(labels)
35-
example["labels"] = labels_batch["input_ids"]
37+
answer = {"role": "assistant", "content": labels}
38+
example["labels"] = tokenizer.apply_chat_template([answer])
3639

3740
text = example["chunk"] or ""
38-
text = f"extract relations: {text}"
39-
example["input_ids"] = tokenizer(text)["input_ids"]
41+
example["input_ids"] = tokenizer.apply_chat_template(
42+
GenerativeRelationExtractor.task_prompt(text)
43+
)
4044

4145
return example
4246

@@ -52,31 +56,85 @@ def load_ARF_dataset(tokenizer: PreTrainedTokenizerFast) -> HGDataset:
5256
"synthetic_relations_in_fiction_books",
5357
split="train",
5458
)
55-
dataset = dataset.train_test_split(test_size=0.1)
59+
dataset = dataset.train_test_split(test_size=0.001)
5660
return dataset.map(ft.partial(_load_ARF_line, tokenizer=tokenizer))
5761

5862

59-
def train_t5_on_ARF(
60-
t5_hg_id: str, targs: Seq2SeqTrainingArguments
61-
) -> T5ForConditionalGeneration:
62-
tokenizer = AutoTokenizer.from_pretrained(t5_hg_id)
63-
model = AutoModelForSeq2SeqLM.from_pretrained(t5_hg_id)
63+
def _triple_precision_recall_f1(
64+
references: list[list[tuple[str, str, str]]],
65+
predictions: list[list[tuple[str, str, str]]],
66+
) -> dict[str, float]:
67+
triple_vocab = make_vocab(list(flatten(references)) + list(flatten(predictions)))
68+
69+
# the "null triple" indicates no prediction (or no reference
70+
# available), useful to compute precision/recall.
71+
null_triple_index = max(triple_vocab.values()) + 1
72+
73+
y, y_hat = [], []
74+
for ref, pred in zip(references, predictions):
75+
ref = {triple: triple_vocab[triple] for triple in ref}
76+
pred = {triple: triple_vocab[triple] for triple in pred}
77+
for ref_triple, ref_index in ref.items():
78+
y.append(ref_index)
79+
y_hat.append(pred.get(ref_triple, null_triple_index))
80+
try:
81+
del pred[ref_triple]
82+
except KeyError:
83+
pass
84+
for pred_triple, pred_index in pred.items():
85+
y_hat.append(pred_index)
86+
y.append(ref.get(pred_triple, null_triple_index))
87+
88+
precision, recall, f1, _ = precision_recall_fscore_support(
89+
y, y_hat, labels=list(triple_vocab.values()), average="micro"
90+
)
91+
92+
return {"precision": float(precision), "recall": float(recall), "f1": float(f1)}
93+
94+
95+
def train_model_on_ARF(
96+
model: str | PreTrainedModel,
97+
targs: TrainingArguments,
98+
tokenizer: PreTrainedTokenizerFast | None = None,
99+
) -> PreTrainedModel:
100+
if isinstance(model, str):
101+
assert tokenizer is None
102+
tokenizer = AutoTokenizer.from_pretrained(model)
103+
model = AutoModelForCausalLM.from_pretrained(model)
104+
assert not tokenizer is None
105+
tokenizer.pad_token = tokenizer.eos_token
106+
pad_token_i = tokenizer.encode(tokenizer.pad_token)[0]
64107

65108
dataset = load_ARF_dataset(tokenizer)
66109

67-
trainer = Seq2SeqTrainer(
110+
def compute_metrics(eval_preds: EvalPrediction) -> dict[str, float]:
111+
eval_preds.label_ids[eval_preds.label_ids == -100] = pad_token_i
112+
113+
labels_str = tokenizer.batch_decode(
114+
eval_preds.label_ids, skip_special_tokens=True
115+
)
116+
labels = list(map(GenerativeRelationExtractor.parse_text_relations, labels_str))
117+
118+
pred_ids = eval_preds.predictions[0].argmax(axis=-1)
119+
preds_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
120+
preds = list(map(GenerativeRelationExtractor.parse_text_relations, preds_str))
121+
122+
return _triple_precision_recall_f1(labels, preds)
123+
124+
trainer = Trainer(
68125
model,
69126
targs,
70127
train_dataset=dataset["train"],
71128
eval_dataset=dataset["test"],
72-
data_collator=DataCollatorForSeq2Seq(tokenizer),
129+
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
130+
compute_metrics=compute_metrics,
73131
)
74132
trainer.train()
75133

76134
return model
77135

78136

79-
class T5RelationExtractor(PipelineStep):
137+
class GenerativeRelationExtractor(PipelineStep):
80138
DEFAULT_MODEL = "compnet-renard/t5-small-literary-relation-extraction"
81139

82140
def __init__(
@@ -85,7 +143,9 @@ def __init__(
85143
batch_size: int = 1,
86144
device: Literal["cpu", "cuda", "auto"] = "auto",
87145
):
88-
self.model = T5RelationExtractor.DEFAULT_MODEL if model is None else model
146+
self.model = (
147+
GenerativeRelationExtractor.DEFAULT_MODEL if model is None else model
148+
)
89149
self.hg_pipeline = None
90150
self.batch_size = batch_size
91151
if device == "auto":
@@ -96,7 +156,7 @@ def __init__(
96156
def _pipeline_init_(self, lang: str, progress_reporter: ProgressReporter, **kwargs):
97157
super()._pipeline_init_(lang, progress_reporter, **kwargs)
98158
self.hg_pipeline = hg_pipeline(
99-
"text2text-generation", model=self.model, device=self.device
159+
"text-generation", model=self.model, device=self.device
100160
)
101161

102162
def __call__(
@@ -108,19 +168,32 @@ def __call__(
108168

109169
# chunk as in the ARF dataset
110170
dataset = HGDataset.from_list(
111-
[{"text": T5RelationExtractor.task_prompt(sent)} for sent in sentences]
171+
[
172+
{
173+
"text": self.hg_pipeline.tokenizer.apply_chat_template(
174+
(GenerativeRelationExtractor.task_prompt(" ".join(sent)))
175+
)
176+
}
177+
for sent in sentences
178+
]
112179
)
113180
for out in self._progress_(
114181
self.hg_pipeline(KeyDataset(dataset, "text"), batch_size=self.batch_size),
115182
total=len(dataset),
116183
):
117184
text_relations = out[0]["generated_text"]
118185

119-
raw_triples = T5RelationExtractor.parse_t5_text_relations(text_relations)
186+
raw_triples = GenerativeRelationExtractor.parse_text_relations(
187+
text_relations
188+
)
120189
triples = []
121190
for subj, rel, obj in raw_triples:
122-
subj_char = T5RelationExtractor.identify_character(subj, characters)
123-
obj_char = T5RelationExtractor.identify_character(obj, characters)
191+
subj_char = GenerativeRelationExtractor.identify_character(
192+
subj, characters
193+
)
194+
obj_char = GenerativeRelationExtractor.identify_character(
195+
obj, characters
196+
)
124197
if subj_char is None or obj_char is None or subj_char == obj_char:
125198
continue
126199
triples.append((subj_char, rel, obj_char))
@@ -129,12 +202,14 @@ def __call__(
129202
return {"sentence_relations": sentence_relations}
130203

131204
@staticmethod
132-
def task_prompt(sentence: list[str]) -> str:
133-
sent_text = " ".join(sentence)
134-
return f"extract relations: {sent_text}"
205+
def task_prompt(text: str) -> list[dict]:
206+
return [
207+
{"role": "system", "content": "Extract relations from the given text."},
208+
{"role": "user", "content": text},
209+
]
135210

136211
@staticmethod
137-
def parse_t5_text_relations(text_relations: str) -> list[tuple[str, str, str]]:
212+
def parse_text_relations(text_relations: str) -> list[tuple[str, str, str]]:
138213
return re.findall(r"\(([^,]+), ([^,]+), ([^,]+)\)", text_relations)
139214

140215
@staticmethod

renard/utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,3 +132,21 @@ def charbb2tokenbb(char_bb: BlockBounds, char2token: List[int]) -> BlockBounds:
132132
for char_block_start, char_block_end in char_bb[0]:
133133
tokens_blocks.append((char2token[char_block_start], char2token[char_block_end]))
134134
return (tokens_blocks, "tokens")
135+
136+
137+
def make_vocab(elements: list[T]) -> dict[T, int]:
138+
"""Create a vocabulary from a list of elements.
139+
140+
:return: a dictionary mapping each element to its index in the
141+
vocabulary
142+
"""
143+
vocab = {}
144+
next_index = 0
145+
146+
for elt in elements:
147+
vocab[elt] = vocab.get(elt, next_index)
148+
elt_was_just_added = vocab[elt] == next_index
149+
if elt_was_just_added:
150+
next_index += 1
151+
152+
return vocab

uv.lock

Lines changed: 14 additions & 12 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)