Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
prophesier
GitHub Repository: prophesier/diff-svc
Path: blob/main/utils/text_encoder.py
694 views
1
import re
2
import six
3
from six.moves import range # pylint: disable=redefined-builtin
4
5
PAD = "<pad>"
6
EOS = "<EOS>"
7
UNK = "<UNK>"
8
SEG = "|"
9
RESERVED_TOKENS = [PAD, EOS, UNK]
10
NUM_RESERVED_TOKENS = len(RESERVED_TOKENS)
11
PAD_ID = RESERVED_TOKENS.index(PAD) # Normally 0
12
EOS_ID = RESERVED_TOKENS.index(EOS) # Normally 1
13
UNK_ID = RESERVED_TOKENS.index(UNK) # Normally 2
14
15
if six.PY2:
16
RESERVED_TOKENS_BYTES = RESERVED_TOKENS
17
else:
18
RESERVED_TOKENS_BYTES = [bytes(PAD, "ascii"), bytes(EOS, "ascii")]
19
20
# Regular expression for unescaping token strings.
21
# '\u' is converted to '_'
22
# '\\' is converted to '\'
23
# '\213;' is converted to unichr(213)
24
_UNESCAPE_REGEX = re.compile(r"\\u|\\\\|\\([0-9]+);")
25
_ESCAPE_CHARS = set(u"\\_u;0123456789")
26
27
28
def strip_ids(ids, ids_to_strip):
29
"""Strip ids_to_strip from the end ids."""
30
ids = list(ids)
31
while ids and ids[-1] in ids_to_strip:
32
ids.pop()
33
return ids
34
35
36
class TextEncoder(object):
37
"""Base class for converting from ints to/from human readable strings."""
38
39
def __init__(self, num_reserved_ids=NUM_RESERVED_TOKENS):
40
self._num_reserved_ids = num_reserved_ids
41
42
@property
43
def num_reserved_ids(self):
44
return self._num_reserved_ids
45
46
def encode(self, s):
47
"""Transform a human-readable string into a sequence of int ids.
48
49
The ids should be in the range [num_reserved_ids, vocab_size). Ids [0,
50
num_reserved_ids) are reserved.
51
52
EOS is not appended.
53
54
Args:
55
s: human-readable string to be converted.
56
57
Returns:
58
ids: list of integers
59
"""
60
return [int(w) + self._num_reserved_ids for w in s.split()]
61
62
def decode(self, ids, strip_extraneous=False):
63
"""Transform a sequence of int ids into a human-readable string.
64
65
EOS is not expected in ids.
66
67
Args:
68
ids: list of integers to be converted.
69
strip_extraneous: bool, whether to strip off extraneous tokens
70
(EOS and PAD).
71
72
Returns:
73
s: human-readable string.
74
"""
75
if strip_extraneous:
76
ids = strip_ids(ids, list(range(self._num_reserved_ids or 0)))
77
return " ".join(self.decode_list(ids))
78
79
def decode_list(self, ids):
80
"""Transform a sequence of int ids into a their string versions.
81
82
This method supports transforming individual input/output ids to their
83
string versions so that sequence to/from text conversions can be visualized
84
in a human readable format.
85
86
Args:
87
ids: list of integers to be converted.
88
89
Returns:
90
strs: list of human-readable string.
91
"""
92
decoded_ids = []
93
for id_ in ids:
94
if 0 <= id_ < self._num_reserved_ids:
95
decoded_ids.append(RESERVED_TOKENS[int(id_)])
96
else:
97
decoded_ids.append(id_ - self._num_reserved_ids)
98
return [str(d) for d in decoded_ids]
99
100
@property
101
def vocab_size(self):
102
raise NotImplementedError()
103
104
105
class ByteTextEncoder(TextEncoder):
106
"""Encodes each byte to an id. For 8-bit strings only."""
107
108
def encode(self, s):
109
numres = self._num_reserved_ids
110
if six.PY2:
111
if isinstance(s, unicode):
112
s = s.encode("utf-8")
113
return [ord(c) + numres for c in s]
114
# Python3: explicitly convert to UTF-8
115
return [c + numres for c in s.encode("utf-8")]
116
117
def decode(self, ids, strip_extraneous=False):
118
if strip_extraneous:
119
ids = strip_ids(ids, list(range(self._num_reserved_ids or 0)))
120
numres = self._num_reserved_ids
121
decoded_ids = []
122
int2byte = six.int2byte
123
for id_ in ids:
124
if 0 <= id_ < numres:
125
decoded_ids.append(RESERVED_TOKENS_BYTES[int(id_)])
126
else:
127
decoded_ids.append(int2byte(id_ - numres))
128
if six.PY2:
129
return "".join(decoded_ids)
130
# Python3: join byte arrays and then decode string
131
return b"".join(decoded_ids).decode("utf-8", "replace")
132
133
def decode_list(self, ids):
134
numres = self._num_reserved_ids
135
decoded_ids = []
136
int2byte = six.int2byte
137
for id_ in ids:
138
if 0 <= id_ < numres:
139
decoded_ids.append(RESERVED_TOKENS_BYTES[int(id_)])
140
else:
141
decoded_ids.append(int2byte(id_ - numres))
142
# Python3: join byte arrays and then decode string
143
return decoded_ids
144
145
@property
146
def vocab_size(self):
147
return 2**8 + self._num_reserved_ids
148
149
150
class ByteTextEncoderWithEos(ByteTextEncoder):
151
"""Encodes each byte to an id and appends the EOS token."""
152
153
def encode(self, s):
154
return super(ByteTextEncoderWithEos, self).encode(s) + [EOS_ID]
155
156
157
class TokenTextEncoder(TextEncoder):
158
"""Encoder based on a user-supplied vocabulary (file or list)."""
159
160
def __init__(self,
161
vocab_filename,
162
reverse=False,
163
vocab_list=None,
164
replace_oov=None,
165
num_reserved_ids=NUM_RESERVED_TOKENS):
166
"""Initialize from a file or list, one token per line.
167
168
Handling of reserved tokens works as follows:
169
- When initializing from a list, we add reserved tokens to the vocab.
170
- When initializing from a file, we do not add reserved tokens to the vocab.
171
- When saving vocab files, we save reserved tokens to the file.
172
173
Args:
174
vocab_filename: If not None, the full filename to read vocab from. If this
175
is not None, then vocab_list should be None.
176
reverse: Boolean indicating if tokens should be reversed during encoding
177
and decoding.
178
vocab_list: If not None, a list of elements of the vocabulary. If this is
179
not None, then vocab_filename should be None.
180
replace_oov: If not None, every out-of-vocabulary token seen when
181
encoding will be replaced by this string (which must be in vocab).
182
num_reserved_ids: Number of IDs to save for reserved tokens like <EOS>.
183
"""
184
super(TokenTextEncoder, self).__init__(num_reserved_ids=num_reserved_ids)
185
self._reverse = reverse
186
self._replace_oov = replace_oov
187
if vocab_filename:
188
self._init_vocab_from_file(vocab_filename)
189
else:
190
assert vocab_list is not None
191
self._init_vocab_from_list(vocab_list)
192
self.pad_index = self._token_to_id[PAD]
193
self.eos_index = self._token_to_id[EOS]
194
self.unk_index = self._token_to_id[UNK]
195
self.seg_index = self._token_to_id[SEG] if SEG in self._token_to_id else self.eos_index
196
197
def encode(self, s):
198
"""Converts a space-separated string of tokens to a list of ids."""
199
sentence = s
200
tokens = sentence.strip().split()
201
if self._replace_oov is not None:
202
tokens = [t if t in self._token_to_id else self._replace_oov
203
for t in tokens]
204
ret = [self._token_to_id[tok] for tok in tokens]
205
return ret[::-1] if self._reverse else ret
206
207
def decode(self, ids, strip_eos=False, strip_padding=False):
208
if strip_padding and self.pad() in list(ids):
209
pad_pos = list(ids).index(self.pad())
210
ids = ids[:pad_pos]
211
if strip_eos and self.eos() in list(ids):
212
eos_pos = list(ids).index(self.eos())
213
ids = ids[:eos_pos]
214
return " ".join(self.decode_list(ids))
215
216
def decode_list(self, ids):
217
seq = reversed(ids) if self._reverse else ids
218
return [self._safe_id_to_token(i) for i in seq]
219
220
@property
221
def vocab_size(self):
222
return len(self._id_to_token)
223
224
def __len__(self):
225
return self.vocab_size
226
227
def _safe_id_to_token(self, idx):
228
return self._id_to_token.get(idx, "ID_%d" % idx)
229
230
def _init_vocab_from_file(self, filename):
231
"""Load vocab from a file.
232
233
Args:
234
filename: The file to load vocabulary from.
235
"""
236
with open(filename) as f:
237
tokens = [token.strip() for token in f.readlines()]
238
239
def token_gen():
240
for token in tokens:
241
yield token
242
243
self._init_vocab(token_gen(), add_reserved_tokens=False)
244
245
def _init_vocab_from_list(self, vocab_list):
246
"""Initialize tokens from a list of tokens.
247
248
It is ok if reserved tokens appear in the vocab list. They will be
249
removed. The set of tokens in vocab_list should be unique.
250
251
Args:
252
vocab_list: A list of tokens.
253
"""
254
def token_gen():
255
for token in vocab_list:
256
if token not in RESERVED_TOKENS:
257
yield token
258
259
self._init_vocab(token_gen())
260
261
def _init_vocab(self, token_generator, add_reserved_tokens=True):
262
"""Initialize vocabulary with tokens from token_generator."""
263
264
self._id_to_token = {}
265
non_reserved_start_index = 0
266
267
if add_reserved_tokens:
268
self._id_to_token.update(enumerate(RESERVED_TOKENS))
269
non_reserved_start_index = len(RESERVED_TOKENS)
270
271
self._id_to_token.update(
272
enumerate(token_generator, start=non_reserved_start_index))
273
274
# _token_to_id is the reverse of _id_to_token
275
self._token_to_id = dict((v, k)
276
for k, v in six.iteritems(self._id_to_token))
277
278
def pad(self):
279
return self.pad_index
280
281
def eos(self):
282
return self.eos_index
283
284
def unk(self):
285
return self.unk_index
286
287
def seg(self):
288
return self.seg_index
289
290
def store_to_file(self, filename):
291
"""Write vocab file to disk.
292
293
Vocab files have one token per line. The file ends in a newline. Reserved
294
tokens are written to the vocab file as well.
295
296
Args:
297
filename: Full path of the file to store the vocab to.
298
"""
299
with open(filename, "w") as f:
300
for i in range(len(self._id_to_token)):
301
f.write(self._id_to_token[i] + "\n")
302
303
def sil_phonemes(self):
304
return [p for p in self._id_to_token.values() if not p[0].isalpha()]
305
306