Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
yiming-wange
GitHub Repository: yiming-wange/cs224n-2023-solution
Path: blob/main/minBERT/tokenizer.py
984 views
1
from typing import List, Optional, Tuple, Dict, Union, Any, overload, Sequence, NamedTuple
2
import collections
3
import os
4
import re
5
import unicodedata
6
import itertools
7
import requests
8
import copy
9
import json
10
from contextlib import contextmanager
11
from collections import OrderedDict, UserDict
12
from enum import Enum
13
import numpy as np
14
from utils import cached_path, hf_bucket_url, is_remote_url, is_tf_available, is_torch_available
15
from tokenizers import AddedToken
16
from tokenizers import Encoding as EncodingFast
17
18
19
VERY_LARGE_INTEGER = int(1e30) # This is used to set the max input length for a model with infinite size input
20
LARGE_INTEGER = int(1e20) # This is used when we need something big but slightly smaller than VERY_LARGE_INTEGER
21
22
SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json"
23
ADDED_TOKENS_FILE = "added_tokens.json"
24
TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
25
FULL_TOKENIZER_FILE = "tokenizer.json"
26
27
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
28
PRETRAINED_VOCAB_FILES_MAP = {
29
"vocab_file": {
30
"bert-base-uncased": "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt"
31
}
32
}
33
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
34
"bert-base-uncased": 512
35
}
36
PRETRAINED_INIT_CONFIGURATION = {
37
"bert-base-uncased": {"do_lower_case": True}
38
}
39
40
41
TextInput = str
42
PreTokenizedInput = List[str]
43
EncodedInput = List[int]
44
TextInputPair = Tuple[str, str]
45
PreTokenizedInputPair = Tuple[List[str], List[str]]
46
EncodedInputPair = Tuple[List[int], List[int]]
47
48
49
class ExplicitEnum(Enum):
50
@classmethod
51
def _missing_(cls, value):
52
raise ValueError(
53
"%r is not a valid %s, please select one of %s"
54
% (value, cls.__name__, str(list(cls._value2member_map_.keys())))
55
)
56
57
58
class TruncationStrategy(ExplicitEnum):
59
ONLY_FIRST = "only_first"
60
ONLY_SECOND = "only_second"
61
LONGEST_FIRST = "longest_first"
62
DO_NOT_TRUNCATE = "do_not_truncate"
63
64
65
class PaddingStrategy(ExplicitEnum):
66
LONGEST = "longest"
67
MAX_LENGTH = "max_length"
68
DO_NOT_PAD = "do_not_pad"
69
70
71
class TensorType(ExplicitEnum):
72
PYTORCH = "pt"
73
TENSORFLOW = "tf"
74
NUMPY = "np"
75
JAX = "jax"
76
77
78
class CharSpan(NamedTuple):
79
start: int
80
end: int
81
82
83
class TokenSpan(NamedTuple):
84
start: int
85
end: int
86
87
88
def to_py_obj(obj):
89
"""
90
Convert a TensorFlow tensor, PyTorch tensor, Numpy array or python list to a python list.
91
"""
92
if isinstance(obj, (dict, BatchEncoding)):
93
return {k: to_py_obj(v) for k, v in obj.items()}
94
elif isinstance(obj, (list, tuple)):
95
return [to_py_obj(o) for o in obj]
96
elif is_tf_available() and _is_tensorflow(obj):
97
return obj.numpy().tolist()
98
elif is_torch_available() and _is_torch(obj):
99
return obj.detach().cpu().tolist()
100
elif isinstance(obj, np.ndarray):
101
return obj.tolist()
102
else:
103
return obj
104
105
106
def _is_torch(x):
107
import torch
108
return isinstance(x, torch.Tensor)
109
110
111
def _is_torch_device(x):
112
import torch
113
return isinstance(x, torch.device)
114
115
116
def _is_end_of_word(text):
117
last_char = text[-1]
118
return bool(_is_control(last_char) | _is_punctuation(last_char) | _is_whitespace(last_char))
119
120
121
def _is_start_of_word(text):
122
first_char = text[0]
123
return bool(_is_control(first_char) | _is_punctuation(first_char) | _is_whitespace(first_char))
124
125
126
def _is_punctuation(char):
127
cp = ord(char)
128
# We treat all non-letter/number ASCII as punctuation.
129
# Characters such as "^", "$", and "`" are not in the Unicode
130
# Punctuation class but we treat them as punctuation anyways, for
131
# consistency.
132
if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126):
133
return True
134
cat = unicodedata.category(char)
135
if cat.startswith("P"):
136
return True
137
return False
138
139
140
def _is_whitespace(char):
141
# \t, \n, and \r are technically control characters but we treat them
142
# as whitespace since they are generally considered as such.
143
if char == " " or char == "\t" or char == "\n" or char == "\r":
144
return True
145
cat = unicodedata.category(char)
146
if cat == "Zs":
147
return True
148
return False
149
150
151
def _is_control(char):
152
# These are technically control characters but we count them as whitespace
153
# characters.
154
if char == "\t" or char == "\n" or char == "\r":
155
return False
156
cat = unicodedata.category(char)
157
if cat.startswith("C"):
158
return True
159
return False
160
161
162
def load_vocab(vocab_file):
163
vocab = collections.OrderedDict()
164
with open(vocab_file, "r", encoding="utf-8") as reader:
165
tokens = reader.readlines()
166
for index, token in enumerate(tokens):
167
token = token.rstrip("\n")
168
vocab[token] = index
169
return vocab
170
171
172
def whitespace_tokenize(text):
173
text = text.strip()
174
if not text:
175
return []
176
tokens = text.split()
177
return tokens
178
179
180
class BatchEncoding(UserDict):
181
def __init__(
182
self,
183
data: Optional[Dict[str, Any]] = None,
184
encoding: Optional[Union[EncodingFast, Sequence[EncodingFast]]] = None,
185
tensor_type: Union[None, str, TensorType] = None,
186
prepend_batch_axis: bool = False,
187
n_sequences: Optional[int] = None,
188
):
189
super().__init__(data)
190
191
if isinstance(encoding, EncodingFast):
192
encoding = [encoding]
193
194
self._encodings = encoding
195
196
if n_sequences is None and encoding is not None and len(encoding):
197
n_sequences = encoding[0].n_sequences
198
199
self._n_sequences = n_sequences
200
201
self.convert_to_tensors(tensor_type=tensor_type, prepend_batch_axis=prepend_batch_axis)
202
203
@property
204
def n_sequences(self) -> Optional[int]:
205
return self._n_sequences
206
207
@property
208
def is_fast(self) -> bool:
209
return self._encodings is not None
210
211
def __getitem__(self, item: Union[int, str]) -> Union[Any, EncodingFast]:
212
if isinstance(item, str):
213
return self.data[item]
214
elif self._encodings is not None:
215
return self._encodings[item]
216
else:
217
raise KeyError(
218
"Indexing with integers (to access backend Encoding for a given batch index) "
219
"is not available when using Python based tokenizers"
220
)
221
222
def __getattr__(self, item: str):
223
try:
224
return self.data[item]
225
except KeyError:
226
raise AttributeError
227
228
def __getstate__(self):
229
return {"data": self.data, "encodings": self._encodings}
230
231
def __setstate__(self, state):
232
if "data" in state:
233
self.data = state["data"]
234
235
if "encodings" in state:
236
self._encodings = state["encodings"]
237
238
def keys(self):
239
return self.data.keys()
240
241
def values(self):
242
return self.data.values()
243
244
def items(self):
245
return self.data.items()
246
247
# After this point:
248
# Extended properties and methods only available for fast (Rust-based) tokenizers
249
# provided by HuggingFace tokenizers library.
250
251
@property
252
def encodings(self) -> Optional[List[EncodingFast]]:
253
return self._encodings
254
255
def tokens(self, batch_index: int = 0) -> List[str]:
256
if not self._encodings:
257
raise ValueError("tokens() is not available when using Python-based tokenizers")
258
return self._encodings[batch_index].tokens
259
260
def sequence_ids(self, batch_index: int = 0) -> List[Optional[int]]:
261
if not self._encodings:
262
raise ValueError("sequence_ids() is not available when using Python-based tokenizers")
263
return self._encodings[batch_index].sequence_ids
264
265
def words(self, batch_index: int = 0) -> List[Optional[int]]:
266
if not self._encodings:
267
raise ValueError("words() is not available when using Python-based tokenizers")
268
return self.word_ids(batch_index)
269
270
def word_ids(self, batch_index: int = 0) -> List[Optional[int]]:
271
if not self._encodings:
272
raise ValueError("word_ids() is not available when using Python-based tokenizers")
273
return self._encodings[batch_index].word_ids
274
275
def token_to_sequence(self, batch_or_token_index: int, token_index: Optional[int] = None) -> int:
276
if not self._encodings:
277
raise ValueError("token_to_sequence() is not available when using Python based tokenizers")
278
if token_index is not None:
279
batch_index = batch_or_token_index
280
else:
281
batch_index = 0
282
token_index = batch_or_token_index
283
if batch_index < 0:
284
batch_index = self._batch_size + batch_index
285
if token_index < 0:
286
token_index = self._seq_len + token_index
287
return self._encodings[batch_index].token_to_sequence(token_index)
288
289
def token_to_word(self, batch_or_token_index: int, token_index: Optional[int] = None) -> int:
290
if not self._encodings:
291
raise ValueError("token_to_word() is not available when using Python based tokenizers")
292
if token_index is not None:
293
batch_index = batch_or_token_index
294
else:
295
batch_index = 0
296
token_index = batch_or_token_index
297
if batch_index < 0:
298
batch_index = self._batch_size + batch_index
299
if token_index < 0:
300
token_index = self._seq_len + token_index
301
return self._encodings[batch_index].token_to_word(token_index)
302
303
def word_to_tokens(
304
self, batch_or_word_index: int, word_index: Optional[int] = None, sequence_index: int = 0
305
) -> Optional[TokenSpan]:
306
if not self._encodings:
307
raise ValueError("word_to_tokens() is not available when using Python based tokenizers")
308
if word_index is not None:
309
batch_index = batch_or_word_index
310
else:
311
batch_index = 0
312
word_index = batch_or_word_index
313
if batch_index < 0:
314
batch_index = self._batch_size + batch_index
315
if word_index < 0:
316
word_index = self._seq_len + word_index
317
span = self._encodings[batch_index].word_to_tokens(word_index, sequence_index)
318
return TokenSpan(*span) if span is not None else None
319
320
def token_to_chars(self, batch_or_token_index: int, token_index: Optional[int] = None) -> CharSpan:
321
if not self._encodings:
322
raise ValueError("token_to_chars() is not available when using Python based tokenizers")
323
if token_index is not None:
324
batch_index = batch_or_token_index
325
else:
326
batch_index = 0
327
token_index = batch_or_token_index
328
return CharSpan(*(self._encodings[batch_index].token_to_chars(token_index)))
329
330
def char_to_token(
331
self, batch_or_char_index: int, char_index: Optional[int] = None, sequence_index: int = 0
332
) -> int:
333
if not self._encodings:
334
raise ValueError("char_to_token() is not available when using Python based tokenizers")
335
if char_index is not None:
336
batch_index = batch_or_char_index
337
else:
338
batch_index = 0
339
char_index = batch_or_char_index
340
return self._encodings[batch_index].char_to_token(char_index, sequence_index)
341
342
def word_to_chars(
343
self, batch_or_word_index: int, word_index: Optional[int] = None, sequence_index: int = 0
344
) -> CharSpan:
345
if not self._encodings:
346
raise ValueError("word_to_chars() is not available when using Python based tokenizers")
347
if word_index is not None:
348
batch_index = batch_or_word_index
349
else:
350
batch_index = 0
351
word_index = batch_or_word_index
352
return CharSpan(*(self._encodings[batch_index].word_to_chars(word_index, sequence_index)))
353
354
def char_to_word(self, batch_or_char_index: int, char_index: Optional[int] = None, sequence_index: int = 0) -> int:
355
if not self._encodings:
356
raise ValueError("char_to_word() is not available when using Python based tokenizers")
357
if char_index is not None:
358
batch_index = batch_or_char_index
359
else:
360
batch_index = 0
361
char_index = batch_or_char_index
362
return self._encodings[batch_index].char_to_word(char_index, sequence_index)
363
364
def convert_to_tensors(
365
self, tensor_type: Optional[Union[str, TensorType]] = None, prepend_batch_axis: bool = False
366
):
367
if tensor_type is None:
368
return self
369
370
# Convert to TensorType
371
if not isinstance(tensor_type, TensorType):
372
tensor_type = TensorType(tensor_type)
373
374
# Get a function reference for the correct framework
375
if tensor_type == TensorType.TENSORFLOW:
376
if not is_tf_available():
377
raise ImportError(
378
"Unable to convert output to TensorFlow tensors format, TensorFlow is not installed."
379
)
380
import tensorflow as tf
381
382
as_tensor = tf.constant
383
is_tensor = tf.is_tensor
384
elif tensor_type == TensorType.PYTORCH:
385
if not is_torch_available():
386
raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.")
387
import torch
388
389
as_tensor = torch.tensor
390
is_tensor = torch.is_tensor
391
elif tensor_type == TensorType.JAX:
392
if not is_flax_available():
393
raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.")
394
import jax.numpy as jnp # noqa: F811
395
396
as_tensor = jnp.array
397
is_tensor = _is_jax
398
else:
399
as_tensor = np.asarray
400
is_tensor = _is_numpy
401
# (mfuntowicz: This code is unreachable)
402
# else:
403
# raise ImportError(
404
# "Unable to convert output to tensors format {}".format(tensor_type)
405
# )
406
407
# Do the tensor conversion in batch
408
for key, value in self.items():
409
try:
410
if prepend_batch_axis:
411
value = [value]
412
413
if not is_tensor(value):
414
tensor = as_tensor(value)
415
416
# Removing this for now in favor of controlling the shape with `prepend_batch_axis`
417
# # at-least2d
418
# if tensor.ndim > 2:
419
# tensor = tensor.squeeze(0)
420
# elif tensor.ndim < 2:
421
# tensor = tensor[None, :]
422
423
self[key] = tensor
424
except: # noqa E722
425
if key == "overflowing_tokens":
426
raise ValueError(
427
"Unable to create tensor returning overflowing tokens of different lengths. "
428
"Please see if a fast version of this tokenizer is available to have this feature available."
429
)
430
raise ValueError(
431
"Unable to create tensor, you should probably activate truncation and/or padding "
432
"with 'padding=True' 'truncation=True' to have batched tensors with the same length."
433
)
434
435
return self
436
437
def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding":
438
# This check catches things like APEX blindly calling "to" on all inputs to a module
439
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs
440
# into a HalfTensor
441
if isinstance(device, str) or _is_torch_device(device) or isinstance(device, int):
442
self.data = {k: v.to(device=device) for k, v in self.data.items()}
443
return self
444
445
446
class SpecialTokensMixin:
447
SPECIAL_TOKENS_ATTRIBUTES = [
448
"bos_token",
449
"eos_token",
450
"unk_token",
451
"sep_token",
452
"pad_token",
453
"cls_token",
454
"mask_token",
455
"additional_special_tokens",
456
]
457
458
def __init__(self, verbose=True, **kwargs):
459
self._bos_token = None
460
self._eos_token = None
461
self._unk_token = None
462
self._sep_token = None
463
self._pad_token = None
464
self._cls_token = None
465
self._mask_token = None
466
self._pad_token_type_id = 0
467
self._additional_special_tokens = []
468
self.verbose = verbose
469
470
# We directly set the hidden value to allow initialization with special tokens
471
# which are not yet in the vocabulary. Necessary for serialization/de-serialization
472
# TODO clean this up at some point (probably by switching to fast tokenizers)
473
for key, value in kwargs.items():
474
if value is None:
475
continue
476
if key in self.SPECIAL_TOKENS_ATTRIBUTES:
477
if key == "additional_special_tokens":
478
assert isinstance(value, (list, tuple)), f"Value {value} is not a list or tuple"
479
assert all(isinstance(t, str) for t in value), "One of the tokens is not a string"
480
setattr(self, key, value)
481
elif isinstance(value, (str, AddedToken)):
482
setattr(self, key, value)
483
else:
484
raise TypeError(
485
"special token {} has to be either str or AddedToken but got: {}".format(key, type(value))
486
)
487
488
def sanitize_special_tokens(self) -> int:
489
return self.add_tokens(self.all_special_tokens_extended, special_tokens=True)
490
491
def add_special_tokens(self, special_tokens_dict: Dict[str, Union[str, AddedToken]]) -> int:
492
if not special_tokens_dict:
493
return 0
494
495
added_tokens = 0
496
for key, value in special_tokens_dict.items():
497
assert key in self.SPECIAL_TOKENS_ATTRIBUTES, f"Key {key} is not a special token"
498
499
setattr(self, key, value)
500
501
if key == "additional_special_tokens":
502
assert isinstance(value, (list, tuple)) and all(
503
isinstance(t, (str, AddedToken)) for t in value
504
), f"Tokens {value} for key {key} should all be str or AddedToken instances"
505
added_tokens += self.add_tokens(value, special_tokens=True)
506
else:
507
assert isinstance(
508
value, (str, AddedToken)
509
), f"Token {value} for key {key} should be a str or an AddedToken instance"
510
added_tokens += self.add_tokens([value], special_tokens=True)
511
512
return added_tokens
513
514
def add_tokens(
515
self, new_tokens: Union[str, AddedToken, List[Union[str, AddedToken]]], special_tokens: bool = False
516
) -> int:
517
if not new_tokens:
518
return 0
519
520
if not isinstance(new_tokens, (list, tuple)):
521
new_tokens = [new_tokens]
522
523
return self._add_tokens(new_tokens, special_tokens=special_tokens)
524
525
def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:
526
raise NotImplementedError
527
528
@property
529
def bos_token(self) -> str:
530
if self._bos_token is None and self.verbose:
531
return None
532
return str(self._bos_token)
533
534
@property
535
def eos_token(self) -> str:
536
if self._eos_token is None and self.verbose:
537
return None
538
return str(self._eos_token)
539
540
@property
541
def unk_token(self) -> str:
542
if self._unk_token is None and self.verbose:
543
return None
544
return str(self._unk_token)
545
546
@property
547
def sep_token(self) -> str:
548
if self._sep_token is None and self.verbose:
549
return None
550
return str(self._sep_token)
551
552
@property
553
def pad_token(self) -> str:
554
if self._pad_token is None and self.verbose:
555
return None
556
return str(self._pad_token)
557
558
@property
559
def cls_token(self) -> str:
560
if self._cls_token is None and self.verbose:
561
return None
562
return str(self._cls_token)
563
564
@property
565
def mask_token(self) -> str:
566
if self._mask_token is None and self.verbose:
567
return None
568
return str(self._mask_token)
569
570
@property
571
def additional_special_tokens(self) -> List[str]:
572
if self._additional_special_tokens is None and self.verbose:
573
return None
574
return [str(tok) for tok in self._additional_special_tokens]
575
576
@bos_token.setter
577
def bos_token(self, value):
578
self._bos_token = value
579
580
@eos_token.setter
581
def eos_token(self, value):
582
self._eos_token = value
583
584
@unk_token.setter
585
def unk_token(self, value):
586
self._unk_token = value
587
588
@sep_token.setter
589
def sep_token(self, value):
590
self._sep_token = value
591
592
@pad_token.setter
593
def pad_token(self, value):
594
self._pad_token = value
595
596
@cls_token.setter
597
def cls_token(self, value):
598
self._cls_token = value
599
600
@mask_token.setter
601
def mask_token(self, value):
602
self._mask_token = value
603
604
@additional_special_tokens.setter
605
def additional_special_tokens(self, value):
606
self._additional_special_tokens = value
607
608
@property
609
def bos_token_id(self) -> Optional[int]:
610
if self._bos_token is None:
611
return None
612
return self.convert_tokens_to_ids(self.bos_token)
613
614
@property
615
def eos_token_id(self) -> Optional[int]:
616
if self._eos_token is None:
617
return None
618
return self.convert_tokens_to_ids(self.eos_token)
619
620
@property
621
def unk_token_id(self) -> Optional[int]:
622
if self._unk_token is None:
623
return None
624
return self.convert_tokens_to_ids(self.unk_token)
625
626
@property
627
def sep_token_id(self) -> Optional[int]:
628
if self._sep_token is None:
629
return None
630
return self.convert_tokens_to_ids(self.sep_token)
631
632
@property
633
def pad_token_id(self) -> Optional[int]:
634
if self._pad_token is None:
635
return None
636
return self.convert_tokens_to_ids(self.pad_token)
637
638
@property
639
def pad_token_type_id(self) -> int:
640
return self._pad_token_type_id
641
642
@property
643
def cls_token_id(self) -> Optional[int]:
644
if self._cls_token is None:
645
return None
646
return self.convert_tokens_to_ids(self.cls_token)
647
648
@property
649
def mask_token_id(self) -> Optional[int]:
650
if self._mask_token is None:
651
return None
652
return self.convert_tokens_to_ids(self.mask_token)
653
654
@property
655
def additional_special_tokens_ids(self) -> List[int]:
656
return self.convert_tokens_to_ids(self.additional_special_tokens)
657
658
@bos_token_id.setter
659
def bos_token_id(self, value):
660
self._bos_token = self.convert_tokens_to_ids(value)
661
662
@eos_token_id.setter
663
def eos_token_id(self, value):
664
self._eos_token = self.convert_tokens_to_ids(value)
665
666
@unk_token_id.setter
667
def unk_token_id(self, value):
668
self._unk_token = self.convert_tokens_to_ids(value)
669
670
@sep_token_id.setter
671
def sep_token_id(self, value):
672
self._sep_token = self.convert_tokens_to_ids(value)
673
674
@pad_token_id.setter
675
def pad_token_id(self, value):
676
self._pad_token = self.convert_tokens_to_ids(value)
677
678
@cls_token_id.setter
679
def cls_token_id(self, value):
680
self._cls_token = self.convert_tokens_to_ids(value)
681
682
@mask_token_id.setter
683
def mask_token_id(self, value):
684
self._mask_token = self.convert_tokens_to_ids(value)
685
686
@additional_special_tokens_ids.setter
687
def additional_special_tokens_ids(self, values):
688
self._additional_special_tokens = [self.convert_tokens_to_ids(value) for value in values]
689
690
@property
691
def special_tokens_map(self) -> Dict[str, Union[str, List[str]]]:
692
set_attr = {}
693
for attr in self.SPECIAL_TOKENS_ATTRIBUTES:
694
attr_value = getattr(self, "_" + attr)
695
if attr_value:
696
set_attr[attr] = str(attr_value)
697
return set_attr
698
699
@property
700
def special_tokens_map_extended(self) -> Dict[str, Union[str, AddedToken, List[Union[str, AddedToken]]]]:
701
set_attr = {}
702
for attr in self.SPECIAL_TOKENS_ATTRIBUTES:
703
attr_value = getattr(self, "_" + attr)
704
if attr_value:
705
set_attr[attr] = attr_value
706
return set_attr
707
708
@property
709
def all_special_tokens(self) -> List[str]:
710
all_toks = [str(s) for s in self.all_special_tokens_extended]
711
return all_toks
712
713
@property
714
def all_special_tokens_extended(self) -> List[Union[str, AddedToken]]:
715
all_toks = []
716
set_attr = self.special_tokens_map_extended
717
for attr_value in set_attr.values():
718
all_toks = all_toks + (list(attr_value) if isinstance(attr_value, (list, tuple)) else [attr_value])
719
all_toks = list(OrderedDict.fromkeys(all_toks))
720
return all_toks
721
722
@property
723
def all_special_ids(self) -> List[int]:
724
all_toks = self.all_special_tokens
725
all_ids = self.convert_tokens_to_ids(all_toks)
726
return all_ids
727
728
729
class PreTrainedTokenizerBase(SpecialTokensMixin):
730
vocab_files_names: Dict[str, str] = {}
731
pretrained_vocab_files_map: Dict[str, Dict[str, str]] = {}
732
pretrained_init_configuration: Dict[str, Dict[str, Any]] = {}
733
max_model_input_sizes: Dict[str, Optional[int]] = {}
734
735
# first name has to correspond to main model input name
736
# to make sure `tokenizer.pad(...)` works correctly
737
model_input_names: List[str] = ["input_ids", "token_type_ids", "attention_mask"]
738
padding_side: str = "right"
739
slow_tokenizer_class = None
740
741
def __init__(self, **kwargs):
742
# inputs and kwargs for saving and re-loading (see ``from_pretrained`` and ``save_pretrained``)
743
self.init_inputs = ()
744
self.init_kwargs = copy.deepcopy(kwargs)
745
self.name_or_path = kwargs.pop("name_or_path", "")
746
747
# For backward compatibility we fallback to set model_max_length from max_len if provided
748
model_max_length = kwargs.pop("model_max_length", kwargs.pop("max_len", None))
749
self.model_max_length = model_max_length if model_max_length is not None else VERY_LARGE_INTEGER
750
751
# Padding side is right by default and overridden in subclasses. If specified in the kwargs, it is changed.
752
self.padding_side = kwargs.pop("padding_side", self.padding_side)
753
assert self.padding_side in [
754
"right",
755
"left",
756
], f"Padding side should be selected between 'right' and 'left', current value: {self.padding_side}"
757
self.model_input_names = kwargs.pop("model_input_names", self.model_input_names)
758
759
self.deprecation_warnings = (
760
{}
761
) # Use to store when we have already noticed a deprecation warning (avoid overlogging).
762
763
super().__init__(**kwargs)
764
765
@property
766
def max_len_single_sentence(self) -> int:
767
return self.model_max_length - self.num_special_tokens_to_add(pair=False)
768
769
@property
770
def max_len_sentences_pair(self) -> int:
771
return self.model_max_length - self.num_special_tokens_to_add(pair=True)
772
773
@max_len_single_sentence.setter
774
def max_len_single_sentence(self, value) -> int:
775
# For backward compatibility, allow to try to setup 'max_len_single_sentence'.
776
if value == self.model_max_length - self.num_special_tokens_to_add(pair=False) and self.verbose:
777
self.deprecation_warnings["max_len_single_sentence"] = True
778
else:
779
raise ValueError(
780
"Setting 'max_len_single_sentence' is now deprecated. " "This value is automatically set up."
781
)
782
783
@max_len_sentences_pair.setter
784
def max_len_sentences_pair(self, value) -> int:
785
# For backward compatibility, allow to try to setup 'max_len_sentences_pair'.
786
if value == self.model_max_length - self.num_special_tokens_to_add(pair=True) and self.verbose:
787
self.deprecation_warnings["max_len_sentences_pair"] = True
788
else:
789
raise ValueError(
790
"Setting 'max_len_sentences_pair' is now deprecated. " "This value is automatically set up."
791
)
792
793
def __repr__(self) -> str:
794
return (
795
f"{'PreTrainedTokenizerFast' if self.is_fast else 'PreTrainedTokenizer'}(name_or_path='{self.name_or_path}', "
796
f"vocab_size={self.vocab_size}, model_max_len={self.model_max_length}, is_fast={self.is_fast}, "
797
f"padding_side='{self.padding_side}', special_tokens={self.special_tokens_map_extended})"
798
)
799
800
def get_vocab(self) -> Dict[str, int]:
801
raise NotImplementedError()
802
803
@classmethod
804
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], *init_inputs, **kwargs):
805
cache_dir = kwargs.pop("cache_dir", None)
806
force_download = kwargs.pop("force_download", False)
807
resume_download = kwargs.pop("resume_download", False)
808
proxies = kwargs.pop("proxies", None)
809
local_files_only = kwargs.pop("local_files_only", False)
810
use_auth_token = kwargs.pop("use_auth_token", None)
811
revision = kwargs.pop("revision", None)
812
subfolder = kwargs.pop("subfolder", None)
813
814
s3_models = list(cls.max_model_input_sizes.keys())
815
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
816
vocab_files = {}
817
init_configuration = {}
818
if pretrained_model_name_or_path in s3_models:
819
# Get the vocabulary from AWS S3 bucket
820
for file_id, map_list in cls.pretrained_vocab_files_map.items():
821
vocab_files[file_id] = map_list[pretrained_model_name_or_path]
822
if (
823
cls.pretrained_init_configuration
824
and pretrained_model_name_or_path in cls.pretrained_init_configuration
825
):
826
init_configuration = cls.pretrained_init_configuration[pretrained_model_name_or_path].copy()
827
else:
828
# Get the vocabulary from local files
829
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
830
if len(cls.vocab_files_names) > 1:
831
raise ValueError(
832
"Calling {}.from_pretrained() with the path to a single file or url is not supported."
833
"Use a model identifier or the path to a directory instead.".format(cls.__name__)
834
)
835
file_id = list(cls.vocab_files_names.keys())[0]
836
vocab_files[file_id] = pretrained_model_name_or_path
837
else:
838
# At this point pretrained_model_name_or_path is either a directory or a model identifier name
839
additional_files_names = {
840
"added_tokens_file": ADDED_TOKENS_FILE,
841
"special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE,
842
"tokenizer_config_file": TOKENIZER_CONFIG_FILE,
843
"tokenizer_file": FULL_TOKENIZER_FILE,
844
}
845
# Look for the tokenizer files
846
for file_id, file_name in {**cls.vocab_files_names, **additional_files_names}.items():
847
if os.path.isdir(pretrained_model_name_or_path):
848
if subfolder is not None:
849
full_file_name = os.path.join(pretrained_model_name_or_path, subfolder, file_name)
850
else:
851
full_file_name = os.path.join(pretrained_model_name_or_path, file_name)
852
if not os.path.exists(full_file_name):
853
full_file_name = None
854
else:
855
full_file_name = hf_bucket_url(
856
pretrained_model_name_or_path,
857
filename=file_name,
858
subfolder=subfolder,
859
revision=revision,
860
mirror=None,
861
)
862
863
vocab_files[file_id] = full_file_name
864
865
# Get files from url, cache, or disk depending on the case
866
resolved_vocab_files = {}
867
unresolved_files = []
868
for file_id, file_path in vocab_files.items():
869
if file_path is None:
870
resolved_vocab_files[file_id] = None
871
else:
872
try:
873
try:
874
resolved_vocab_files[file_id] = cached_path(
875
file_path,
876
cache_dir=cache_dir,
877
force_download=force_download,
878
proxies=proxies,
879
resume_download=resume_download,
880
local_files_only=local_files_only,
881
use_auth_token=use_auth_token,
882
)
883
except FileNotFoundError as error:
884
if local_files_only:
885
unresolved_files.append(file_id)
886
else:
887
raise error
888
889
except requests.exceptions.HTTPError as err:
890
if "404 Client Error" in str(err):
891
resolved_vocab_files[file_id] = None
892
else:
893
raise err
894
895
if all(full_file_name is None for full_file_name in resolved_vocab_files.values()):
896
msg = (
897
f"Can't load tokenizer for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
898
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
899
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing relevant tokenizer files\n\n"
900
)
901
raise EnvironmentError(msg)
902
903
for file_id, file_path in vocab_files.items():
904
if file_id not in resolved_vocab_files:
905
continue
906
907
return cls._from_pretrained(
908
resolved_vocab_files, pretrained_model_name_or_path, init_configuration, *init_inputs, **kwargs
909
)
910
911
@classmethod
912
def _from_pretrained(
913
cls, resolved_vocab_files, pretrained_model_name_or_path, init_configuration, *init_inputs, **kwargs
914
):
915
# We instantiate fast tokenizers based on a slow tokenizer if we don't have access to the tokenizer.json
916
# file or if `from_slow` is set to True.
917
from_slow = kwargs.get("from_slow", False)
918
has_tokenizer_file = resolved_vocab_files.get("tokenizer_file", None) is not None
919
if (from_slow or not has_tokenizer_file) and cls.slow_tokenizer_class is not None:
920
slow_tokenizer = (cls.slow_tokenizer_class)._from_pretrained(
921
copy.deepcopy(resolved_vocab_files),
922
pretrained_model_name_or_path,
923
copy.deepcopy(init_configuration),
924
*init_inputs,
925
**(copy.deepcopy(kwargs)),
926
)
927
else:
928
slow_tokenizer = None
929
930
# Prepare tokenizer initialization kwargs
931
# Did we saved some inputs and kwargs to reload ?
932
tokenizer_config_file = resolved_vocab_files.pop("tokenizer_config_file", None)
933
if tokenizer_config_file is not None:
934
with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle:
935
init_kwargs = json.load(tokenizer_config_handle)
936
saved_init_inputs = init_kwargs.pop("init_inputs", ())
937
if not init_inputs:
938
init_inputs = saved_init_inputs
939
else:
940
init_kwargs = init_configuration
941
942
# Update with newly provided kwargs
943
init_kwargs.update(kwargs)
944
945
# Convert AddedTokens serialized as dict to class instances
946
def convert_added_tokens(obj: Union[AddedToken, Any]):
947
if isinstance(obj, dict) and "__type" in obj and obj["__type"] == "AddedToken":
948
obj.pop("__type")
949
return AddedToken(**obj)
950
elif isinstance(obj, (list, tuple)):
951
return list(convert_added_tokens(o) for o in obj)
952
elif isinstance(obj, dict):
953
return {k: convert_added_tokens(v) for k, v in obj.items()}
954
return obj
955
956
init_kwargs = convert_added_tokens(init_kwargs)
957
958
# Set max length if needed
959
if pretrained_model_name_or_path in cls.max_model_input_sizes:
960
# if we're using a pretrained model, ensure the tokenizer
961
# wont index sequences longer than the number of positional embeddings
962
model_max_length = cls.max_model_input_sizes[pretrained_model_name_or_path]
963
if model_max_length is not None and isinstance(model_max_length, (int, float)):
964
init_kwargs["model_max_length"] = min(init_kwargs.get("model_max_length", int(1e30)), model_max_length)
965
966
# Merge resolved_vocab_files arguments in init_kwargs.
967
added_tokens_file = resolved_vocab_files.pop("added_tokens_file", None)
968
for args_name, file_path in resolved_vocab_files.items():
969
if args_name not in init_kwargs:
970
init_kwargs[args_name] = file_path
971
972
if slow_tokenizer is not None:
973
init_kwargs["__slow_tokenizer"] = slow_tokenizer
974
975
init_kwargs["name_or_path"] = pretrained_model_name_or_path
976
977
# Instantiate tokenizer.
978
try:
979
tokenizer = cls(*init_inputs, **init_kwargs)
980
except OSError:
981
raise OSError(
982
"Unable to load vocabulary from file. "
983
"Please check that the provided vocabulary is accessible and not corrupted."
984
)
985
986
# Save inputs and kwargs for saving and re-loading with ``save_pretrained``
987
# Removed: Now done at the base class level
988
# tokenizer.init_inputs = init_inputs
989
# tokenizer.init_kwargs = init_kwargs
990
991
# If there is a complementary special token map, load it
992
special_tokens_map_file = resolved_vocab_files.pop("special_tokens_map_file", None)
993
if special_tokens_map_file is not None:
994
with open(special_tokens_map_file, encoding="utf-8") as special_tokens_map_handle:
995
special_tokens_map = json.load(special_tokens_map_handle)
996
for key, value in special_tokens_map.items():
997
if isinstance(value, dict):
998
value = AddedToken(**value)
999
elif isinstance(value, list):
1000
value = [AddedToken(**token) if isinstance(token, dict) else token for token in value]
1001
setattr(tokenizer, key, value)
1002
1003
# Add supplementary tokens.
1004
special_tokens = tokenizer.all_special_tokens
1005
if added_tokens_file is not None:
1006
with open(added_tokens_file, encoding="utf-8") as added_tokens_handle:
1007
added_tok_encoder = json.load(added_tokens_handle)
1008
1009
# Sort added tokens by index
1010
added_tok_encoder_sorted = list(sorted(added_tok_encoder.items(), key=lambda x: x[1]))
1011
1012
for token, index in added_tok_encoder_sorted:
1013
assert index == len(tokenizer), (
1014
f"Non-consecutive added token '{token}' found. "
1015
f"Should have index {len(tokenizer)} but has index {index} in saved vocabulary."
1016
)
1017
tokenizer.add_tokens(token, special_tokens=bool(token in special_tokens))
1018
1019
# Check all our special tokens are registered as "no split" token (we don't cut them) and are in the vocab
1020
added_tokens = tokenizer.sanitize_special_tokens()
1021
1022
return tokenizer
1023
1024
def save_pretrained(
1025
self,
1026
save_directory: Union[str, os.PathLike],
1027
legacy_format: bool = True,
1028
filename_prefix: Optional[str] = None,
1029
) -> Tuple[str]:
1030
if os.path.isfile(save_directory):
1031
return
1032
os.makedirs(save_directory, exist_ok=True)
1033
1034
special_tokens_map_file = os.path.join(
1035
save_directory, (filename_prefix + "-" if filename_prefix else "") + SPECIAL_TOKENS_MAP_FILE
1036
)
1037
tokenizer_config_file = os.path.join(
1038
save_directory, (filename_prefix + "-" if filename_prefix else "") + TOKENIZER_CONFIG_FILE
1039
)
1040
1041
tokenizer_config = copy.deepcopy(self.init_kwargs)
1042
if len(self.init_inputs) > 0:
1043
tokenizer_config["init_inputs"] = copy.deepcopy(self.init_inputs)
1044
for file_id in self.vocab_files_names.keys():
1045
tokenizer_config.pop(file_id, None)
1046
1047
# Sanitize AddedTokens
1048
def convert_added_tokens(obj: Union[AddedToken, Any], add_type_field=True):
1049
if isinstance(obj, AddedToken):
1050
out = obj.__getstate__()
1051
if add_type_field:
1052
out["__type"] = "AddedToken"
1053
return out
1054
elif isinstance(obj, (list, tuple)):
1055
return list(convert_added_tokens(o, add_type_field=add_type_field) for o in obj)
1056
elif isinstance(obj, dict):
1057
return {k: convert_added_tokens(v, add_type_field=add_type_field) for k, v in obj.items()}
1058
return obj
1059
1060
# add_type_field=True to allow dicts in the kwargs / differentiate from AddedToken serialization
1061
tokenizer_config = convert_added_tokens(tokenizer_config, add_type_field=True)
1062
with open(tokenizer_config_file, "w", encoding="utf-8") as f:
1063
f.write(json.dumps(tokenizer_config, ensure_ascii=False))
1064
1065
# Sanitize AddedTokens in special_tokens_map
1066
write_dict = convert_added_tokens(self.special_tokens_map_extended, add_type_field=False)
1067
with open(special_tokens_map_file, "w", encoding="utf-8") as f:
1068
f.write(json.dumps(write_dict, ensure_ascii=False))
1069
1070
file_names = (tokenizer_config_file, special_tokens_map_file)
1071
1072
return self._save_pretrained(
1073
save_directory=save_directory,
1074
file_names=file_names,
1075
legacy_format=legacy_format,
1076
filename_prefix=filename_prefix,
1077
)
1078
1079
def _save_pretrained(
1080
self,
1081
save_directory: Union[str, os.PathLike],
1082
file_names: Tuple[str],
1083
legacy_format: bool = True,
1084
filename_prefix: Optional[str] = None,
1085
) -> Tuple[str]:
1086
if not legacy_format:
1087
raise ValueError(
1088
"Only fast tokenizers (instances of PretrainedTokenizerFast) can be saved in non legacy format."
1089
)
1090
1091
save_directory = str(save_directory)
1092
1093
added_tokens_file = os.path.join(
1094
save_directory, (filename_prefix + "-" if filename_prefix else "") + ADDED_TOKENS_FILE
1095
)
1096
added_vocab = self.get_added_vocab()
1097
if added_vocab:
1098
with open(added_tokens_file, "w", encoding="utf-8") as f:
1099
out_str = json.dumps(added_vocab, ensure_ascii=False)
1100
f.write(out_str)
1101
1102
vocab_files = self.save_vocabulary(save_directory, filename_prefix=filename_prefix)
1103
1104
return file_names + vocab_files + (added_tokens_file,)
1105
1106
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
1107
raise NotImplementedError
1108
1109
def tokenize(self, text: str, pair: Optional[str] = None, add_special_tokens: bool = False, **kwargs) -> List[str]:
1110
raise NotImplementedError
1111
1112
def encode(
1113
self,
1114
text: Union[TextInput, PreTokenizedInput, EncodedInput],
1115
text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None,
1116
add_special_tokens: bool = True,
1117
padding: Union[bool, str, PaddingStrategy] = False,
1118
truncation: Union[bool, str, TruncationStrategy] = False,
1119
max_length: Optional[int] = None,
1120
stride: int = 0,
1121
return_tensors: Optional[Union[str, TensorType]] = None,
1122
**kwargs
1123
) -> List[int]:
1124
encoded_inputs = self.encode_plus(
1125
text,
1126
text_pair=text_pair,
1127
add_special_tokens=add_special_tokens,
1128
padding=padding,
1129
truncation=truncation,
1130
max_length=max_length,
1131
stride=stride,
1132
return_tensors=return_tensors,
1133
**kwargs,
1134
)
1135
1136
return encoded_inputs["input_ids"]
1137
1138
def num_special_tokens_to_add(self, pair: bool = False) -> int:
1139
raise NotImplementedError
1140
1141
def _get_padding_truncation_strategies(
1142
self, padding=False, truncation=False, max_length=None, pad_to_multiple_of=None, verbose=True, **kwargs
1143
):
1144
old_truncation_strategy = kwargs.pop("truncation_strategy", "do_not_truncate")
1145
old_pad_to_max_length = kwargs.pop("pad_to_max_length", False)
1146
1147
# Backward compatibility for previous behavior, maybe we should deprecate it:
1148
# If you only set max_length, it activates truncation for max_length
1149
if max_length is not None and padding is False and truncation is False:
1150
if verbose:
1151
self.deprecation_warnings["Truncation-not-explicitly-activated"] = True
1152
truncation = "longest_first"
1153
1154
# Get padding strategy
1155
if padding is False and old_pad_to_max_length:
1156
if max_length is None:
1157
padding_strategy = PaddingStrategy.LONGEST
1158
else:
1159
padding_strategy = PaddingStrategy.MAX_LENGTH
1160
elif padding is not False:
1161
if padding is True:
1162
padding_strategy = PaddingStrategy.LONGEST # Default to pad to the longest sequence in the batch
1163
elif not isinstance(padding, PaddingStrategy):
1164
padding_strategy = PaddingStrategy(padding)
1165
elif isinstance(padding, PaddingStrategy):
1166
padding_strategy = padding
1167
else:
1168
padding_strategy = PaddingStrategy.DO_NOT_PAD
1169
1170
# Get truncation strategy
1171
if truncation is False and old_truncation_strategy != "do_not_truncate":
1172
truncation_strategy = TruncationStrategy(old_truncation_strategy)
1173
elif truncation is not False:
1174
if truncation is True:
1175
truncation_strategy = (
1176
TruncationStrategy.LONGEST_FIRST
1177
) # Default to truncate the longest sequences in pairs of inputs
1178
elif not isinstance(truncation, TruncationStrategy):
1179
truncation_strategy = TruncationStrategy(truncation)
1180
elif isinstance(truncation, TruncationStrategy):
1181
truncation_strategy = truncation
1182
else:
1183
truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
1184
1185
# Set max length if needed
1186
if max_length is None:
1187
if padding_strategy == PaddingStrategy.MAX_LENGTH:
1188
if self.model_max_length > LARGE_INTEGER:
1189
if verbose:
1190
self.deprecation_warnings["Asking-to-pad-to-max_length"] = True
1191
padding_strategy = PaddingStrategy.DO_NOT_PAD
1192
else:
1193
max_length = self.model_max_length
1194
1195
if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE:
1196
if self.model_max_length > LARGE_INTEGER:
1197
if verbose:
1198
self.deprecation_warnings["Asking-to-truncate-to-max_length"] = True
1199
truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
1200
else:
1201
max_length = self.model_max_length
1202
1203
# Test if we have a padding token
1204
if padding_strategy != PaddingStrategy.DO_NOT_PAD and (not self.pad_token or self.pad_token_id < 0):
1205
raise ValueError(
1206
"Asking to pad but the tokenizer does not have a padding token. "
1207
"Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` "
1208
"or add a new pad token via `tokenizer.add_special_tokens({'pad_token': '[PAD]'})`."
1209
)
1210
1211
# Check that we will truncate to a multiple of pad_to_multiple_of if both are provided
1212
if (
1213
truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE
1214
and padding_strategy != PaddingStrategy.DO_NOT_PAD
1215
and pad_to_multiple_of is not None
1216
and max_length is not None
1217
and (max_length % pad_to_multiple_of != 0)
1218
):
1219
raise ValueError(
1220
f"Truncation and padding are both activated but "
1221
f"truncation length ({max_length}) is not a multiple of pad_to_multiple_of ({pad_to_multiple_of})."
1222
)
1223
1224
return padding_strategy, truncation_strategy, max_length, kwargs
1225
1226
def __call__(
1227
self,
1228
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],
1229
text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
1230
add_special_tokens: bool = True,
1231
padding: Union[bool, str, PaddingStrategy] = False,
1232
truncation: Union[bool, str, TruncationStrategy] = False,
1233
max_length: Optional[int] = None,
1234
stride: int = 0,
1235
is_split_into_words: bool = False,
1236
pad_to_multiple_of: Optional[int] = None,
1237
return_tensors: Optional[Union[str, TensorType]] = None,
1238
return_token_type_ids: Optional[bool] = None,
1239
return_attention_mask: Optional[bool] = None,
1240
return_overflowing_tokens: bool = False,
1241
return_special_tokens_mask: bool = False,
1242
return_offsets_mapping: bool = False,
1243
return_length: bool = False,
1244
verbose: bool = True,
1245
**kwargs
1246
) -> BatchEncoding:
1247
# Input type checking for clearer error
1248
assert isinstance(text, str) or (
1249
isinstance(text, (list, tuple))
1250
and (
1251
len(text) == 0
1252
or (
1253
isinstance(text[0], str)
1254
or (isinstance(text[0], (list, tuple)) and (len(text[0]) == 0 or isinstance(text[0][0], str)))
1255
)
1256
)
1257
), (
1258
"text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) "
1259
"or `List[List[str]]` (batch of pretokenized examples)."
1260
)
1261
1262
assert (
1263
text_pair is None
1264
or isinstance(text_pair, str)
1265
or (
1266
isinstance(text_pair, (list, tuple))
1267
and (
1268
len(text_pair) == 0
1269
or (
1270
isinstance(text_pair[0], str)
1271
or (
1272
isinstance(text_pair[0], (list, tuple))
1273
and (len(text_pair[0]) == 0 or isinstance(text_pair[0][0], str))
1274
)
1275
)
1276
)
1277
)
1278
), (
1279
"text_pair input must of type `str` (single example), `List[str]` (batch or single pretokenized example) "
1280
"or `List[List[str]]` (batch of pretokenized examples)."
1281
)
1282
1283
is_batched = bool(
1284
(not is_split_into_words and isinstance(text, (list, tuple)))
1285
or (
1286
is_split_into_words and isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple))
1287
)
1288
)
1289
1290
if is_batched:
1291
batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text
1292
return self.batch_encode_plus(
1293
batch_text_or_text_pairs=batch_text_or_text_pairs,
1294
add_special_tokens=add_special_tokens,
1295
padding=padding,
1296
truncation=truncation,
1297
max_length=max_length,
1298
stride=stride,
1299
is_split_into_words=is_split_into_words,
1300
pad_to_multiple_of=pad_to_multiple_of,
1301
return_tensors=return_tensors,
1302
return_token_type_ids=return_token_type_ids,
1303
return_attention_mask=return_attention_mask,
1304
return_overflowing_tokens=return_overflowing_tokens,
1305
return_special_tokens_mask=return_special_tokens_mask,
1306
return_offsets_mapping=return_offsets_mapping,
1307
return_length=return_length,
1308
verbose=verbose,
1309
**kwargs,
1310
)
1311
else:
1312
return self.encode_plus(
1313
text=text,
1314
text_pair=text_pair,
1315
add_special_tokens=add_special_tokens,
1316
padding=padding,
1317
truncation=truncation,
1318
max_length=max_length,
1319
stride=stride,
1320
is_split_into_words=is_split_into_words,
1321
pad_to_multiple_of=pad_to_multiple_of,
1322
return_tensors=return_tensors,
1323
return_token_type_ids=return_token_type_ids,
1324
return_attention_mask=return_attention_mask,
1325
return_overflowing_tokens=return_overflowing_tokens,
1326
return_special_tokens_mask=return_special_tokens_mask,
1327
return_offsets_mapping=return_offsets_mapping,
1328
return_length=return_length,
1329
verbose=verbose,
1330
**kwargs,
1331
)
1332
1333
def encode_plus(
1334
self,
1335
text: Union[TextInput, PreTokenizedInput, EncodedInput],
1336
text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None,
1337
add_special_tokens: bool = True,
1338
padding: Union[bool, str, PaddingStrategy] = False,
1339
truncation: Union[bool, str, TruncationStrategy] = False,
1340
max_length: Optional[int] = None,
1341
stride: int = 0,
1342
is_split_into_words: bool = False,
1343
pad_to_multiple_of: Optional[int] = None,
1344
return_tensors: Optional[Union[str, TensorType]] = None,
1345
return_token_type_ids: Optional[bool] = None,
1346
return_attention_mask: Optional[bool] = None,
1347
return_overflowing_tokens: bool = False,
1348
return_special_tokens_mask: bool = False,
1349
return_offsets_mapping: bool = False,
1350
return_length: bool = False,
1351
verbose: bool = True,
1352
**kwargs
1353
) -> BatchEncoding:
1354
# Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
1355
padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
1356
padding=padding,
1357
truncation=truncation,
1358
max_length=max_length,
1359
pad_to_multiple_of=pad_to_multiple_of,
1360
verbose=verbose,
1361
**kwargs,
1362
)
1363
1364
return self._encode_plus(
1365
text=text,
1366
text_pair=text_pair,
1367
add_special_tokens=add_special_tokens,
1368
padding_strategy=padding_strategy,
1369
truncation_strategy=truncation_strategy,
1370
max_length=max_length,
1371
stride=stride,
1372
is_split_into_words=is_split_into_words,
1373
pad_to_multiple_of=pad_to_multiple_of,
1374
return_tensors=return_tensors,
1375
return_token_type_ids=return_token_type_ids,
1376
return_attention_mask=return_attention_mask,
1377
return_overflowing_tokens=return_overflowing_tokens,
1378
return_special_tokens_mask=return_special_tokens_mask,
1379
return_offsets_mapping=return_offsets_mapping,
1380
return_length=return_length,
1381
verbose=verbose,
1382
**kwargs,
1383
)
1384
1385
def _encode_plus(
1386
self,
1387
text: Union[TextInput, PreTokenizedInput, EncodedInput],
1388
text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None,
1389
add_special_tokens: bool = True,
1390
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
1391
truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
1392
max_length: Optional[int] = None,
1393
stride: int = 0,
1394
is_split_into_words: bool = False,
1395
pad_to_multiple_of: Optional[int] = None,
1396
return_tensors: Optional[Union[str, TensorType]] = None,
1397
return_token_type_ids: Optional[bool] = None,
1398
return_attention_mask: Optional[bool] = None,
1399
return_overflowing_tokens: bool = False,
1400
return_special_tokens_mask: bool = False,
1401
return_offsets_mapping: bool = False,
1402
return_length: bool = False,
1403
verbose: bool = True,
1404
**kwargs
1405
) -> BatchEncoding:
1406
raise NotImplementedError
1407
1408
def batch_encode_plus(
1409
self,
1410
batch_text_or_text_pairs: Union[
1411
List[TextInput],
1412
List[TextInputPair],
1413
List[PreTokenizedInput],
1414
List[PreTokenizedInputPair],
1415
List[EncodedInput],
1416
List[EncodedInputPair],
1417
],
1418
add_special_tokens: bool = True,
1419
padding: Union[bool, str, PaddingStrategy] = False,
1420
truncation: Union[bool, str, TruncationStrategy] = False,
1421
max_length: Optional[int] = None,
1422
stride: int = 0,
1423
is_split_into_words: bool = False,
1424
pad_to_multiple_of: Optional[int] = None,
1425
return_tensors: Optional[Union[str, TensorType]] = None,
1426
return_token_type_ids: Optional[bool] = None,
1427
return_attention_mask: Optional[bool] = None,
1428
return_overflowing_tokens: bool = False,
1429
return_special_tokens_mask: bool = False,
1430
return_offsets_mapping: bool = False,
1431
return_length: bool = False,
1432
verbose: bool = True,
1433
**kwargs
1434
) -> BatchEncoding:
1435
# Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
1436
padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
1437
padding=padding,
1438
truncation=truncation,
1439
max_length=max_length,
1440
pad_to_multiple_of=pad_to_multiple_of,
1441
verbose=verbose,
1442
**kwargs,
1443
)
1444
1445
return self._batch_encode_plus(
1446
batch_text_or_text_pairs=batch_text_or_text_pairs,
1447
add_special_tokens=add_special_tokens,
1448
padding_strategy=padding_strategy,
1449
truncation_strategy=truncation_strategy,
1450
max_length=max_length,
1451
stride=stride,
1452
is_split_into_words=is_split_into_words,
1453
pad_to_multiple_of=pad_to_multiple_of,
1454
return_tensors=return_tensors,
1455
return_token_type_ids=return_token_type_ids,
1456
return_attention_mask=return_attention_mask,
1457
return_overflowing_tokens=return_overflowing_tokens,
1458
return_special_tokens_mask=return_special_tokens_mask,
1459
return_offsets_mapping=return_offsets_mapping,
1460
return_length=return_length,
1461
verbose=verbose,
1462
**kwargs,
1463
)
1464
1465
def _batch_encode_plus(
1466
self,
1467
batch_text_or_text_pairs: Union[
1468
List[TextInput],
1469
List[TextInputPair],
1470
List[PreTokenizedInput],
1471
List[PreTokenizedInputPair],
1472
List[EncodedInput],
1473
List[EncodedInputPair],
1474
],
1475
add_special_tokens: bool = True,
1476
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
1477
truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
1478
max_length: Optional[int] = None,
1479
stride: int = 0,
1480
is_split_into_words: bool = False,
1481
pad_to_multiple_of: Optional[int] = None,
1482
return_tensors: Optional[Union[str, TensorType]] = None,
1483
return_token_type_ids: Optional[bool] = None,
1484
return_attention_mask: Optional[bool] = None,
1485
return_overflowing_tokens: bool = False,
1486
return_special_tokens_mask: bool = False,
1487
return_offsets_mapping: bool = False,
1488
return_length: bool = False,
1489
verbose: bool = True,
1490
**kwargs
1491
) -> BatchEncoding:
1492
raise NotImplementedError
1493
1494
def pad(
1495
self,
1496
encoded_inputs: Union[
1497
BatchEncoding,
1498
List[BatchEncoding],
1499
Dict[str, EncodedInput],
1500
Dict[str, List[EncodedInput]],
1501
List[Dict[str, EncodedInput]],
1502
],
1503
padding: Union[bool, str, PaddingStrategy] = True,
1504
max_length: Optional[int] = None,
1505
pad_to_multiple_of: Optional[int] = None,
1506
return_attention_mask: Optional[bool] = None,
1507
return_tensors: Optional[Union[str, TensorType]] = None,
1508
verbose: bool = True,
1509
) -> BatchEncoding:
1510
# If we have a list of dicts, let's convert it in a dict of lists
1511
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader
1512
if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], (dict, BatchEncoding)):
1513
encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()}
1514
1515
# The model's main input name, usually `input_ids`, has be passed for padding
1516
if self.model_input_names[0] not in encoded_inputs:
1517
raise ValueError(
1518
"You should supply an encoding or a list of encodings to this method"
1519
f"that includes {self.model_input_names[0]}, but you provided {list(encoded_inputs.keys())}"
1520
)
1521
1522
required_input = encoded_inputs[self.model_input_names[0]]
1523
1524
if not required_input:
1525
if return_attention_mask:
1526
encoded_inputs["attention_mask"] = []
1527
return encoded_inputs
1528
1529
# If we have PyTorch/TF/NumPy tensors/arrays as inputs, we cast them as python objects
1530
# and rebuild them afterwards if no return_tensors is specified
1531
# Note that we lose the specific device the tensor may be on for PyTorch
1532
1533
first_element = required_input[0]
1534
if isinstance(first_element, (list, tuple)):
1535
# first_element might be an empty list/tuple in some edge cases so we grab the first non empty element.
1536
index = 0
1537
while len(required_input[index]) == 0:
1538
index += 1
1539
if index < len(required_input):
1540
first_element = required_input[index][0]
1541
# At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.
1542
if not isinstance(first_element, (int, list, tuple)):
1543
if is_tf_available() and _is_tensorflow(first_element):
1544
return_tensors = "tf" if return_tensors is None else return_tensors
1545
elif is_torch_available() and _is_torch(first_element):
1546
return_tensors = "pt" if return_tensors is None else return_tensors
1547
elif isinstance(first_element, np.ndarray):
1548
return_tensors = "np" if return_tensors is None else return_tensors
1549
else:
1550
raise ValueError(
1551
f"type of {first_element} unknown: {type(first_element)}. "
1552
f"Should be one of a python, numpy, pytorch or tensorflow object."
1553
)
1554
1555
for key, value in encoded_inputs.items():
1556
encoded_inputs[key] = to_py_obj(value)
1557
1558
# Convert padding_strategy in PaddingStrategy
1559
padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies(
1560
padding=padding, max_length=max_length, verbose=verbose
1561
)
1562
1563
required_input = encoded_inputs[self.model_input_names[0]]
1564
if required_input and not isinstance(required_input[0], (list, tuple)):
1565
encoded_inputs = self._pad(
1566
encoded_inputs,
1567
max_length=max_length,
1568
padding_strategy=padding_strategy,
1569
pad_to_multiple_of=pad_to_multiple_of,
1570
return_attention_mask=return_attention_mask,
1571
)
1572
return BatchEncoding(encoded_inputs, tensor_type=return_tensors)
1573
1574
batch_size = len(required_input)
1575
assert all(
1576
len(v) == batch_size for v in encoded_inputs.values()
1577
), "Some items in the output dictionary have a different batch size than others."
1578
1579
if padding_strategy == PaddingStrategy.LONGEST:
1580
max_length = max(len(inputs) for inputs in required_input)
1581
padding_strategy = PaddingStrategy.MAX_LENGTH
1582
1583
batch_outputs = {}
1584
for i in range(batch_size):
1585
inputs = dict((k, v[i]) for k, v in encoded_inputs.items())
1586
outputs = self._pad(
1587
inputs,
1588
max_length=max_length,
1589
padding_strategy=padding_strategy,
1590
pad_to_multiple_of=pad_to_multiple_of,
1591
return_attention_mask=return_attention_mask,
1592
)
1593
1594
for key, value in outputs.items():
1595
if key not in batch_outputs:
1596
batch_outputs[key] = []
1597
batch_outputs[key].append(value)
1598
1599
return BatchEncoding(batch_outputs, tensor_type=return_tensors)
1600
1601
def create_token_type_ids_from_sequences(
1602
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
1603
) -> List[int]:
1604
if token_ids_1 is None:
1605
return len(token_ids_0) * [0]
1606
return [0] * len(token_ids_0) + [1] * len(token_ids_1)
1607
1608
def build_inputs_with_special_tokens(
1609
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
1610
) -> List[int]:
1611
if token_ids_1 is None:
1612
return token_ids_0
1613
return token_ids_0 + token_ids_1
1614
1615
def prepare_for_model(
1616
self,
1617
ids: List[int],
1618
pair_ids: Optional[List[int]] = None,
1619
add_special_tokens: bool = True,
1620
padding: Union[bool, str, PaddingStrategy] = False,
1621
truncation: Union[bool, str, TruncationStrategy] = False,
1622
max_length: Optional[int] = None,
1623
stride: int = 0,
1624
pad_to_multiple_of: Optional[int] = None,
1625
return_tensors: Optional[Union[str, TensorType]] = None,
1626
return_token_type_ids: Optional[bool] = None,
1627
return_attention_mask: Optional[bool] = None,
1628
return_overflowing_tokens: bool = False,
1629
return_special_tokens_mask: bool = False,
1630
return_offsets_mapping: bool = False,
1631
return_length: bool = False,
1632
verbose: bool = True,
1633
prepend_batch_axis: bool = False,
1634
**kwargs
1635
) -> BatchEncoding:
1636
# Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
1637
padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
1638
padding=padding,
1639
truncation=truncation,
1640
max_length=max_length,
1641
pad_to_multiple_of=pad_to_multiple_of,
1642
verbose=verbose,
1643
**kwargs,
1644
)
1645
1646
pair = bool(pair_ids is not None)
1647
len_ids = len(ids)
1648
len_pair_ids = len(pair_ids) if pair else 0
1649
1650
if return_token_type_ids and not add_special_tokens:
1651
raise ValueError(
1652
"Asking to return token_type_ids while setting add_special_tokens to False "
1653
"results in an undefined behavior. Please set add_special_tokens to True or "
1654
"set return_token_type_ids to None."
1655
)
1656
1657
# Load from model defaults
1658
if return_token_type_ids is None:
1659
return_token_type_ids = "token_type_ids" in self.model_input_names
1660
if return_attention_mask is None:
1661
return_attention_mask = "attention_mask" in self.model_input_names
1662
1663
encoded_inputs = {}
1664
1665
# Compute the total size of the returned encodings
1666
total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0)
1667
1668
# Truncation: Handle max sequence length
1669
overflowing_tokens = []
1670
if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length:
1671
ids, pair_ids, overflowing_tokens = self.truncate_sequences(
1672
ids,
1673
pair_ids=pair_ids,
1674
num_tokens_to_remove=total_len - max_length,
1675
truncation_strategy=truncation_strategy,
1676
stride=stride,
1677
)
1678
1679
if return_overflowing_tokens:
1680
encoded_inputs["overflowing_tokens"] = overflowing_tokens
1681
encoded_inputs["num_truncated_tokens"] = total_len - max_length
1682
1683
# Add special tokens
1684
if add_special_tokens:
1685
sequence = self.build_inputs_with_special_tokens(ids, pair_ids)
1686
token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids)
1687
else:
1688
sequence = ids + pair_ids if pair else ids
1689
token_type_ids = [0] * len(ids) + ([0] * len(pair_ids) if pair else [])
1690
1691
# Build output dictionary
1692
encoded_inputs["input_ids"] = sequence
1693
if return_token_type_ids:
1694
encoded_inputs["token_type_ids"] = token_type_ids
1695
if return_special_tokens_mask:
1696
if add_special_tokens:
1697
encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids)
1698
else:
1699
encoded_inputs["special_tokens_mask"] = [0] * len(sequence)
1700
1701
# Check lengths
1702
self._eventual_warn_about_too_long_sequence(encoded_inputs["input_ids"], max_length, verbose)
1703
1704
# Padding
1705
if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask:
1706
encoded_inputs = self.pad(
1707
encoded_inputs,
1708
max_length=max_length,
1709
padding=padding_strategy.value,
1710
pad_to_multiple_of=pad_to_multiple_of,
1711
return_attention_mask=return_attention_mask,
1712
)
1713
1714
if return_length:
1715
encoded_inputs["length"] = len(encoded_inputs["input_ids"])
1716
1717
batch_outputs = BatchEncoding(
1718
encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis
1719
)
1720
1721
return batch_outputs
1722
1723
def truncate_sequences(
1724
self,
1725
ids: List[int],
1726
pair_ids: Optional[List[int]] = None,
1727
num_tokens_to_remove: int = 0,
1728
truncation_strategy: Union[str, TruncationStrategy] = "longest_first",
1729
stride: int = 0,
1730
) -> Tuple[List[int], List[int], List[int]]:
1731
if num_tokens_to_remove <= 0:
1732
return ids, pair_ids, []
1733
1734
if not isinstance(truncation_strategy, TruncationStrategy):
1735
truncation_strategy = TruncationStrategy(truncation_strategy)
1736
1737
overflowing_tokens = []
1738
if truncation_strategy == TruncationStrategy.LONGEST_FIRST:
1739
for _ in range(num_tokens_to_remove):
1740
if pair_ids is None or len(ids) > len(pair_ids):
1741
if not overflowing_tokens:
1742
window_len = min(len(ids), stride + 1)
1743
else:
1744
window_len = 1
1745
overflowing_tokens.extend(ids[-window_len:])
1746
ids = ids[:-1]
1747
else:
1748
if not overflowing_tokens:
1749
window_len = min(len(pair_ids), stride + 1)
1750
else:
1751
window_len = 1
1752
overflowing_tokens.extend(pair_ids[-window_len:])
1753
pair_ids = pair_ids[:-1]
1754
elif truncation_strategy == TruncationStrategy.ONLY_FIRST:
1755
if len(ids) > num_tokens_to_remove:
1756
window_len = min(len(ids), stride + num_tokens_to_remove)
1757
overflowing_tokens = ids[-window_len:]
1758
ids = ids[:-num_tokens_to_remove]
1759
elif truncation_strategy == TruncationStrategy.ONLY_SECOND and pair_ids is not None:
1760
if len(pair_ids) > num_tokens_to_remove:
1761
window_len = min(len(pair_ids), stride + num_tokens_to_remove)
1762
overflowing_tokens = pair_ids[-window_len:]
1763
pair_ids = pair_ids[:-num_tokens_to_remove]
1764
1765
return (ids, pair_ids, overflowing_tokens)
1766
1767
def _pad(
1768
self,
1769
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
1770
max_length: Optional[int] = None,
1771
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
1772
pad_to_multiple_of: Optional[int] = None,
1773
return_attention_mask: Optional[bool] = None,
1774
) -> dict:
1775
# Load from model defaults
1776
if return_attention_mask is None:
1777
return_attention_mask = "attention_mask" in self.model_input_names
1778
1779
required_input = encoded_inputs[self.model_input_names[0]]
1780
1781
if padding_strategy == PaddingStrategy.LONGEST:
1782
max_length = len(required_input)
1783
1784
if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
1785
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
1786
1787
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
1788
1789
if needs_to_be_padded:
1790
difference = max_length - len(required_input)
1791
if self.padding_side == "right":
1792
if return_attention_mask:
1793
encoded_inputs["attention_mask"] = [1] * len(required_input) + [0] * difference
1794
if "token_type_ids" in encoded_inputs:
1795
encoded_inputs["token_type_ids"] = (
1796
encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference
1797
)
1798
if "special_tokens_mask" in encoded_inputs:
1799
encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference
1800
encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference
1801
elif self.padding_side == "left":
1802
if return_attention_mask:
1803
encoded_inputs["attention_mask"] = [0] * difference + [1] * len(required_input)
1804
if "token_type_ids" in encoded_inputs:
1805
encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
1806
"token_type_ids"
1807
]
1808
if "special_tokens_mask" in encoded_inputs:
1809
encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
1810
encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
1811
else:
1812
raise ValueError("Invalid padding strategy:" + str(self.padding_side))
1813
elif return_attention_mask and "attention_mask" not in encoded_inputs:
1814
encoded_inputs["attention_mask"] = [1] * len(required_input)
1815
1816
return encoded_inputs
1817
1818
def convert_tokens_to_string(self, tokens: List[str]) -> str:
1819
raise NotImplementedError
1820
1821
def batch_decode(
1822
self,
1823
sequences: Union[List[int], List[List[int]], "np.ndarray", "torch.Tensor", "tf.Tensor"],
1824
skip_special_tokens: bool = False,
1825
clean_up_tokenization_spaces: bool = True,
1826
**kwargs
1827
) -> List[str]:
1828
return [
1829
self.decode(
1830
seq,
1831
skip_special_tokens=skip_special_tokens,
1832
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
1833
**kwargs,
1834
)
1835
for seq in sequences
1836
]
1837
1838
def decode(
1839
self,
1840
token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor", "tf.Tensor"],
1841
skip_special_tokens: bool = False,
1842
clean_up_tokenization_spaces: bool = True,
1843
**kwargs
1844
) -> str:
1845
# Convert inputs to python lists
1846
token_ids = to_py_obj(token_ids)
1847
1848
return self._decode(
1849
token_ids=token_ids,
1850
skip_special_tokens=skip_special_tokens,
1851
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
1852
**kwargs,
1853
)
1854
1855
def _decode(
1856
self,
1857
token_ids: Union[int, List[int]],
1858
skip_special_tokens: bool = False,
1859
clean_up_tokenization_spaces: bool = True,
1860
**kwargs
1861
) -> str:
1862
raise NotImplementedError
1863
1864
def get_special_tokens_mask(
1865
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
1866
) -> List[int]:
1867
assert already_has_special_tokens and token_ids_1 is None, (
1868
"You cannot use ``already_has_special_tokens=False`` with this tokenizer. "
1869
"Please use a slow (full python) tokenizer to activate this argument."
1870
"Or set `return_special_tokens_mask=True` when calling the encoding method "
1871
"to get the special tokens mask in any tokenizer. "
1872
)
1873
1874
all_special_ids = self.all_special_ids # cache the property
1875
1876
special_tokens_mask = [1 if token in all_special_ids else 0 for token in token_ids_0]
1877
1878
return special_tokens_mask
1879
1880
@staticmethod
1881
def clean_up_tokenization(out_string: str) -> str:
1882
"""
1883
Clean up a list of simple English tokenization artifacts like spaces before punctuations and abbreviated forms.
1884
Args:
1885
out_string (:obj:`str`): The text to clean up.
1886
Returns:
1887
:obj:`str`: The cleaned-up string.
1888
"""
1889
out_string = (
1890
out_string.replace(" .", ".")
1891
.replace(" ?", "?")
1892
.replace(" !", "!")
1893
.replace(" ,", ",")
1894
.replace(" ' ", "'")
1895
.replace(" n't", "n't")
1896
.replace(" 'm", "'m")
1897
.replace(" 's", "'s")
1898
.replace(" 've", "'ve")
1899
.replace(" 're", "'re")
1900
)
1901
return out_string
1902
1903
def _eventual_warn_about_too_long_sequence(self, ids: List[int], max_length: Optional[int], verbose: bool):
1904
if max_length is None and len(ids) > self.model_max_length and verbose:
1905
self.deprecation_warnings["sequence-length-is-longer-than-the-specified-maximum"] = True
1906
1907
@contextmanager
1908
def as_target_tokenizer(self):
1909
yield
1910
1911
def prepare_seq2seq_batch(
1912
self,
1913
src_texts: List[str],
1914
tgt_texts: Optional[List[str]] = None,
1915
max_length: Optional[int] = None,
1916
max_target_length: Optional[int] = None,
1917
padding: str = "longest",
1918
return_tensors: str = None,
1919
truncation: bool = True,
1920
**kwargs,
1921
) -> BatchEncoding:
1922
# mBART-specific kwargs that should be ignored by other models.
1923
kwargs.pop("src_lang", None)
1924
kwargs.pop("tgt_lang", None)
1925
if max_length is None:
1926
max_length = self.model_max_length
1927
model_inputs = self(
1928
src_texts,
1929
add_special_tokens=True,
1930
return_tensors=return_tensors,
1931
max_length=max_length,
1932
padding=padding,
1933
truncation=truncation,
1934
**kwargs,
1935
)
1936
if tgt_texts is None:
1937
return model_inputs
1938
# Process tgt_texts
1939
if max_target_length is None:
1940
max_target_length = max_length
1941
with self.as_target_tokenizer():
1942
labels = self(
1943
tgt_texts,
1944
add_special_tokens=True,
1945
return_tensors=return_tensors,
1946
padding=padding,
1947
max_length=max_target_length,
1948
truncation=truncation,
1949
**kwargs,
1950
)
1951
model_inputs["labels"] = labels["input_ids"]
1952
return model_inputs
1953
1954
1955
class PreTrainedTokenizer(PreTrainedTokenizerBase):
1956
def __init__(self, **kwargs):
1957
super().__init__(**kwargs)
1958
# Added tokens - We store this for both slow and fast tokenizers
1959
# until the serialization of Fast tokenizers is updated
1960
self.added_tokens_encoder: Dict[str, int] = {}
1961
self.added_tokens_decoder: Dict[int, str] = {}
1962
self.unique_no_split_tokens: List[str] = []
1963
1964
@property
1965
def is_fast(self) -> bool:
1966
return False
1967
1968
@property
1969
def vocab_size(self) -> int:
1970
"""
1971
:obj:`int`: Size of the base vocabulary (without the added tokens).
1972
"""
1973
raise NotImplementedError
1974
1975
def get_added_vocab(self) -> Dict[str, int]:
1976
"""
1977
Returns the added tokens in the vocabulary as a dictionary of token to index.
1978
Returns:
1979
:obj:`Dict[str, int]`: The added tokens.
1980
"""
1981
return self.added_tokens_encoder
1982
1983
def __len__(self):
1984
"""
1985
Size of the full vocabulary with the added tokens.
1986
"""
1987
return self.vocab_size + len(self.added_tokens_encoder)
1988
1989
def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:
1990
"""
1991
Add a list of new tokens to the tokenizer class. If the new tokens are not in the vocabulary, they are added to
1992
it with indices starting from length of the current vocabulary.
1993
Args:
1994
new_tokens (:obj:`List[str]`or :obj:`List[tokenizers.AddedToken]`):
1995
Token(s) to add in vocabulary. A token is only added if it's not already in the vocabulary (tested by
1996
checking if the tokenizer assign the index of the ``unk_token`` to them).
1997
special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
1998
Whether or not the tokens should be added as special tokens.
1999
Returns:
2000
:obj:`int`: The number of tokens actually added to the vocabulary.
2001
Examples::
2002
# Let's see how to increase the vocabulary of Bert model and tokenizer
2003
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
2004
model = BertModel.from_pretrained('bert-base-uncased')
2005
num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2'])
2006
print('We have added', num_added_toks, 'tokens')
2007
# Note: resize_token_embeddings expects to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
2008
model.resize_token_embeddings(len(tokenizer))
2009
"""
2010
new_tokens = [str(tok) for tok in new_tokens]
2011
2012
tokens_to_add = []
2013
for token in new_tokens:
2014
assert isinstance(token, str)
2015
if not special_tokens and hasattr(self, "do_lower_case") and self.do_lower_case:
2016
token = token.lower()
2017
if (
2018
token != self.unk_token
2019
and self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token)
2020
and token not in tokens_to_add
2021
):
2022
tokens_to_add.append(token)
2023
2024
added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(tokens_to_add))
2025
added_tok_decoder = {v: k for k, v in added_tok_encoder.items()}
2026
self.added_tokens_encoder.update(added_tok_encoder)
2027
self.added_tokens_decoder.update(added_tok_decoder)
2028
2029
# Make sure we don't split on any special tokens (even they were already in the vocab before e.g. for Albert)
2030
if special_tokens:
2031
self.unique_no_split_tokens = sorted(set(self.unique_no_split_tokens).union(set(new_tokens)))
2032
else:
2033
# Or on the newly added tokens
2034
self.unique_no_split_tokens = sorted(set(self.unique_no_split_tokens).union(set(tokens_to_add)))
2035
2036
return len(tokens_to_add)
2037
2038
def num_special_tokens_to_add(self, pair: bool = False) -> int:
2039
"""
2040
Returns the number of added tokens when encoding a sequence with special tokens.
2041
.. note::
2042
This encodes a dummy input and checks the number of added tokens, and is therefore not efficient. Do not
2043
put this inside your training loop.
2044
Args:
2045
pair (:obj:`bool`, `optional`, defaults to :obj:`False`):
2046
Whether the number of added tokens should be computed in the case of a sequence pair or a single
2047
sequence.
2048
Returns:
2049
:obj:`int`: Number of special tokens added to sequences.
2050
"""
2051
token_ids_0 = []
2052
token_ids_1 = []
2053
return len(self.build_inputs_with_special_tokens(token_ids_0, token_ids_1 if pair else None))
2054
2055
def tokenize(self, text: TextInput, **kwargs) -> List[str]:
2056
"""
2057
Converts a string in a sequence of tokens, using the tokenizer.
2058
Split in words for word-based vocabulary or sub-words for sub-word-based vocabularies
2059
(BPE/SentencePieces/WordPieces). Takes care of added tokens.
2060
Args:
2061
text (:obj:`str`):
2062
The sequence to be encoded.
2063
**kwargs (additional keyword arguments):
2064
Passed along to the model-specific ``prepare_for_tokenization`` preprocessing method.
2065
Returns:
2066
:obj:`List[str]`: The list of tokens.
2067
"""
2068
# Simple mapping string => AddedToken for special tokens with specific tokenization behaviors
2069
all_special_tokens_extended = dict(
2070
(str(t), t) for t in self.all_special_tokens_extended if isinstance(t, AddedToken)
2071
)
2072
2073
text, kwargs = self.prepare_for_tokenization(text, **kwargs)
2074
2075
# TODO: should this be in the base class?
2076
if hasattr(self, "do_lower_case") and self.do_lower_case:
2077
# convert non-special tokens to lowercase
2078
escaped_special_toks = [re.escape(s_tok) for s_tok in self.all_special_tokens]
2079
pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)"
2080
text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text)
2081
2082
def split_on_token(tok, text):
2083
result = []
2084
tok_extended = all_special_tokens_extended.get(tok, None)
2085
split_text = text.split(tok)
2086
full_word = ""
2087
for i, sub_text in enumerate(split_text):
2088
# AddedToken can control whitespace stripping around them.
2089
# We use them for GPT2 and Roberta to have different behavior depending on the special token
2090
# Cf. https://github.com/huggingface/transformers/pull/2778
2091
# and https://github.com/huggingface/transformers/issues/3788
2092
if isinstance(tok_extended, AddedToken):
2093
if tok_extended.single_word:
2094
# Try to avoid splitting on token
2095
if (
2096
i < len(split_text) - 1
2097
and not _is_end_of_word(sub_text)
2098
and not _is_start_of_word(split_text[i + 1])
2099
):
2100
# Don't extract the special token
2101
full_word += sub_text + tok
2102
elif full_word:
2103
full_word += sub_text
2104
result.append(full_word)
2105
full_word = ""
2106
continue
2107
# Strip white spaces on the right
2108
if tok_extended.rstrip and i > 0:
2109
# A bit counter-intuitive but we strip the left of the string
2110
# since tok_extended.rstrip means the special token is eating all white spaces on its right
2111
sub_text = sub_text.lstrip()
2112
# Strip white spaces on the left
2113
if tok_extended.lstrip and i < len(split_text) - 1:
2114
sub_text = sub_text.rstrip() # Opposite here
2115
else:
2116
# We strip left and right by default
2117
if i < len(split_text) - 1:
2118
sub_text = sub_text.rstrip()
2119
if i > 0:
2120
sub_text = sub_text.lstrip()
2121
2122
if i == 0 and not sub_text:
2123
result.append(tok)
2124
elif i == len(split_text) - 1:
2125
if sub_text:
2126
result.append(sub_text)
2127
else:
2128
pass
2129
else:
2130
if sub_text:
2131
result.append(sub_text)
2132
result.append(tok)
2133
return result
2134
2135
def split_on_tokens(tok_list, text):
2136
if not text.strip():
2137
return []
2138
if not tok_list:
2139
return self._tokenize(text)
2140
2141
tokenized_text = []
2142
text_list = [text]
2143
for tok in tok_list:
2144
tokenized_text = []
2145
for sub_text in text_list:
2146
if sub_text not in self.unique_no_split_tokens:
2147
tokenized_text.extend(split_on_token(tok, sub_text))
2148
else:
2149
tokenized_text.append(sub_text)
2150
text_list = tokenized_text
2151
2152
return list(
2153
itertools.chain.from_iterable(
2154
(
2155
self._tokenize(token) if token not in self.unique_no_split_tokens else [token]
2156
for token in tokenized_text
2157
)
2158
)
2159
)
2160
2161
no_split_token = self.unique_no_split_tokens
2162
tokenized_text = split_on_tokens(no_split_token, text)
2163
return tokenized_text
2164
2165
def _tokenize(self, text, **kwargs):
2166
"""
2167
Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based
2168
vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces).
2169
Do NOT take care of added tokens.
2170
"""
2171
raise NotImplementedError
2172
2173
def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]:
2174
"""
2175
Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the
2176
vocabulary.
2177
Args:
2178
tokens (:obj:`str` or :obj:`List[str]`): One or several token(s) to convert to token id(s).
2179
Returns:
2180
:obj:`int` or :obj:`List[int]`: The token id or list of token ids.
2181
"""
2182
if tokens is None:
2183
return None
2184
2185
if isinstance(tokens, str):
2186
return self._convert_token_to_id_with_added_voc(tokens)
2187
2188
ids = []
2189
for token in tokens:
2190
ids.append(self._convert_token_to_id_with_added_voc(token))
2191
return ids
2192
2193
def _convert_token_to_id_with_added_voc(self, token):
2194
if token is None:
2195
return None
2196
2197
if token in self.added_tokens_encoder:
2198
return self.added_tokens_encoder[token]
2199
return self._convert_token_to_id(token)
2200
2201
def _convert_token_to_id(self, token):
2202
raise NotImplementedError
2203
2204
def _encode_plus(
2205
self,
2206
text: Union[TextInput, PreTokenizedInput, EncodedInput],
2207
text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None,
2208
add_special_tokens: bool = True,
2209
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
2210
truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
2211
max_length: Optional[int] = None,
2212
stride: int = 0,
2213
is_split_into_words: bool = False,
2214
pad_to_multiple_of: Optional[int] = None,
2215
return_tensors: Optional[Union[str, TensorType]] = None,
2216
return_token_type_ids: Optional[bool] = None,
2217
return_attention_mask: Optional[bool] = None,
2218
return_overflowing_tokens: bool = False,
2219
return_special_tokens_mask: bool = False,
2220
return_offsets_mapping: bool = False,
2221
return_length: bool = False,
2222
verbose: bool = True,
2223
**kwargs
2224
) -> BatchEncoding:
2225
def get_input_ids(text):
2226
if isinstance(text, str):
2227
tokens = self.tokenize(text, **kwargs)
2228
return self.convert_tokens_to_ids(tokens)
2229
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str):
2230
if is_split_into_words:
2231
tokens = list(
2232
itertools.chain(*(self.tokenize(t, is_split_into_words=True, **kwargs) for t in text))
2233
)
2234
return self.convert_tokens_to_ids(tokens)
2235
else:
2236
return self.convert_tokens_to_ids(text)
2237
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
2238
return text
2239
else:
2240
if is_split_into_words:
2241
raise ValueError(
2242
f"Input {text} is not valid. Should be a string or a list/tuple of strings when `is_split_into_words=True`."
2243
)
2244
else:
2245
raise ValueError(
2246
f"Input {text} is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers."
2247
)
2248
2249
if return_offsets_mapping:
2250
raise NotImplementedError(
2251
"return_offset_mapping is not available when using Python tokenizers."
2252
"To use this feature, change your tokenizer to one deriving from "
2253
"transformers.PreTrainedTokenizerFast."
2254
"More information on available tokenizers at "
2255
"https://github.com/huggingface/transformers/pull/2674"
2256
)
2257
2258
first_ids = get_input_ids(text)
2259
second_ids = get_input_ids(text_pair) if text_pair is not None else None
2260
2261
return self.prepare_for_model(
2262
first_ids,
2263
pair_ids=second_ids,
2264
add_special_tokens=add_special_tokens,
2265
padding=padding_strategy.value,
2266
truncation=truncation_strategy.value,
2267
max_length=max_length,
2268
stride=stride,
2269
pad_to_multiple_of=pad_to_multiple_of,
2270
return_tensors=return_tensors,
2271
prepend_batch_axis=True,
2272
return_attention_mask=return_attention_mask,
2273
return_token_type_ids=return_token_type_ids,
2274
return_overflowing_tokens=return_overflowing_tokens,
2275
return_special_tokens_mask=return_special_tokens_mask,
2276
return_length=return_length,
2277
verbose=verbose,
2278
)
2279
2280
def _batch_encode_plus(
2281
self,
2282
batch_text_or_text_pairs: Union[
2283
List[TextInput],
2284
List[TextInputPair],
2285
List[PreTokenizedInput],
2286
List[PreTokenizedInputPair],
2287
List[EncodedInput],
2288
List[EncodedInputPair],
2289
],
2290
add_special_tokens: bool = True,
2291
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
2292
truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
2293
max_length: Optional[int] = None,
2294
stride: int = 0,
2295
is_split_into_words: bool = False,
2296
pad_to_multiple_of: Optional[int] = None,
2297
return_tensors: Optional[Union[str, TensorType]] = None,
2298
return_token_type_ids: Optional[bool] = None,
2299
return_attention_mask: Optional[bool] = None,
2300
return_overflowing_tokens: bool = False,
2301
return_special_tokens_mask: bool = False,
2302
return_offsets_mapping: bool = False,
2303
return_length: bool = False,
2304
verbose: bool = True,
2305
**kwargs
2306
) -> BatchEncoding:
2307
def get_input_ids(text):
2308
if isinstance(text, str):
2309
tokens = self.tokenize(text, **kwargs)
2310
return self.convert_tokens_to_ids(tokens)
2311
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str):
2312
if is_split_into_words:
2313
tokens = list(
2314
itertools.chain(*(self.tokenize(t, is_split_into_words=True, **kwargs) for t in text))
2315
)
2316
return self.convert_tokens_to_ids(tokens)
2317
else:
2318
return self.convert_tokens_to_ids(text)
2319
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
2320
return text
2321
else:
2322
raise ValueError(
2323
"Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers."
2324
)
2325
2326
if return_offsets_mapping:
2327
raise NotImplementedError(
2328
"return_offset_mapping is not available when using Python tokenizers."
2329
"To use this feature, change your tokenizer to one deriving from "
2330
"transformers.PreTrainedTokenizerFast."
2331
)
2332
2333
input_ids = []
2334
for ids_or_pair_ids in batch_text_or_text_pairs:
2335
if not isinstance(ids_or_pair_ids, (list, tuple)):
2336
ids, pair_ids = ids_or_pair_ids, None
2337
elif is_split_into_words and not isinstance(ids_or_pair_ids[0], (list, tuple)):
2338
ids, pair_ids = ids_or_pair_ids, None
2339
else:
2340
ids, pair_ids = ids_or_pair_ids
2341
2342
first_ids = get_input_ids(ids)
2343
second_ids = get_input_ids(pair_ids) if pair_ids is not None else None
2344
input_ids.append((first_ids, second_ids))
2345
2346
batch_outputs = self._batch_prepare_for_model(
2347
input_ids,
2348
add_special_tokens=add_special_tokens,
2349
padding_strategy=padding_strategy,
2350
truncation_strategy=truncation_strategy,
2351
max_length=max_length,
2352
stride=stride,
2353
pad_to_multiple_of=pad_to_multiple_of,
2354
return_attention_mask=return_attention_mask,
2355
return_token_type_ids=return_token_type_ids,
2356
return_overflowing_tokens=return_overflowing_tokens,
2357
return_special_tokens_mask=return_special_tokens_mask,
2358
return_length=return_length,
2359
return_tensors=return_tensors,
2360
verbose=verbose,
2361
)
2362
2363
return BatchEncoding(batch_outputs)
2364
2365
def _batch_prepare_for_model(
2366
self,
2367
batch_ids_pairs: List[Union[PreTokenizedInputPair, Tuple[List[int], None]]],
2368
add_special_tokens: bool = True,
2369
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
2370
truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
2371
max_length: Optional[int] = None,
2372
stride: int = 0,
2373
pad_to_multiple_of: Optional[int] = None,
2374
return_tensors: Optional[str] = None,
2375
return_token_type_ids: Optional[bool] = None,
2376
return_attention_mask: Optional[bool] = None,
2377
return_overflowing_tokens: bool = False,
2378
return_special_tokens_mask: bool = False,
2379
return_length: bool = False,
2380
verbose: bool = True,
2381
) -> BatchEncoding:
2382
"""
2383
Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It
2384
adds special tokens, truncates sequences if overflowing while taking into account the special tokens and
2385
manages a moving window (with user defined stride) for overflowing tokens
2386
Args:
2387
batch_ids_pairs: list of tokenized input ids or input ids pairs
2388
"""
2389
2390
batch_outputs = {}
2391
for first_ids, second_ids in batch_ids_pairs:
2392
outputs = self.prepare_for_model(
2393
first_ids,
2394
second_ids,
2395
add_special_tokens=add_special_tokens,
2396
padding=PaddingStrategy.DO_NOT_PAD.value, # we pad in batch afterward
2397
truncation=truncation_strategy.value,
2398
max_length=max_length,
2399
stride=stride,
2400
pad_to_multiple_of=None, # we pad in batch afterward
2401
return_attention_mask=False, # we pad in batch afterward
2402
return_token_type_ids=return_token_type_ids,
2403
return_overflowing_tokens=return_overflowing_tokens,
2404
return_special_tokens_mask=return_special_tokens_mask,
2405
return_length=return_length,
2406
return_tensors=None, # We convert the whole batch to tensors at the end
2407
prepend_batch_axis=False,
2408
verbose=verbose,
2409
)
2410
2411
for key, value in outputs.items():
2412
if key not in batch_outputs:
2413
batch_outputs[key] = []
2414
batch_outputs[key].append(value)
2415
2416
batch_outputs = self.pad(
2417
batch_outputs,
2418
padding=padding_strategy.value,
2419
max_length=max_length,
2420
pad_to_multiple_of=pad_to_multiple_of,
2421
return_attention_mask=return_attention_mask,
2422
)
2423
2424
batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors)
2425
2426
return batch_outputs
2427
2428
def prepare_for_tokenization(
2429
self, text: str, is_split_into_words: bool = False, **kwargs
2430
) -> Tuple[str, Dict[str, Any]]:
2431
"""
2432
Performs any necessary transformations before tokenization.
2433
This method should pop the arguments from kwargs and return the remaining :obj:`kwargs` as well. We test the
2434
:obj:`kwargs` at the end of the encoding process to be sure all the arguments have been used.
2435
Args:
2436
text (:obj:`str`):
2437
The text to prepare.
2438
is_split_into_words (:obj:`bool`, `optional`, defaults to :obj:`False`):
2439
Whether or not the text has been pretokenized.
2440
kwargs:
2441
Keyword arguments to use for the tokenization.
2442
Returns:
2443
:obj:`Tuple[str, Dict[str, Any]]`: The prepared text and the unused kwargs.
2444
"""
2445
return (text, kwargs)
2446
2447
def get_special_tokens_mask(
2448
self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False
2449
) -> List[int]:
2450
"""
2451
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
2452
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
2453
Args:
2454
token_ids_0 (:obj:`List[int]`):
2455
List of ids of the first sequence.
2456
token_ids_1 (:obj:`List[int]`, `optional`):
2457
List of ids of the second sequence.
2458
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
2459
Whether or not the token list is already formatted with special tokens for the model.
2460
Returns:
2461
A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
2462
"""
2463
return [0] * ((len(token_ids_1) if token_ids_1 else 0) + len(token_ids_0))
2464
2465
@overload
2466
def convert_ids_to_tokens(self, ids: int, skip_special_tokens: bool = False) -> str:
2467
...
2468
2469
@overload
2470
def convert_ids_to_tokens(self, ids: List[int], skip_special_tokens: bool = False) -> List[str]:
2471
...
2472
2473
def convert_ids_to_tokens(
2474
self, ids: Union[int, List[int]], skip_special_tokens: bool = False
2475
) -> Union[str, List[str]]:
2476
"""
2477
Converts a single index or a sequence of indices in a token or a sequence of tokens, using the vocabulary and
2478
added tokens.
2479
Args:
2480
ids (:obj:`int` or :obj:`List[int]`):
2481
The token id (or token ids) to convert to tokens.
2482
skip_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
2483
Whether or not to remove special tokens in the decoding.
2484
Returns:
2485
:obj:`str` or :obj:`List[str]`: The decoded token(s).
2486
"""
2487
if isinstance(ids, int):
2488
if ids in self.added_tokens_decoder:
2489
return self.added_tokens_decoder[ids]
2490
else:
2491
return self._convert_id_to_token(ids)
2492
tokens = []
2493
for index in ids:
2494
index = int(index)
2495
if skip_special_tokens and index in self.all_special_ids:
2496
continue
2497
if index in self.added_tokens_decoder:
2498
tokens.append(self.added_tokens_decoder[index])
2499
else:
2500
tokens.append(self._convert_id_to_token(index))
2501
return tokens
2502
2503
def _convert_id_to_token(self, index: int) -> str:
2504
raise NotImplementedError
2505
2506
def convert_tokens_to_string(self, tokens: List[str]) -> str:
2507
return " ".join(tokens)
2508
2509
def _decode(
2510
self,
2511
token_ids: List[int],
2512
skip_special_tokens: bool = False,
2513
clean_up_tokenization_spaces: bool = True,
2514
spaces_between_special_tokens: bool = True,
2515
) -> str:
2516
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
2517
2518
# To avoid mixing byte-level and unicode for byte-level BPT
2519
# we need to build string separately for added tokens and byte-level tokens
2520
# cf. https://github.com/huggingface/transformers/issues/1133
2521
sub_texts = []
2522
current_sub_text = []
2523
for token in filtered_tokens:
2524
if skip_special_tokens and token in self.all_special_ids:
2525
continue
2526
if token in self.added_tokens_encoder:
2527
if current_sub_text:
2528
sub_texts.append(self.convert_tokens_to_string(current_sub_text))
2529
current_sub_text = []
2530
sub_texts.append(token)
2531
else:
2532
current_sub_text.append(token)
2533
if current_sub_text:
2534
sub_texts.append(self.convert_tokens_to_string(current_sub_text))
2535
2536
if spaces_between_special_tokens:
2537
text = " ".join(sub_texts)
2538
else:
2539
text = "".join(sub_texts)
2540
2541
if clean_up_tokenization_spaces:
2542
clean_text = self.clean_up_tokenization(text)
2543
return clean_text
2544
else:
2545
return text
2546
2547
2548
2549
class BertTokenizer(PreTrainedTokenizer):
2550
vocab_files_names = VOCAB_FILES_NAMES
2551
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
2552
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
2553
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
2554
2555
def __init__(
2556
self,
2557
vocab_file,
2558
do_lower_case=True,
2559
do_basic_tokenize=True,
2560
never_split=None,
2561
unk_token="[UNK]",
2562
sep_token="[SEP]",
2563
pad_token="[PAD]",
2564
cls_token="[CLS]",
2565
mask_token="[MASK]",
2566
tokenize_chinese_chars=True,
2567
strip_accents=None,
2568
**kwargs
2569
):
2570
super().__init__(
2571
do_lower_case=do_lower_case,
2572
do_basic_tokenize=do_basic_tokenize,
2573
never_split=never_split,
2574
unk_token=unk_token,
2575
sep_token=sep_token,
2576
pad_token=pad_token,
2577
cls_token=cls_token,
2578
mask_token=mask_token,
2579
tokenize_chinese_chars=tokenize_chinese_chars,
2580
strip_accents=strip_accents,
2581
**kwargs,
2582
)
2583
self.vocab = load_vocab(vocab_file)
2584
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
2585
self.do_basic_tokenize = do_basic_tokenize
2586
if do_basic_tokenize:
2587
self.basic_tokenizer = BasicTokenizer(
2588
do_lower_case=do_lower_case,
2589
never_split=never_split,
2590
tokenize_chinese_chars=tokenize_chinese_chars,
2591
strip_accents=strip_accents,
2592
)
2593
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
2594
2595
@property
2596
def do_lower_case(self):
2597
return self.basic_tokenizer.do_lower_case
2598
2599
@property
2600
def vocab_size(self):
2601
return len(self.vocab)
2602
2603
def get_vocab(self):
2604
return dict(self.vocab, **self.added_tokens_encoder)
2605
2606
def _tokenize(self, text):
2607
split_tokens = []
2608
if self.do_basic_tokenize:
2609
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
2610
2611
# If the token is part of the never_split set
2612
if token in self.basic_tokenizer.never_split:
2613
split_tokens.append(token)
2614
else:
2615
split_tokens += self.wordpiece_tokenizer.tokenize(token)
2616
else:
2617
split_tokens = self.wordpiece_tokenizer.tokenize(text)
2618
return split_tokens
2619
2620
def _convert_token_to_id(self, token):
2621
return self.vocab.get(token, self.vocab.get(self.unk_token))
2622
2623
def _convert_id_to_token(self, index):
2624
return self.ids_to_tokens.get(index, self.unk_token)
2625
2626
def convert_tokens_to_string(self, tokens):
2627
out_string = " ".join(tokens).replace(" ##", "").strip()
2628
return out_string
2629
2630
def build_inputs_with_special_tokens(
2631
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
2632
) -> List[int]:
2633
if token_ids_1 is None:
2634
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
2635
cls = [self.cls_token_id]
2636
sep = [self.sep_token_id]
2637
return cls + token_ids_0 + sep + token_ids_1 + sep
2638
2639
def get_special_tokens_mask(
2640
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
2641
) -> List[int]:
2642
if already_has_special_tokens:
2643
if token_ids_1 is not None:
2644
raise ValueError(
2645
"You should not supply a second sequence if the provided sequence of "
2646
"ids is already formatted with special tokens for the model."
2647
)
2648
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
2649
2650
if token_ids_1 is not None:
2651
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
2652
return [1] + ([0] * len(token_ids_0)) + [1]
2653
2654
def create_token_type_ids_from_sequences(
2655
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
2656
) -> List[int]:
2657
sep = [self.sep_token_id]
2658
cls = [self.cls_token_id]
2659
if token_ids_1 is None:
2660
return len(cls + token_ids_0 + sep) * [0]
2661
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
2662
2663
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
2664
index = 0
2665
if os.path.isdir(save_directory):
2666
vocab_file = os.path.join(
2667
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
2668
)
2669
else:
2670
vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
2671
with open(vocab_file, "w", encoding="utf-8") as writer:
2672
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
2673
if index != token_index:
2674
index = token_index
2675
writer.write(token + "\n")
2676
index += 1
2677
return (vocab_file,)
2678
2679
2680
class BasicTokenizer(object):
2681
def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True, strip_accents=None):
2682
if never_split is None:
2683
never_split = []
2684
self.do_lower_case = do_lower_case
2685
self.never_split = set(never_split)
2686
self.tokenize_chinese_chars = tokenize_chinese_chars
2687
self.strip_accents = strip_accents
2688
2689
def tokenize(self, text, never_split=None):
2690
# union() returns a new set by concatenating the two sets.
2691
never_split = self.never_split.union(set(never_split)) if never_split else self.never_split
2692
text = self._clean_text(text)
2693
2694
# This was added on November 1st, 2018 for the multilingual and Chinese
2695
# models. This is also applied to the English models now, but it doesn't
2696
# matter since the English models were not trained on any Chinese data
2697
# and generally don't have any Chinese data in them (there are Chinese
2698
# characters in the vocabulary because Wikipedia does have some Chinese
2699
# words in the English Wikipedia.).
2700
if self.tokenize_chinese_chars:
2701
text = self._tokenize_chinese_chars(text)
2702
orig_tokens = whitespace_tokenize(text)
2703
split_tokens = []
2704
for token in orig_tokens:
2705
if token not in never_split:
2706
if self.do_lower_case:
2707
token = token.lower()
2708
if self.strip_accents is not False:
2709
token = self._run_strip_accents(token)
2710
elif self.strip_accents:
2711
token = self._run_strip_accents(token)
2712
split_tokens.extend(self._run_split_on_punc(token, never_split))
2713
2714
output_tokens = whitespace_tokenize(" ".join(split_tokens))
2715
return output_tokens
2716
2717
def _run_strip_accents(self, text):
2718
text = unicodedata.normalize("NFD", text)
2719
output = []
2720
for char in text:
2721
cat = unicodedata.category(char)
2722
if cat == "Mn":
2723
continue
2724
output.append(char)
2725
return "".join(output)
2726
2727
def _run_split_on_punc(self, text, never_split=None):
2728
if never_split is not None and text in never_split:
2729
return [text]
2730
chars = list(text)
2731
i = 0
2732
start_new_word = True
2733
output = []
2734
while i < len(chars):
2735
char = chars[i]
2736
if _is_punctuation(char):
2737
output.append([char])
2738
start_new_word = True
2739
else:
2740
if start_new_word:
2741
output.append([])
2742
start_new_word = False
2743
output[-1].append(char)
2744
i += 1
2745
2746
return ["".join(x) for x in output]
2747
2748
def _tokenize_chinese_chars(self, text):
2749
output = []
2750
for char in text:
2751
cp = ord(char)
2752
if self._is_chinese_char(cp):
2753
output.append(" ")
2754
output.append(char)
2755
output.append(" ")
2756
else:
2757
output.append(char)
2758
return "".join(output)
2759
2760
def _is_chinese_char(self, cp):
2761
# This defines a "chinese character" as anything in the CJK Unicode block:
2762
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
2763
#
2764
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
2765
# despite its name. The modern Korean Hangul alphabet is a different block,
2766
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
2767
# space-separated words, so they are not treated specially and handled
2768
# like the all of the other languages.
2769
if (
2770
(cp >= 0x4E00 and cp <= 0x9FFF)
2771
or (cp >= 0x3400 and cp <= 0x4DBF) #
2772
or (cp >= 0x20000 and cp <= 0x2A6DF) #
2773
or (cp >= 0x2A700 and cp <= 0x2B73F) #
2774
or (cp >= 0x2B740 and cp <= 0x2B81F) #
2775
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
2776
or (cp >= 0xF900 and cp <= 0xFAFF)
2777
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
2778
): #
2779
return True
2780
2781
return False
2782
2783
def _clean_text(self, text):
2784
output = []
2785
for char in text:
2786
cp = ord(char)
2787
if cp == 0 or cp == 0xFFFD or _is_control(char):
2788
continue
2789
if _is_whitespace(char):
2790
output.append(" ")
2791
else:
2792
output.append(char)
2793
return "".join(output)
2794
2795
2796
class WordpieceTokenizer(object):
2797
def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
2798
self.vocab = vocab
2799
self.unk_token = unk_token
2800
self.max_input_chars_per_word = max_input_chars_per_word
2801
2802
def tokenize(self, text):
2803
output_tokens = []
2804
for token in whitespace_tokenize(text):
2805
chars = list(token)
2806
if len(chars) > self.max_input_chars_per_word:
2807
output_tokens.append(self.unk_token)
2808
continue
2809
2810
is_bad = False
2811
start = 0
2812
sub_tokens = []
2813
while start < len(chars):
2814
end = len(chars)
2815
cur_substr = None
2816
while start < end:
2817
substr = "".join(chars[start:end])
2818
if start > 0:
2819
substr = "##" + substr
2820
if substr in self.vocab:
2821
cur_substr = substr
2822
break
2823
end -= 1
2824
if cur_substr is None:
2825
is_bad = True
2826
break
2827
sub_tokens.append(cur_substr)
2828
start = end
2829
2830
if is_bad:
2831
output_tokens.append(self.unk_token)
2832
else:
2833
output_tokens.extend(sub_tokens)
2834
return output_tokens
2835
2836