Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
jantic
GitHub Repository: jantic/deoldify
Path: blob/master/fastai/text/data.py
781 views
1
"NLP data loading pipeline. Supports csv, folders, and preprocessed data."
2
from ..torch_core import *
3
from .transform import *
4
from ..basic_data import *
5
from ..data_block import *
6
from ..layers import *
7
from ..callback import Callback
8
9
__all__ = ['LanguageModelPreLoader', 'SortSampler', 'SortishSampler', 'TextList', 'pad_collate', 'TextDataBunch',
10
'TextLMDataBunch', 'TextClasDataBunch', 'Text', 'open_text', 'TokenizeProcessor', 'NumericalizeProcessor',
11
'OpenFileProcessor', 'LMLabelList', 'LMTextList', 'SPProcessor']
12
13
TextMtd = IntEnum('TextMtd', 'DF TOK IDS')
14
text_extensions = {'.txt'}
15
16
class LanguageModelPreLoader(Callback):
17
"Transforms the tokens in `dataset` to a stream of contiguous batches for language modelling."
18
19
class CircularIndex():
20
"Handles shuffle, direction of indexing, wraps around to head tail in the ragged array as needed"
21
def __init__(self, length:int, forward:bool): self.idx, self.forward = np.arange(length), forward
22
def __getitem__(self, i):
23
return self.idx[ i%len(self.idx) if self.forward else len(self.idx)-1-i%len(self.idx)]
24
def __len__(self) -> int: return len(self.idx)
25
def shuffle(self): np.random.shuffle(self.idx)
26
27
def __init__(self, dataset:LabelList, lengths:Collection[int]=None, bs:int=32, bptt:int=70, backwards:bool=False,
28
shuffle:bool=False):
29
self.dataset,self.bs,self.bptt,self.shuffle,self.backwards,self.lengths = dataset,bs,bptt,shuffle,backwards,lengths
30
self.bs *= num_distrib() or 1
31
self.totalToks,self.ite_len,self.idx = int(0),None,None
32
33
def __len__(self):
34
if self.ite_len is None:
35
if self.lengths is None: self.lengths = np.array([len(item) for item in self.dataset.x.items])
36
self.totalToks = self.lengths.sum()
37
self.ite_len = self.bs*int( math.ceil( self.totalToks/(self.bptt*self.bs) )) if self.item is None else 1
38
return self.ite_len
39
40
def __getattr__(self,k:str)->Any: return getattr(self.dataset, k)
41
42
def allocate_buffers(self):
43
"Create the ragged array that will be filled when we ask for items."
44
if self.ite_len is None: len(self)
45
self.idx = LanguageModelPreLoader.CircularIndex(len(self.dataset.x.items), not self.backwards)
46
self.batch = np.zeros((self.bs, self.bptt+1), dtype=np.int64)
47
self.batch_x, self.batch_y = self.batch[:,0:self.bptt], self.batch[:,1:self.bptt+1]
48
#ro: index of the text we're at inside our datasets for the various batches
49
self.ro = np.zeros(self.bs, dtype=np.int64)
50
#ri: index of the token we're at inside our current text for the various batches
51
self.ri = np.zeros(self.bs, dtype=np.int)
52
53
def on_epoch_begin(self, **kwargs):
54
if self.idx is None or len(self.idx) != len(self.dataset.x.items): self.allocate_buffers()
55
elif self.shuffle: self.idx.shuffle()
56
self.idx.forward = not self.backwards
57
58
step = self.totalToks / self.bs
59
ln_rag, countTokens, i_rag = 0, 0, -1
60
for i in range(0,self.bs):
61
#Compute the initial values for ro and ri
62
while ln_rag + countTokens <= int(step * i):
63
countTokens += ln_rag
64
i_rag += 1
65
ln_rag = self.lengths[self.idx[i_rag]]
66
self.ro[i] = i_rag
67
self.ri[i] = ( ln_rag - int(step * i - countTokens) ) if self.backwards else int(step * i - countTokens)
68
69
#Training dl gets on_epoch_begin called, val_dl, on_epoch_end
70
def on_epoch_end(self, **kwargs): self.on_epoch_begin()
71
72
def __getitem__(self, k:int):
73
j = k % self.bs
74
if self.item is not None: return self.dataset[0]
75
if self.idx is None: self.on_epoch_begin()
76
self.ro[j],self.ri[j] = self.fill_row(not self.backwards, self.dataset.x.items, self.idx, self.batch[j],
77
self.ro[j], self.ri[j], overlap=1, lengths=self.lengths)
78
return self.batch_x[j], self.batch_y[j]
79
80
def fill_row(self, forward, items, idx, row, ro, ri, overlap,lengths):
81
"Fill the row with tokens from the ragged array. --OBS-- overlap != 1 has not been implemented"
82
ibuf = n = 0
83
ro -= 1
84
while ibuf < row.size:
85
ro += 1
86
ix = idx[ro]
87
rag = items[ix]
88
if forward:
89
ri = 0 if ibuf else ri
90
n = min(lengths[ix] - ri, row.size - ibuf)
91
row[ibuf:ibuf+n] = rag[ri:ri+n]
92
else:
93
ri = lengths[ix] if ibuf else ri
94
n = min(ri, row.size - ibuf)
95
row[ibuf:ibuf+n] = rag[ri-n:ri][::-1]
96
ibuf += n
97
return ro, ri + ((n-overlap) if forward else -(n-overlap))
98
99
class SortSampler(Sampler):
100
"Go through the text data by order of length."
101
102
def __init__(self, data_source:NPArrayList, key:KeyFunc): self.data_source,self.key = data_source,key
103
def __len__(self) -> int: return len(self.data_source)
104
def __iter__(self):
105
return iter(sorted(range_of(self.data_source), key=self.key, reverse=True))
106
107
class SortishSampler(Sampler):
108
"Go through the text data by order of length with a bit of randomness."
109
110
def __init__(self, data_source:NPArrayList, key:KeyFunc, bs:int):
111
self.data_source,self.key,self.bs = data_source,key,bs
112
113
def __len__(self) -> int: return len(self.data_source)
114
115
def __iter__(self):
116
idxs = np.random.permutation(len(self.data_source))
117
sz = self.bs*50
118
ck_idx = [idxs[i:i+sz] for i in range(0, len(idxs), sz)]
119
sort_idx = np.concatenate([sorted(s, key=self.key, reverse=True) for s in ck_idx])
120
sz = self.bs
121
ck_idx = [sort_idx[i:i+sz] for i in range(0, len(sort_idx), sz)]
122
max_ck = np.argmax([self.key(ck[0]) for ck in ck_idx]) # find the chunk with the largest key,
123
ck_idx[0],ck_idx[max_ck] = ck_idx[max_ck],ck_idx[0] # then make sure it goes first.
124
sort_idx = np.concatenate(np.random.permutation(ck_idx[1:])) if len(ck_idx) > 1 else np.array([],dtype=np.int)
125
sort_idx = np.concatenate((ck_idx[0], sort_idx))
126
return iter(sort_idx)
127
128
def pad_collate(samples:BatchSamples, pad_idx:int=1, pad_first:bool=True, backwards:bool=False) -> Tuple[LongTensor, LongTensor]:
129
"Function that collect samples and adds padding. Flips token order if needed"
130
samples = to_data(samples)
131
max_len = max([len(s[0]) for s in samples])
132
res = torch.zeros(len(samples), max_len).long() + pad_idx
133
if backwards: pad_first = not pad_first
134
for i,s in enumerate(samples):
135
if pad_first: res[i,-len(s[0]):] = LongTensor(s[0])
136
else: res[i,:len(s[0]):] = LongTensor(s[0])
137
if backwards: res = res.flip(1)
138
return res, tensor(np.array([s[1] for s in samples]))
139
140
def _get_processor(tokenizer:Tokenizer=None, vocab:Vocab=None, chunksize:int=10000, max_vocab:int=60000,
141
min_freq:int=2, mark_fields:bool=False, include_bos:bool=True, include_eos:bool=False):
142
return [TokenizeProcessor(tokenizer=tokenizer, chunksize=chunksize,
143
mark_fields=mark_fields, include_bos=include_bos, include_eos=include_eos),
144
NumericalizeProcessor(vocab=vocab, max_vocab=max_vocab, min_freq=min_freq)]
145
146
class TextDataBunch(DataBunch):
147
"General class to get a `DataBunch` for NLP. Subclassed by `TextLMDataBunch` and `TextClasDataBunch`."
148
149
@classmethod
150
def from_ids(cls, path:PathOrStr, vocab:Vocab, train_ids:Collection[Collection[int]], valid_ids:Collection[Collection[int]],
151
test_ids:Collection[Collection[int]]=None, train_lbls:Collection[Union[int,float]]=None,
152
valid_lbls:Collection[Union[int,float]]=None, classes:Collection[Any]=None,
153
processor:PreProcessor=None, **kwargs) -> DataBunch:
154
"Create a `TextDataBunch` from ids, labels and a `vocab`. `kwargs` are passed to the dataloader creation."
155
src = ItemLists(path, TextList(train_ids, vocab, path=path, processor=[]),
156
TextList(valid_ids, vocab, path=path, processor=[]))
157
src = src.label_for_lm() if cls==TextLMDataBunch else src.label_from_lists(train_lbls, valid_lbls, classes=classes, processor=[])
158
if not is1d(train_lbls): src.train.y.one_hot,src.valid.y.one_hot = True,True
159
if test_ids is not None: src.add_test(TextList(test_ids, vocab, path=path), label=train_lbls[0])
160
src.valid.x.processor = ifnone(processor, [TokenizeProcessor(), NumericalizeProcessor(vocab=vocab)])
161
if classes is not None: src.valid.y.processor = ifnone(processor, [CategoryProcessor(src.valid.y)])
162
return src.databunch(**kwargs)
163
164
@classmethod
165
def load(cls, path:PathOrStr, cache_name:PathOrStr='tmp', processor:PreProcessor=None, **kwargs):
166
"Load a `TextDataBunch` from `path/cache_name`. `kwargs` are passed to the dataloader creation."
167
warn("""This method is deprecated and only kept to load data serialized in v1.0.43 or earlier.
168
Use `load_data` for data saved with v1.0.44 or later.""", DeprecationWarning)
169
cache_path = Path(path)/cache_name
170
vocab = Vocab(pickle.load(open(cache_path/'itos.pkl','rb')))
171
train_ids,train_lbls = np.load(cache_path/f'train_ids.npy'), np.load(cache_path/f'train_lbl.npy')
172
valid_ids,valid_lbls = np.load(cache_path/f'valid_ids.npy'), np.load(cache_path/f'valid_lbl.npy')
173
test_ids = np.load(cache_path/f'test_ids.npy') if os.path.isfile(cache_path/f'test_ids.npy') else None
174
classes = loadtxt_str(cache_path/'classes.txt') if os.path.isfile(cache_path/'classes.txt') else None
175
return cls.from_ids(path, vocab, train_ids, valid_ids, test_ids, train_lbls, valid_lbls, classes, processor, **kwargs)
176
177
@classmethod#TODO: test
178
def from_tokens(cls, path:PathOrStr, trn_tok:Collection[Collection[str]], trn_lbls:Collection[Union[int,float]],
179
val_tok:Collection[Collection[str]], val_lbls:Collection[Union[int,float]], vocab:Vocab=None,
180
tst_tok:Collection[Collection[str]]=None, classes:Collection[Any]=None, max_vocab:int=60000, min_freq:int=3,
181
**kwargs) -> DataBunch:
182
"Create a `TextDataBunch` from tokens and labels. `kwargs` are passed to the dataloader creation."
183
processor = NumericalizeProcessor(vocab=vocab, max_vocab=max_vocab, min_freq=min_freq)
184
src = ItemLists(path, TextList(trn_tok, path=path, processor=processor),
185
TextList(val_tok, path=path, processor=processor))
186
src = src.label_for_lm() if cls==TextLMDataBunch else src.label_from_lists(trn_lbls, val_lbls, classes=classes)
187
if tst_tok is not None: src.add_test(TextList(tst_tok, path=path))
188
return src.databunch(**kwargs)
189
190
@classmethod
191
def from_df(cls, path:PathOrStr, train_df:DataFrame, valid_df:DataFrame, test_df:Optional[DataFrame]=None,
192
tokenizer:Tokenizer=None, vocab:Vocab=None, classes:Collection[str]=None, text_cols:IntsOrStrs=1,
193
label_cols:IntsOrStrs=0, label_delim:str=None, chunksize:int=10000, max_vocab:int=60000,
194
min_freq:int=2, mark_fields:bool=False, include_bos:bool=True, include_eos:bool=False, **kwargs) -> DataBunch:
195
"Create a `TextDataBunch` from DataFrames. `kwargs` are passed to the dataloader creation."
196
processor = _get_processor(tokenizer=tokenizer, vocab=vocab, chunksize=chunksize, max_vocab=max_vocab,
197
min_freq=min_freq, mark_fields=mark_fields,
198
include_bos=include_bos, include_eos=include_eos)
199
if classes is None and is_listy(label_cols) and len(label_cols) > 1: classes = label_cols
200
src = ItemLists(path, TextList.from_df(train_df, path, cols=text_cols, processor=processor),
201
TextList.from_df(valid_df, path, cols=text_cols, processor=processor))
202
if cls==TextLMDataBunch: src = src.label_for_lm()
203
else:
204
if label_delim is not None: src = src.label_from_df(cols=label_cols, classes=classes, label_delim=label_delim)
205
else: src = src.label_from_df(cols=label_cols, classes=classes)
206
if test_df is not None: src.add_test(TextList.from_df(test_df, path, cols=text_cols))
207
return src.databunch(**kwargs)
208
209
@classmethod
210
def from_csv(cls, path:PathOrStr, csv_name, valid_pct:float=0.2, test:Optional[str]=None,
211
tokenizer:Tokenizer=None, vocab:Vocab=None, classes:Collection[str]=None, delimiter:str=None, header='infer',
212
text_cols:IntsOrStrs=1, label_cols:IntsOrStrs=0, label_delim:str=None,
213
chunksize:int=10000, max_vocab:int=60000, min_freq:int=2,
214
mark_fields:bool=False, include_bos:bool=True, include_eos:bool=False, **kwargs) -> DataBunch:
215
"Create a `TextDataBunch` from texts in csv files. `kwargs` are passed to the dataloader creation."
216
df = pd.read_csv(Path(path)/csv_name, header=header, delimiter=delimiter)
217
df = df.iloc[np.random.permutation(len(df))]
218
cut = int(valid_pct * len(df)) + 1
219
train_df, valid_df = df[cut:], df[:cut]
220
test_df = None if test is None else pd.read_csv(Path(path)/test, header=header, delimiter=delimiter)
221
return cls.from_df(path, train_df, valid_df, test_df, tokenizer=tokenizer, vocab=vocab, classes=classes, text_cols=text_cols,
222
label_cols=label_cols, label_delim=label_delim, chunksize=chunksize, max_vocab=max_vocab,
223
min_freq=min_freq, mark_fields=mark_fields,
224
include_bos=include_bos, include_eos=include_eos, **kwargs)
225
226
@classmethod
227
def from_folder(cls, path:PathOrStr, train:str='train', valid:str='valid', test:Optional[str]=None,
228
classes:Collection[Any]=None, tokenizer:Tokenizer=None, vocab:Vocab=None, chunksize:int=10000, max_vocab:int=60000,
229
min_freq:int=2, mark_fields:bool=False, include_bos:bool=True, include_eos:bool=False, **kwargs):
230
"Create a `TextDataBunch` from text files in folders."
231
path = Path(path).absolute()
232
processor = [OpenFileProcessor()] + _get_processor(tokenizer=tokenizer, vocab=vocab, chunksize=chunksize, max_vocab=max_vocab,
233
min_freq=min_freq, mark_fields=mark_fields, include_bos=include_bos, include_eos=include_eos)
234
src = (TextList.from_folder(path, processor=processor)
235
.split_by_folder(train=train, valid=valid))
236
src = src.label_for_lm() if cls==TextLMDataBunch else src.label_from_folder(classes=classes)
237
if test is not None: src.add_test_folder(path/test)
238
return src.databunch(**kwargs)
239
240
class TextLMDataBunch(TextDataBunch):
241
"Create a `TextDataBunch` suitable for training a language model."
242
@classmethod
243
def create(cls, train_ds, valid_ds, test_ds=None, path:PathOrStr='.', no_check:bool=False, bs=64, val_bs:int=None,
244
num_workers:int=0, device:torch.device=None, collate_fn:Callable=data_collate,
245
dl_tfms:Optional[Collection[Callable]]=None, bptt:int=70, backwards:bool=False, **dl_kwargs) -> DataBunch:
246
"Create a `TextDataBunch` in `path` from the `datasets` for language modelling. Passes `**dl_kwargs` on to `DataLoader()`"
247
datasets = cls._init_ds(train_ds, valid_ds, test_ds)
248
val_bs = ifnone(val_bs, bs)
249
datasets = [LanguageModelPreLoader(ds, shuffle=(i==0), bs=(bs if i==0 else val_bs), bptt=bptt, backwards=backwards)
250
for i,ds in enumerate(datasets)]
251
val_bs = bs
252
dls = [DataLoader(d, b, shuffle=False, **dl_kwargs) for d,b in zip(datasets, (bs,val_bs,val_bs,val_bs)) if d is not None]
253
return cls(*dls, path=path, device=device, dl_tfms=dl_tfms, collate_fn=collate_fn, no_check=no_check)
254
255
class TextClasDataBunch(TextDataBunch):
256
"Create a `TextDataBunch` suitable for training an RNN classifier."
257
@classmethod
258
def create(cls, train_ds, valid_ds, test_ds=None, path:PathOrStr='.', bs:int=32, val_bs:int=None, pad_idx=1,
259
pad_first=True, device:torch.device=None, no_check:bool=False, backwards:bool=False,
260
dl_tfms:Optional[Collection[Callable]]=None, **dl_kwargs) -> DataBunch:
261
"Function that transform the `datasets` in a `DataBunch` for classification. Passes `**dl_kwargs` on to `DataLoader()`"
262
datasets = cls._init_ds(train_ds, valid_ds, test_ds)
263
val_bs = ifnone(val_bs, bs)
264
collate_fn = partial(pad_collate, pad_idx=pad_idx, pad_first=pad_first, backwards=backwards)
265
train_sampler = SortishSampler(datasets[0].x, key=lambda t: len(datasets[0][t][0].data), bs=bs)
266
train_dl = DataLoader(datasets[0], batch_size=bs, sampler=train_sampler, drop_last=True, **dl_kwargs)
267
dataloaders = [train_dl]
268
for ds in datasets[1:]:
269
lengths = [len(t) for t in ds.x.items]
270
sampler = SortSampler(ds.x, key=lengths.__getitem__)
271
dataloaders.append(DataLoader(ds, batch_size=val_bs, sampler=sampler, **dl_kwargs))
272
return cls(*dataloaders, path=path, device=device, dl_tfms=dl_tfms, collate_fn=collate_fn, no_check=no_check)
273
274
def open_text(fn:PathOrStr, enc='utf-8'):
275
"Read the text in `fn`."
276
with open(fn,'r', encoding = enc) as f: return ''.join(f.readlines())
277
278
class Text(ItemBase):
279
"Basic item for <code>text</code> data in numericalized `ids`."
280
def __init__(self, ids, text): self.data,self.text = np.array(ids, dtype=np.int64),text
281
def __str__(self): return str(self.text)
282
283
class TokenizeProcessor(PreProcessor):
284
"`PreProcessor` that tokenizes the texts in `ds`."
285
def __init__(self, ds:ItemList=None, tokenizer:Tokenizer=None, chunksize:int=10000,
286
mark_fields:bool=False, include_bos:bool=True, include_eos:bool=False):
287
self.tokenizer,self.chunksize,self.mark_fields = ifnone(tokenizer, Tokenizer()),chunksize,mark_fields
288
self.include_bos, self.include_eos = include_bos, include_eos
289
290
def process_one(self, item):
291
return self.tokenizer._process_all_1(_join_texts([item], self.mark_fields, self.include_bos, self.include_eos))[0]
292
293
def process(self, ds):
294
ds.items = _join_texts(ds.items, self.mark_fields, self.include_bos, self.include_eos)
295
tokens = []
296
for i in progress_bar(range(0,len(ds),self.chunksize), leave=False):
297
tokens += self.tokenizer.process_all(ds.items[i:i+self.chunksize])
298
ds.items = tokens
299
300
class NumericalizeProcessor(PreProcessor):
301
"`PreProcessor` that numericalizes the tokens in `ds`."
302
def __init__(self, ds:ItemList=None, vocab:Vocab=None, max_vocab:int=60000, min_freq:int=3):
303
vocab = ifnone(vocab, ds.vocab if ds is not None else None)
304
self.vocab,self.max_vocab,self.min_freq = vocab,max_vocab,min_freq
305
306
def process_one(self,item): return np.array(self.vocab.numericalize(item), dtype=np.int64)
307
def process(self, ds):
308
if self.vocab is None: self.vocab = Vocab.create(ds.items, self.max_vocab, self.min_freq)
309
ds.vocab = self.vocab
310
super().process(ds)
311
312
class OpenFileProcessor(PreProcessor):
313
"`PreProcessor` that opens the filenames and read the texts."
314
def process(self, ds:Collection): ds.items = array([self.process_one(item) for item in ds.items], dtype=np.object)
315
def process_one(self,item): return open_text(item) if isinstance(item, Path) else item
316
317
class TextList(ItemList):
318
"Basic `ItemList` for text data."
319
_bunch = TextClasDataBunch
320
_processor = [TokenizeProcessor, NumericalizeProcessor]
321
_is_lm = False
322
323
def __init__(self, items:Iterator, vocab:Vocab=None, pad_idx:int=1, sep=' ', **kwargs):
324
super().__init__(items, **kwargs)
325
self.vocab,self.pad_idx,self.sep = vocab,pad_idx,sep
326
self.copy_new += ['vocab', 'pad_idx', 'sep']
327
328
def get(self, i):
329
o = super().get(i)
330
return o if self.vocab is None else Text(o, self.vocab.textify(o, self.sep))
331
332
def label_for_lm(self, **kwargs):
333
"A special labelling method for language models."
334
self.__class__ = LMTextList
335
kwargs['label_cls'] = LMLabelList
336
return self.label_const(0, **kwargs)
337
338
def reconstruct(self, t:Tensor):
339
idx_min = (t != self.pad_idx).nonzero().min()
340
idx_max = (t != self.pad_idx).nonzero().max()
341
return Text(t[idx_min:idx_max+1], self.vocab.textify(t[idx_min:idx_max+1]))
342
343
@classmethod
344
def from_folder(cls, path:PathOrStr='.', extensions:Collection[str]=text_extensions, vocab:Vocab=None,
345
processor:PreProcessor=None, **kwargs)->'TextList':
346
"Get the list of files in `path` that have a text suffix. `recurse` determines if we search subfolders."
347
processor = ifnone(processor, [OpenFileProcessor(), TokenizeProcessor(), NumericalizeProcessor(vocab=vocab)])
348
return super().from_folder(path=path, extensions=extensions, processor=processor, **kwargs)
349
350
def show_xys(self, xs, ys, max_len:int=70)->None:
351
"Show the `xs` (inputs) and `ys` (targets). `max_len` is the maximum number of tokens displayed."
352
from IPython.display import display, HTML
353
names = ['idx','text'] if self._is_lm else ['text','target']
354
items = []
355
for i, (x,y) in enumerate(zip(xs,ys)):
356
txt_x = ' '.join(x.text.split(' ')[:max_len]) if max_len is not None else x.text
357
items.append([i, txt_x] if self._is_lm else [txt_x, y])
358
items = np.array(items)
359
df = pd.DataFrame({n:items[:,i] for i,n in enumerate(names)}, columns=names)
360
with pd.option_context('display.max_colwidth', -1):
361
display(HTML(df.to_html(index=False)))
362
363
def show_xyzs(self, xs, ys, zs, max_len:int=70):
364
"Show `xs` (inputs), `ys` (targets) and `zs` (predictions). `max_len` is the maximum number of tokens displayed."
365
from IPython.display import display, HTML
366
items,names = [],['text','target','prediction']
367
for i, (x,y,z) in enumerate(zip(xs,ys,zs)):
368
txt_x = ' '.join(x.text.split(' ')[:max_len]) if max_len is not None else x.text
369
items.append([txt_x, y, z])
370
items = np.array(items)
371
df = pd.DataFrame({n:items[:,i] for i,n in enumerate(names)}, columns=names)
372
with pd.option_context('display.max_colwidth', -1):
373
display(HTML(df.to_html(index=False)))
374
375
class LMLabelList(EmptyLabelList):
376
"Basic `ItemList` for dummy labels."
377
def __init__(self, items:Iterator, **kwargs):
378
super().__init__(items, **kwargs)
379
self.loss_func = CrossEntropyFlat()
380
381
class LMTextList(TextList):
382
"Special `TextList` for a language model."
383
_bunch = TextLMDataBunch
384
_is_lm = True
385
386
def _join_texts(texts:Collection[str], mark_fields:bool=False, include_bos:bool=True, include_eos:bool=False):
387
if not isinstance(texts, np.ndarray): texts = np.array(texts)
388
if is1d(texts): texts = texts[:,None]
389
df = pd.DataFrame({i:texts[:,i] for i in range(texts.shape[1])})
390
bos_tok = f'{BOS} ' if include_bos else ''
391
text_col = f'{bos_tok}{FLD} {1} ' + df[0].astype(str) if mark_fields else f'{bos_tok}' + df[0].astype(str)
392
for i in range(1,len(df.columns)):
393
text_col += (f' {FLD} {i+1} ' if mark_fields else ' ') + df[i].astype(str)
394
if include_eos: text_col = text_col + f' {EOS}'
395
return text_col.values
396
397
def apply_rules(text, pre_rules=None, post_rules=None):
398
"Apply `pre_rules` and `post_rules` to `text`"
399
text = text.strip(' ')
400
for r in ifnone(pre_rules, defaults.text_pre_rules): text = r(text)
401
toks = text.split()
402
for r in ifnone(post_rules, defaults.text_post_rules): toks = r(toks)
403
return ' '.join(toks)
404
405
def get_default_size(texts, max_vocab_sz):
406
"Either max_vocab_sz or one quarter of the number of unique words in `texts`"
407
cnt = Counter()
408
for t in texts:
409
cnt.update(t.split())
410
if len(cnt)//4 > max_vocab_sz: return max_vocab_sz
411
res = len(cnt)//4
412
while res%8 != 0: res+=1
413
return res
414
415
full_char_coverage_langs = ["bg", "cs", "da", "de", "el", "en", "es", "et", "fi", "fr", "ga", "hr", "hu",
416
"it","lt","lv","mt","nl","pl","pt","ro","sk","sl","sv"] # all European langs
417
418
def train_sentencepiece(texts:Collection[str], path:PathOrStr, pre_rules: ListRules=None, post_rules:ListRules=None,
419
vocab_sz:int=None, max_vocab_sz:int=30000, model_type:str='unigram', max_sentence_len:int=20480, lang='en',
420
char_coverage=None, tmp_dir='tmp'):
421
"Train a sentencepiece tokenizer on `texts` and save it in `path/tmp_dir`"
422
from sentencepiece import SentencePieceTrainer
423
cache_dir = Path(path)/tmp_dir
424
os.makedirs(cache_dir, exist_ok=True)
425
if vocab_sz is None: vocab_sz=get_default_size(texts, max_vocab_sz)
426
raw_text_path = cache_dir / 'all_text.out'
427
with open(raw_text_path, 'w') as f: f.write("\n".join(texts))
428
spec_tokens = ['\u2581'+s for s in defaults.text_spec_tok]
429
SentencePieceTrainer.Train(" ".join([
430
f"--input={raw_text_path} --max_sentence_length={max_sentence_len}",
431
f"--character_coverage={ifnone(char_coverage, 0.99999 if lang in full_char_coverage_langs else 0.9998)}",
432
f"--unk_id={len(defaults.text_spec_tok)} --pad_id=-1 --bos_id=-1 --eos_id=-1",
433
f"--user_defined_symbols={','.join(spec_tokens)}",
434
f"--model_prefix={cache_dir/'spm'} --vocab_size={vocab_sz} --model_type={model_type}"]))
435
raw_text_path.unlink()
436
return cache_dir
437
438
class SPProcessor(PreProcessor):
439
"`PreProcessor` that tokenizes and numericalizes with `sentencepiece`"
440
def __init__(self, ds:ItemList=None, pre_rules: ListRules=None, post_rules:ListRules=None, vocab_sz:int=None,
441
max_vocab_sz:int=30000, model_type:str='unigram', max_sentence_len:int=20480, lang='en',
442
char_coverage=None, tmp_dir='tmp', mark_fields:bool=False, include_bos:bool=True,
443
include_eos:bool=False, sp_model=None, sp_vocab=None, n_cpus:int=None):
444
try: from sentencepiece import SentencePieceTrainer,SentencePieceProcessor
445
except ImportError:
446
raise Exception('sentencepiece module is missing: run `pip install sentencepiece`')
447
self.pre_rules,self.post_rules = pre_rules,post_rules
448
self.mark_fields,self.include_bos,self.include_eos = mark_fields,include_bos,include_eos
449
self.sp_model,self.sp_vocab,self.n_cpus = sp_model,sp_vocab,ifnone(n_cpus,defaults.cpus)
450
self.train_func = partial(train_sentencepiece, pre_rules=pre_rules, post_rules=post_rules, vocab_sz=vocab_sz,
451
max_vocab_sz=max_vocab_sz, model_type=model_type, max_sentence_len=max_sentence_len, lang=lang,
452
char_coverage=char_coverage, tmp_dir=tmp_dir)
453
454
def process_one(self, item, join=True):
455
if join: text = _join_texts([item], self.mark_fields, self.include_bos, self.include_eos)[0]
456
text = apply_rules(text, pre_rules=self.pre_rules, post_rules=self.post_rules)
457
return self._encode_batch([text])[0]
458
459
def process(self, ds):
460
ds.items = _join_texts(ds.items, self.mark_fields, self.include_bos, self.include_eos)
461
ds.items = [apply_rules(t, pre_rules=self.pre_rules, post_rules=self.post_rules)
462
for t in progress_bar(ds.items, leave=False)]
463
if self.sp_model is None or self.sp_vocab is None:
464
cache_dir = self.train_func(ds.items, ds.path)
465
self.sp_model,self.sp_vocab = cache_dir/'spm.model',cache_dir/'spm.vocab'
466
if not getattr(self, 'vocab', False):
467
with open(self.sp_vocab, 'r') as f: self.vocab = Vocab([line.split('\t')[0] for line in f.readlines()])
468
if self.n_cpus <= 1: ds.items = self._encode_batch(ds.items)
469
else:
470
with ProcessPoolExecutor(self.n_cpus) as e:
471
ds.items = np.array(sum(e.map(self._encode_batch, partition_by_cores(ds.items, self.n_cpus)), []))
472
ds.vocab = self.vocab
473
474
def _encode_batch(self, texts):
475
from sentencepiece import SentencePieceProcessor
476
tok = SentencePieceProcessor()
477
tok.Load(str(self.sp_model))
478
return [np.array(tok.EncodeAsIds(t)) for t in texts]
479
480
@classmethod
481
def load(cls, path:PathOrStr, tmp_dir:PathOrStr='tmp', name:str='spm'):
482
cache_dir = Path(path)/tmp_dir
483
return cls(sp_model=cache_dir/f'{name}.model', sp_vocab=cache_dir/f'{name}.vocab')
484
485