44from datasets import load_dataset , Dataset as HGDataset
55import torch
66from transformers import (
7- AutoModelForCausalLM ,
7+ AutoModelForSeq2SeqLM ,
88 AutoTokenizer ,
99 Trainer ,
1010 TrainingArguments ,
1111 PreTrainedTokenizerFast ,
12- DataCollatorForLanguageModeling ,
1312 DataCollatorForSeq2Seq ,
1413 PreTrainedModel ,
1514 EvalPrediction ,
@@ -34,13 +33,14 @@ def format_rel(rel: dict) -> str:
3433 return "({}, {}, {})" .format (rel ["entity1" ], rel ["relation" ], rel ["entity2" ])
3534
3635 labels = " " .join (map (format_rel , example ["relations" ]))
37- answer = {"role" : "assistant" , "content" : labels }
38- example ["labels" ] = tokenizer .apply_chat_template ([answer ])
36+ with tokenizer .as_target_tokenizer ():
37+ labels_batch = tokenizer (labels )
38+ example ["labels" ] = labels_batch ["input_ids" ]
3939
4040 text = example ["chunk" ] or ""
41- example ["input_ids" ] = tokenizer . apply_chat_template (
42- GenerativeRelationExtractor . task_prompt ( text )
43- )
41+ example ["input_ids" ] = tokenizer ( GenerativeRelationExtractor . task_prompt ( text ))[
42+ "input_ids"
43+ ]
4444
4545 return example
4646
@@ -100,7 +100,7 @@ def train_model_on_ARF(
100100 if isinstance (model , str ):
101101 assert tokenizer is None
102102 tokenizer = AutoTokenizer .from_pretrained (model )
103- model = AutoModelForCausalLM .from_pretrained (model )
103+ model = AutoModelForSeq2SeqLM .from_pretrained (model )
104104 assert not tokenizer is None
105105 tokenizer .pad_token = tokenizer .eos_token
106106 pad_token_i = tokenizer .encode (tokenizer .pad_token )[0 ]
@@ -126,7 +126,7 @@ def compute_metrics(eval_preds: EvalPrediction) -> dict[str, float]:
126126 targs ,
127127 train_dataset = dataset ["train" ],
128128 eval_dataset = dataset ["test" ],
129- data_collator = DataCollatorForLanguageModeling (tokenizer , mlm = False ),
129+ data_collator = DataCollatorForSeq2Seq (tokenizer ),
130130 compute_metrics = compute_metrics ,
131131 )
132132 trainer .train ()
@@ -156,7 +156,10 @@ def __init__(
156156 def _pipeline_init_ (self , lang : str , progress_reporter : ProgressReporter , ** kwargs ):
157157 super ()._pipeline_init_ (lang , progress_reporter , ** kwargs )
158158 self .hg_pipeline = hg_pipeline (
159- "text-generation" , model = self .model , device = self .device
159+ "text2text-generation" ,
160+ torch_dtype = torch .bfloat16 ,
161+ model = self .model ,
162+ device = self .device ,
160163 )
161164
162165 def __call__ (
@@ -169,11 +172,7 @@ def __call__(
169172 # chunk as in the ARF dataset
170173 dataset = HGDataset .from_list (
171174 [
172- {
173- "text" : self .hg_pipeline .tokenizer .apply_chat_template (
174- (GenerativeRelationExtractor .task_prompt (" " .join (sent )))
175- )
176- }
175+ {"text" : GenerativeRelationExtractor .task_prompt (" " .join (sent ))}
177176 for sent in sentences
178177 ]
179178 )
@@ -202,11 +201,8 @@ def __call__(
202201 return {"sentence_relations" : sentence_relations }
203202
204203 @staticmethod
205- def task_prompt (text : str ) -> list [dict ]:
206- return [
207- {"role" : "system" , "content" : "Extract relations from the given text." },
208- {"role" : "user" , "content" : text },
209- ]
204+ def task_prompt (text : str ) -> str :
205+ return f"Extract relations from the given text: { text } "
210206
211207 @staticmethod
212208 def parse_text_relations (text_relations : str ) -> list [tuple [str , str , str ]]:
0 commit comments