44from datasets import load_dataset , Dataset as HGDataset
55import torch
66from transformers import (
7- AutoModelForSeq2SeqLM ,
8- T5ForConditionalGeneration ,
7+ AutoModelForCausalLM ,
98 AutoTokenizer ,
10- Seq2SeqTrainer ,
11- Seq2SeqTrainingArguments ,
9+ Trainer ,
10+ TrainingArguments ,
1211 PreTrainedTokenizerFast ,
12+ DataCollatorForLanguageModeling ,
1313 DataCollatorForSeq2Seq ,
1414 PreTrainedModel ,
15+ EvalPrediction ,
1516 pipeline as hg_pipeline ,
1617)
18+ from more_itertools import flatten
1719from transformers .pipelines .pt_utils import KeyDataset
1820from renard .pipeline .core import PipelineStep
1921from renard .pipeline .progress import ProgressReporter
2022from renard .pipeline .character_unification import Character
23+ from renard .utils import make_vocab
24+ from sklearn .metrics import precision_recall_fscore_support
2125
2226#: (subject, relation, object)
2327Relation = tuple [Character , str , Character ]
@@ -30,13 +34,13 @@ def format_rel(rel: dict) -> str:
3034 return "({}, {}, {})" .format (rel ["entity1" ], rel ["relation" ], rel ["entity2" ])
3135
3236 labels = " " .join (map (format_rel , example ["relations" ]))
33- with tokenizer .as_target_tokenizer ():
34- labels_batch = tokenizer (labels )
35- example ["labels" ] = labels_batch ["input_ids" ]
37+ answer = {"role" : "assistant" , "content" : labels }
38+ example ["labels" ] = tokenizer .apply_chat_template ([answer ])
3639
3740 text = example ["chunk" ] or ""
38- text = f"extract relations: { text } "
39- example ["input_ids" ] = tokenizer (text )["input_ids" ]
41+ example ["input_ids" ] = tokenizer .apply_chat_template (
42+ GenerativeRelationExtractor .task_prompt (text )
43+ )
4044
4145 return example
4246
@@ -52,31 +56,85 @@ def load_ARF_dataset(tokenizer: PreTrainedTokenizerFast) -> HGDataset:
5256 "synthetic_relations_in_fiction_books" ,
5357 split = "train" ,
5458 )
55- dataset = dataset .train_test_split (test_size = 0.1 )
59+ dataset = dataset .train_test_split (test_size = 0.001 )
5660 return dataset .map (ft .partial (_load_ARF_line , tokenizer = tokenizer ))
5761
5862
59- def train_t5_on_ARF (
60- t5_hg_id : str , targs : Seq2SeqTrainingArguments
61- ) -> T5ForConditionalGeneration :
62- tokenizer = AutoTokenizer .from_pretrained (t5_hg_id )
63- model = AutoModelForSeq2SeqLM .from_pretrained (t5_hg_id )
63+ def _triple_precision_recall_f1 (
64+ references : list [list [tuple [str , str , str ]]],
65+ predictions : list [list [tuple [str , str , str ]]],
66+ ) -> dict [str , float ]:
67+ triple_vocab = make_vocab (list (flatten (references )) + list (flatten (predictions )))
68+
69+ # the "null triple" indicates no prediction (or no reference
70+ # available), useful to compute precision/recall.
71+ null_triple_index = max (triple_vocab .values ()) + 1
72+
73+ y , y_hat = [], []
74+ for ref , pred in zip (references , predictions ):
75+ ref = {triple : triple_vocab [triple ] for triple in ref }
76+ pred = {triple : triple_vocab [triple ] for triple in pred }
77+ for ref_triple , ref_index in ref .items ():
78+ y .append (ref_index )
79+ y_hat .append (pred .get (ref_triple , null_triple_index ))
80+ try :
81+ del pred [ref_triple ]
82+ except KeyError :
83+ pass
84+ for pred_triple , pred_index in pred .items ():
85+ y_hat .append (pred_index )
86+ y .append (ref .get (pred_triple , null_triple_index ))
87+
88+ precision , recall , f1 , _ = precision_recall_fscore_support (
89+ y , y_hat , labels = list (triple_vocab .values ()), average = "micro"
90+ )
91+
92+ return {"precision" : float (precision ), "recall" : float (recall ), "f1" : float (f1 )}
93+
94+
95+ def train_model_on_ARF (
96+ model : str | PreTrainedModel ,
97+ targs : TrainingArguments ,
98+ tokenizer : PreTrainedTokenizerFast | None = None ,
99+ ) -> PreTrainedModel :
100+ if isinstance (model , str ):
101+ assert tokenizer is None
102+ tokenizer = AutoTokenizer .from_pretrained (model )
103+ model = AutoModelForCausalLM .from_pretrained (model )
104+ assert not tokenizer is None
105+ tokenizer .pad_token = tokenizer .eos_token
106+ pad_token_i = tokenizer .encode (tokenizer .pad_token )[0 ]
64107
65108 dataset = load_ARF_dataset (tokenizer )
66109
67- trainer = Seq2SeqTrainer (
110+ def compute_metrics (eval_preds : EvalPrediction ) -> dict [str , float ]:
111+ eval_preds .label_ids [eval_preds .label_ids == - 100 ] = pad_token_i
112+
113+ labels_str = tokenizer .batch_decode (
114+ eval_preds .label_ids , skip_special_tokens = True
115+ )
116+ labels = list (map (GenerativeRelationExtractor .parse_text_relations , labels_str ))
117+
118+ pred_ids = eval_preds .predictions [0 ].argmax (axis = - 1 )
119+ preds_str = tokenizer .batch_decode (pred_ids , skip_special_tokens = True )
120+ preds = list (map (GenerativeRelationExtractor .parse_text_relations , preds_str ))
121+
122+ return _triple_precision_recall_f1 (labels , preds )
123+
124+ trainer = Trainer (
68125 model ,
69126 targs ,
70127 train_dataset = dataset ["train" ],
71128 eval_dataset = dataset ["test" ],
72- data_collator = DataCollatorForSeq2Seq (tokenizer ),
129+ data_collator = DataCollatorForLanguageModeling (tokenizer , mlm = False ),
130+ compute_metrics = compute_metrics ,
73131 )
74132 trainer .train ()
75133
76134 return model
77135
78136
79- class T5RelationExtractor (PipelineStep ):
137+ class GenerativeRelationExtractor (PipelineStep ):
80138 DEFAULT_MODEL = "compnet-renard/t5-small-literary-relation-extraction"
81139
82140 def __init__ (
@@ -85,7 +143,9 @@ def __init__(
85143 batch_size : int = 1 ,
86144 device : Literal ["cpu" , "cuda" , "auto" ] = "auto" ,
87145 ):
88- self .model = T5RelationExtractor .DEFAULT_MODEL if model is None else model
146+ self .model = (
147+ GenerativeRelationExtractor .DEFAULT_MODEL if model is None else model
148+ )
89149 self .hg_pipeline = None
90150 self .batch_size = batch_size
91151 if device == "auto" :
@@ -96,7 +156,7 @@ def __init__(
96156 def _pipeline_init_ (self , lang : str , progress_reporter : ProgressReporter , ** kwargs ):
97157 super ()._pipeline_init_ (lang , progress_reporter , ** kwargs )
98158 self .hg_pipeline = hg_pipeline (
99- "text2text -generation" , model = self .model , device = self .device
159+ "text -generation" , model = self .model , device = self .device
100160 )
101161
102162 def __call__ (
@@ -108,19 +168,32 @@ def __call__(
108168
109169 # chunk as in the ARF dataset
110170 dataset = HGDataset .from_list (
111- [{"text" : T5RelationExtractor .task_prompt (sent )} for sent in sentences ]
171+ [
172+ {
173+ "text" : self .hg_pipeline .tokenizer .apply_chat_template (
174+ (GenerativeRelationExtractor .task_prompt (" " .join (sent )))
175+ )
176+ }
177+ for sent in sentences
178+ ]
112179 )
113180 for out in self ._progress_ (
114181 self .hg_pipeline (KeyDataset (dataset , "text" ), batch_size = self .batch_size ),
115182 total = len (dataset ),
116183 ):
117184 text_relations = out [0 ]["generated_text" ]
118185
119- raw_triples = T5RelationExtractor .parse_t5_text_relations (text_relations )
186+ raw_triples = GenerativeRelationExtractor .parse_text_relations (
187+ text_relations
188+ )
120189 triples = []
121190 for subj , rel , obj in raw_triples :
122- subj_char = T5RelationExtractor .identify_character (subj , characters )
123- obj_char = T5RelationExtractor .identify_character (obj , characters )
191+ subj_char = GenerativeRelationExtractor .identify_character (
192+ subj , characters
193+ )
194+ obj_char = GenerativeRelationExtractor .identify_character (
195+ obj , characters
196+ )
124197 if subj_char is None or obj_char is None or subj_char == obj_char :
125198 continue
126199 triples .append ((subj_char , rel , obj_char ))
@@ -129,12 +202,14 @@ def __call__(
129202 return {"sentence_relations" : sentence_relations }
130203
131204 @staticmethod
132- def task_prompt (sentence : list [str ]) -> str :
133- sent_text = " " .join (sentence )
134- return f"extract relations: { sent_text } "
205+ def task_prompt (text : str ) -> list [dict ]:
206+ return [
207+ {"role" : "system" , "content" : "Extract relations from the given text." },
208+ {"role" : "user" , "content" : text },
209+ ]
135210
136211 @staticmethod
137- def parse_t5_text_relations (text_relations : str ) -> list [tuple [str , str , str ]]:
212+ def parse_text_relations (text_relations : str ) -> list [tuple [str , str , str ]]:
138213 return re .findall (r"\(([^,]+), ([^,]+), ([^,]+)\)" , text_relations )
139214
140215 @staticmethod
0 commit comments