Skip to content

Commit 359d152

Browse files
committed
do not use additional special tokens for relation extraction
1 parent 0a0bab0 commit 359d152

File tree

1 file changed

+3
-8
lines changed

1 file changed

+3
-8
lines changed

renard/pipeline/relation_extraction.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,12 @@
2626
#: (subject, relation, object)
2727
Relation = tuple[Character, str, Character]
2828

29-
seq2seq_special_tokens = ["<triplet>", "<subj>", "<rel>", "<obj>", "</triplet>"]
30-
3129

3230
def _load_ARF_line(example: dict, tokenizer: PreTrainedTokenizerFast) -> BatchEncoding:
3331
relations = ast.literal_eval(example["relations"] or "[]")
3432

3533
def format_rel(rel: dict) -> str:
36-
return "<triplet> <subj> {} <rel> {} <obj> {} </triplet>".format(
37-
rel["entity1"], rel["relation"], rel["entity2"]
38-
)
34+
return "({}, {}, {})".format(rel["entity1"], rel["relation"], rel["entity2"])
3935

4036
labels = " ".join(map(format_rel, relations))
4137

@@ -108,7 +104,6 @@ def train_model_on_ARF(
108104
assert not tokenizer is None
109105
tokenizer.pad_token = tokenizer.eos_token
110106
pad_token_i = tokenizer.encode(tokenizer.pad_token)[0]
111-
tokenizer.add_special_tokens({"additional_special_tokens": seq2seq_special_tokens})
112107

113108
dataset = load_ARF_dataset(tokenizer)
114109

@@ -207,12 +202,12 @@ def __call__(
207202

208203
@staticmethod
209204
def task_prompt(text: str) -> str:
210-
return f"Extract relations from the given text: {text}"
205+
return f"Extract triplets (subject, relation, object) from the given text: '{text}'"
211206

212207
@staticmethod
213208
def parse_text_relations(text_relations: str) -> list[tuple[str, str, str]]:
214209
triplets = re.findall(
215-
r"<triplet> ?<subj>([^<]+)<rel>([^<]+)<obj>([^<]+)</triplet>",
210+
r"\(([^,]+), ?([^,]+), ?([^,]+)\)",
216211
text_relations,
217212
)
218213
triplets = [

0 commit comments

Comments
 (0)