-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Expand file tree
/
Copy pathconvert.py
More file actions
417 lines (345 loc) · 13.2 KB
/
convert.py
File metadata and controls
417 lines (345 loc) · 13.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
import transformers # type: ignore[import]
from tokenizers.implementations import SentencePieceUnigramTokenizer, BaseTokenizer
from tokenizers.processors import TemplateProcessing
from tokenizers.models import Unigram, BPE
from tokenizers import decoders
from tokenizers import Tokenizer, Regex
from tokenizers.normalizers import (
StripAccents,
NFKD,
Lowercase,
Sequence,
BertNormalizer,
Precompiled,
Replace,
)
from tokenizers.pre_tokenizers import (
Digits,
WhitespaceSplit,
Metaspace,
Sequence as PSequence,
)
import json
import unicodedata
import sys
import os
import datetime
import argparse
sys.path.append(".")
from spm_parity_check import check_details # type: ignore[import]
from sentencepiece_extractor import SentencePieceExtractor # type: ignore[import]
def check_number_comma(piece: str) -> bool:
return len(piece) < 2 or piece[-1] != "," or not piece[-2].isdigit()
def get_proto(filename: str):
try:
import sys
sys.path.append(".")
import sentencepiece_model_pb2 as model # type: ignore[import]
except Exception:
raise Exception(
"You don't seem to have the required protobuf file, in order to use this function you need to run `pip install protobuf` and `wget https://raw.githubusercontent.com/google/sentencepiece/master/python/sentencepiece_model_pb2.py` for us to be able to read the intrinsics of your spm_file. `pip install sentencepiece` is not required."
)
m = model.ModelProto()
m.ParseFromString(open(filename, "rb").read())
return m
class Converter:
def __init__(self, original_tokenizer):
self.original_tokenizer = original_tokenizer
def converted(self) -> Tokenizer:
raise NotImplementedError()
class SpmConverter(Converter):
def __init__(self, *args):
super().__init__(*args)
self.proto = get_proto(self.original_tokenizer.vocab_file)
def vocab(self, proto):
return [(piece.piece, piece.score) for piece in proto.pieces]
def unk_id(self, proto):
return proto.trainer_spec.unk_id
def tokenizer(self, proto):
model_type = proto.trainer_spec.model_type
vocab = self.vocab(proto)
unk_id = self.unk_id(proto)
if model_type == 1:
tokenizer = Tokenizer(Unigram(vocab, unk_id))
elif model_type == 2:
vocab, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract()
tokenizer = Tokenizer(BPE(vocab, merges, unk_token=proto.trainer_spec.unk_piece, fuse_unk=True))
else:
raise Exception(
"You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
)
return tokenizer
def normalizer(self, proto):
precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
return Sequence([Precompiled(precompiled_charsmap), Replace(Regex(" {2,}"), " ")])
def post_processor(self, tokenizer):
return None
def converted(self):
tokenizer = self.tokenizer(self.proto)
# Tokenizer assemble
tokenizer.normalizer = self.normalizer(self.proto)
replacement = "▁"
prepend_scheme = "always"
tokenizer.pre_tokenizer = Metaspace(replacement=replacement, prepend_scheme=prepend_scheme)
tokenizer.decoder = decoders.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme)
post_processor = self.post_processor(tokenizer)
if post_processor:
tokenizer.post_processor = post_processor
# TODO what parameters should we give ?
parameters = {}
return BaseTokenizer(tokenizer, parameters)
class AlbertConverter(SpmConverter):
def vocab(self, proto) -> list[tuple[str, float]]:
return [
(piece.piece, piece.score) if check_number_comma(piece.piece) else (piece.piece, piece.score - 100)
for piece in proto.pieces
]
def normalizer(self, proto) -> Sequence:
normalizers = [Replace("``", '"'), Replace("''", '"')]
if not self.original_tokenizer.keep_accents:
normalizers.append(NFKD())
normalizers.append(StripAccents())
if self.original_tokenizer.do_lower_case:
normalizers.append(Lowercase())
precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
normalizers.append(Precompiled(precompiled_charsmap))
normalizers.append(Replace(Regex(" {2,}"), " "))
return Sequence(normalizers)
def post_processor(self, tokenizer) -> TemplateProcessing:
return TemplateProcessing(
single=["[CLS]", "$0", "[SEP]"],
pair=["$1", "[SEP]"],
special_tokens=[
("[CLS]", tokenizer.get_vocab()["[CLS]"]),
("[SEP]", tokenizer.get_vocab()["[SEP]"]),
],
)
class CamembertConverter(SpmConverter):
def vocab(self, proto) -> list[tuple[str, float]]:
vocab = [
("<s>NOTUSED", 0.0),
("<pad>", 0.0),
("</s>NOTUSED", 0.0),
("<unk>", 0.0),
]
vocab += [(piece.piece, piece.score) for piece in proto.pieces]
return vocab
def unk_id(self, proto) -> int:
# See vocab unk position
return 3
def post_processor(self, tokenizer) -> TemplateProcessing:
return TemplateProcessing(
single=["<s>", "$0", "</s>"],
pair=["$1", "</s>"],
special_tokens=[
("<s>", tokenizer.get_vocab()["<s>"]),
("</s>", tokenizer.get_vocab()["</s>"]),
],
)
class MBartConverter(SpmConverter):
def vocab(self, proto) -> list[tuple[str, float]]:
vocab = [
("<s>", 0.0),
("<pad>", 0.0),
("</s>", 0.0),
("<unk>", 0.0),
]
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
vocab += [
("ar_AR", 0.0),
("cs_CZ", 0.0),
("de_DE", 0.0),
("en_XX", 0.0),
("es_XX", 0.0),
("et_EE", 0.0),
("fi_FI", 0.0),
("fr_XX", 0.0),
("gu_IN", 0.0),
("hi_IN", 0.0),
("it_IT", 0.0),
("ja_XX", 0.0),
("kk_KZ", 0.0),
("ko_KR", 0.0),
("lt_LT", 0.0),
("lv_LV", 0.0),
("my_MM", 0.0),
("ne_NP", 0.0),
("nl_XX", 0.0),
("ro_RO", 0.0),
("ru_RU", 0.0),
("si_LK", 0.0),
("tr_TR", 0.0),
("vi_VN", 0.0),
("zh_CN", 0.0),
]
return vocab
def unk_id(self, proto) -> int:
return 3
def post_processor(self, tokenizer) -> TemplateProcessing:
return TemplateProcessing(
single=["$0", "</s>", "en_XX"],
pair=["$1", "</s>"],
special_tokens=[
("en_XX", tokenizer.get_vocab()["en_XX"]),
("</s>", tokenizer.get_vocab()["</s>"]),
],
)
class XLMRobertaConverter(SpmConverter):
def vocab(self, proto) -> list[tuple[str, float]]:
vocab = [
("<s>", 0.0),
("<pad>", 0.0),
("</s>", 0.0),
("<unk>", 0.0),
]
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
return vocab
def unk_id(self, proto) -> int:
unk_id = 3
return unk_id
def post_processor(self, tokenizer) -> TemplateProcessing:
return TemplateProcessing(
single=["<s>", "$0", "</s>"],
pair=["$1", "</s>"],
special_tokens=[
("<s>", tokenizer.get_vocab()["<s>"]),
("</s>", tokenizer.get_vocab()["</s>"]),
],
)
class XLNetConverter(SpmConverter):
def vocab(self, proto) -> list[tuple[str, float]]:
return [
(piece.piece, piece.score) if check_number_comma(piece.piece) else (piece.piece, piece.score - 100)
for piece in proto.pieces
]
def normalizer(self, proto) -> Sequence:
normalizers = [Replace("``", '"'), Replace("''", '"')]
if not self.original_tokenizer.keep_accents:
normalizers.append(NFKD())
normalizers.append(StripAccents())
if self.original_tokenizer.do_lower_case:
normalizers.append(Lowercase())
precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
normalizers.append(Precompiled(precompiled_charsmap))
normalizers.append(Replace(Regex(" {2,}"), " "))
return Sequence(normalizers)
def post_processor(self, tokenizer) -> TemplateProcessing:
return TemplateProcessing(
single=["$0", "<sep>", "<cls>"],
pair=["$1", "<sep>"],
special_tokens=[
("<sep>", tokenizer.get_vocab()["<sep>"]),
("<cls>", tokenizer.get_vocab()["<cls>"]),
],
)
class ReformerConverter(SpmConverter):
pass
class PegasusConverter(SpmConverter):
offset = 103
def vocab(self, proto) -> list[tuple[str, float]]:
vocab = [
(self.original_tokenizer.pad_token, 0),
(self.original_tokenizer.eos_token, 0),
]
vocab += [(f"unk_{i}", -100) for i in range(2, 2 + self.offset)]
vocab += [(piece.piece, piece.score) for piece in proto.pieces[2:]]
return vocab
def unk_id(self, proto) -> int:
return proto.trainer_spec.unk_id + self.offset
def post_processor(self, tokenizer) -> TemplateProcessing:
eos = self.original_tokenizer.eos_token
return TemplateProcessing(
single=["$0", eos],
pair=["$1", eos],
special_tokens=[(eos, tokenizer.get_vocab()[eos])],
)
class T5Converter(SpmConverter):
def post_processor(self, tokenizer) -> TemplateProcessing:
return TemplateProcessing(
single=["$0", "</s>"],
pair=["$1", "</s>"],
special_tokens=[("</s>", tokenizer.get_vocab()["</s>"])],
)
CONVERTERS = {
"AlbertTokenizer": AlbertConverter,
"CamembertTokenizer": CamembertConverter,
"XLMRobertaTokenizer": XLMRobertaConverter,
"MBartTokenizer": MBartConverter,
"XLNetTokenizer": XLNetConverter,
"ReformerTokenizer": ReformerConverter,
"PegasusTokenizer": PegasusConverter,
"T5Tokenizer": T5Converter,
}
def check(pretrained, filename) -> tuple[str, datetime.timedelta]:
transformer_tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained)
converter_class = CONVERTERS[transformer_tokenizer.__class__.__name__]
tokenizer = converter_class(transformer_tokenizer).converted()
now = datetime.datetime.now
trans_total_time = datetime.timedelta(seconds=0)
tok_total_time = datetime.timedelta(seconds=0)
with open(filename, "r") as f:
for i, line in enumerate(f):
line = line.strip()
start = now()
ids = transformer_tokenizer.encode(line)
trans = now()
tok_ids = tokenizer.encode(line).ids
tok = now()
trans_total_time += trans - start
tok_total_time += tok - trans
if ids != tok_ids:
if check_details(line, ids, tok_ids, transformer_tokenizer, tokenizer):
continue
assert ids == tok_ids, f"Error in line {i}: {line} {ids} != {tok_ids}"
tokenizer.save(f"{pretrained.replace('/', '-')}.json")
return ("OK", trans_total_time / tok_total_time)
def main():
pretraineds = [
"albert-base-v1",
"albert-large-v1",
"albert-xlarge-v1",
"albert-xxlarge-v1",
"albert-base-v2",
"albert-large-v2",
"albert-xlarge-v2",
"albert-xxlarge-v2",
"camembert-base",
"xlm-roberta-base",
"xlm-roberta-large",
"xlm-roberta-large-finetuned-conll02-dutch",
"xlm-roberta-large-finetuned-conll02-spanish",
"xlm-roberta-large-finetuned-conll03-english",
"xlm-roberta-large-finetuned-conll03-german",
"facebook/mbart-large-en-ro",
"facebook/mbart-large-cc25",
"xlnet-base-cased",
"xlnet-large-cased",
"google/reformer-crime-and-punishment",
"t5-small",
"google/pegasus-large",
]
parser = argparse.ArgumentParser()
parser.add_argument(
"--filename",
required=True,
type=str,
help="The filename that we are going to encode in both versions to check that conversion worked",
)
parser.add_argument(
"--models",
type=lambda s: s.split(","),
default=pretraineds,
help=f"The pretrained tokenizers you want to test against, (default: {pretraineds})",
)
args = parser.parse_args()
print(args.filename)
model_len = 50
status_len = 6
speedup_len = 8
print(f"|{'Model':^{model_len}}|{'Status':^{status_len}}|{'Speedup':^{speedup_len}}|")
print(f"|{'-' * model_len}|{'-' * status_len}|{'-' * speedup_len}|")
for pretrained in args.models:
status, speedup = check(pretrained, args.filename)
print(f"|{pretrained:<{model_len}}|{status:^{status_len}}|{speedup:^{speedup_len - 1}.2f}x|")
if __name__ == "__main__":
main()