|
26 | 26 | #: (subject, relation, object) |
27 | 27 | Relation = tuple[Character, str, Character] |
28 | 28 |
|
29 | | -seq2seq_special_tokens = ["<triplet>", "<subj>", "<rel>", "<obj>", "</triplet>"] |
30 | | - |
31 | 29 |
|
32 | 30 | def _load_ARF_line(example: dict, tokenizer: PreTrainedTokenizerFast) -> BatchEncoding: |
33 | 31 | relations = ast.literal_eval(example["relations"] or "[]") |
34 | 32 |
|
35 | 33 | def format_rel(rel: dict) -> str: |
36 | | - return "<triplet> <subj> {} <rel> {} <obj> {} </triplet>".format( |
37 | | - rel["entity1"], rel["relation"], rel["entity2"] |
38 | | - ) |
| 34 | + return "({}, {}, {})".format(rel["entity1"], rel["relation"], rel["entity2"]) |
39 | 35 |
|
40 | 36 | labels = " ".join(map(format_rel, relations)) |
41 | 37 |
|
@@ -108,7 +104,6 @@ def train_model_on_ARF( |
108 | 104 | assert not tokenizer is None |
109 | 105 | tokenizer.pad_token = tokenizer.eos_token |
110 | 106 | pad_token_i = tokenizer.encode(tokenizer.pad_token)[0] |
111 | | - tokenizer.add_special_tokens({"additional_special_tokens": seq2seq_special_tokens}) |
112 | 107 |
|
113 | 108 | dataset = load_ARF_dataset(tokenizer) |
114 | 109 |
|
@@ -207,12 +202,12 @@ def __call__( |
207 | 202 |
|
208 | 203 | @staticmethod |
209 | 204 | def task_prompt(text: str) -> str: |
210 | | - return f"Extract relations from the given text: {text}" |
| 205 | + return f"Extract triplets (subject, relation, object) from the given text: '{text}'" |
211 | 206 |
|
212 | 207 | @staticmethod |
213 | 208 | def parse_text_relations(text_relations: str) -> list[tuple[str, str, str]]: |
214 | 209 | triplets = re.findall( |
215 | | - r"<triplet> ?<subj>([^<]+)<rel>([^<]+)<obj>([^<]+)</triplet>", |
| 210 | + r"\(([^,]+), ?([^,]+), ?([^,]+)\)", |
216 | 211 | text_relations, |
217 | 212 | ) |
218 | 213 | triplets = [ |
|
0 commit comments