Skip to content

Commit d5ff35d

Browse files
committed
WIP: moving toward use of t5gemma for relation extraction
1 parent 31b8161 commit d5ff35d

File tree

3 files changed

+85
-26
lines changed

3 files changed

+85
-26
lines changed

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ readme = "README.md"
1010
requires-python = ">=3.9,<3.13"
1111
dependencies = [
1212
"torch>=2.7.0",
13-
"transformers>=4.56.1",
13+
"transformers>=4.57.1",
1414
"nltk>=3.9.1",
1515
"tqdm>=4.67.1",
1616
"networkx>=3.2",
@@ -24,6 +24,8 @@ dependencies = [
2424
"rank-bm25>=0.2.2",
2525
"accelerate>=1.10.1",
2626
"scikit-learn[ui]>=1.6.1",
27+
"tiktoken>=0.12.0",
28+
"protobuf>=6.33.2",
2729
]
2830

2931
[build-system]

renard/pipeline/relation_extraction.py

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,11 @@
44
from datasets import load_dataset, Dataset as HGDataset
55
import torch
66
from transformers import (
7-
AutoModelForCausalLM,
7+
AutoModelForSeq2SeqLM,
88
AutoTokenizer,
99
Trainer,
1010
TrainingArguments,
1111
PreTrainedTokenizerFast,
12-
DataCollatorForLanguageModeling,
1312
DataCollatorForSeq2Seq,
1413
PreTrainedModel,
1514
EvalPrediction,
@@ -34,13 +33,14 @@ def format_rel(rel: dict) -> str:
3433
return "({}, {}, {})".format(rel["entity1"], rel["relation"], rel["entity2"])
3534

3635
labels = " ".join(map(format_rel, example["relations"]))
37-
answer = {"role": "assistant", "content": labels}
38-
example["labels"] = tokenizer.apply_chat_template([answer])
36+
with tokenizer.as_target_tokenizer():
37+
labels_batch = tokenizer(labels)
38+
example["labels"] = labels_batch["input_ids"]
3939

4040
text = example["chunk"] or ""
41-
example["input_ids"] = tokenizer.apply_chat_template(
42-
GenerativeRelationExtractor.task_prompt(text)
43-
)
41+
example["input_ids"] = tokenizer(GenerativeRelationExtractor.task_prompt(text))[
42+
"input_ids"
43+
]
4444

4545
return example
4646

@@ -100,7 +100,7 @@ def train_model_on_ARF(
100100
if isinstance(model, str):
101101
assert tokenizer is None
102102
tokenizer = AutoTokenizer.from_pretrained(model)
103-
model = AutoModelForCausalLM.from_pretrained(model)
103+
model = AutoModelForSeq2SeqLM.from_pretrained(model)
104104
assert not tokenizer is None
105105
tokenizer.pad_token = tokenizer.eos_token
106106
pad_token_i = tokenizer.encode(tokenizer.pad_token)[0]
@@ -126,7 +126,7 @@ def compute_metrics(eval_preds: EvalPrediction) -> dict[str, float]:
126126
targs,
127127
train_dataset=dataset["train"],
128128
eval_dataset=dataset["test"],
129-
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
129+
data_collator=DataCollatorForSeq2Seq(tokenizer),
130130
compute_metrics=compute_metrics,
131131
)
132132
trainer.train()
@@ -156,7 +156,10 @@ def __init__(
156156
def _pipeline_init_(self, lang: str, progress_reporter: ProgressReporter, **kwargs):
157157
super()._pipeline_init_(lang, progress_reporter, **kwargs)
158158
self.hg_pipeline = hg_pipeline(
159-
"text-generation", model=self.model, device=self.device
159+
"text2text-generation",
160+
torch_dtype=torch.bfloat16,
161+
model=self.model,
162+
device=self.device,
160163
)
161164

162165
def __call__(
@@ -169,11 +172,7 @@ def __call__(
169172
# chunk as in the ARF dataset
170173
dataset = HGDataset.from_list(
171174
[
172-
{
173-
"text": self.hg_pipeline.tokenizer.apply_chat_template(
174-
(GenerativeRelationExtractor.task_prompt(" ".join(sent)))
175-
)
176-
}
175+
{"text": GenerativeRelationExtractor.task_prompt(" ".join(sent))}
177176
for sent in sentences
178177
]
179178
)
@@ -202,11 +201,8 @@ def __call__(
202201
return {"sentence_relations": sentence_relations}
203202

204203
@staticmethod
205-
def task_prompt(text: str) -> list[dict]:
206-
return [
207-
{"role": "system", "content": "Extract relations from the given text."},
208-
{"role": "user", "content": text},
209-
]
204+
def task_prompt(text: str) -> str:
205+
return f"Extract relations from the given text: {text}"
210206

211207
@staticmethod
212208
def parse_text_relations(text_relations: str) -> list[tuple[str, str, str]]:

0 commit comments

Comments
 (0)