Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
jantic
GitHub Repository: jantic/deoldify
Path: blob/master/fastai/tabular/data.py
781 views
1
"Data loading pipeline for structured data support. Loads from pandas DataFrame"
2
from ..torch_core import *
3
from .transform import *
4
from ..basic_data import *
5
from ..data_block import *
6
from ..basic_train import *
7
from .models import *
8
from pandas.api.types import is_numeric_dtype, is_categorical_dtype
9
10
__all__ = ['TabularDataBunch', 'TabularLine', 'TabularList', 'TabularProcessor', 'tabular_learner']
11
12
OptTabTfms = Optional[Collection[TabularProc]]
13
14
#def emb_sz_rule(n_cat:int)->int: return min(50, (n_cat//2)+1)
15
def emb_sz_rule(n_cat:int)->int: return min(600, round(1.6 * n_cat**0.56))
16
17
def def_emb_sz(classes, n, sz_dict=None):
18
"Pick an embedding size for `n` depending on `classes` if not given in `sz_dict`."
19
sz_dict = ifnone(sz_dict, {})
20
n_cat = len(classes[n])
21
sz = sz_dict.get(n, int(emb_sz_rule(n_cat))) # rule of thumb
22
return n_cat,sz
23
24
class TabularLine(ItemBase):
25
"Basic item for tabular data."
26
def __init__(self, cats, conts, classes, names):
27
self.cats,self.conts,self.classes,self.names = cats,conts,classes,names
28
self.data = [tensor(cats), tensor(conts)]
29
30
def __str__(self):
31
res = ''
32
for c, n in zip(self.cats, self.names[:len(self.cats)]):
33
res += f"{n} {(self.classes[n][c])}; "
34
for c,n in zip(self.conts, self.names[len(self.cats):]):
35
res += f'{n} {c:.4f}; '
36
return res
37
38
class TabularProcessor(PreProcessor):
39
"Regroup the `procs` in one `PreProcessor`."
40
def __init__(self, ds:ItemBase=None, procs=None):
41
procs = ifnone(procs, ds.procs if ds is not None else None)
42
self.procs = listify(procs)
43
44
def process_one(self, item):
45
df = pd.DataFrame([item,item])
46
for proc in self.procs: proc(df, test=True)
47
if len(self.cat_names) != 0:
48
codes = np.stack([c.cat.codes.values for n,c in df[self.cat_names].items()], 1).astype(np.int64) + 1
49
else: codes = [[]]
50
if len(self.cont_names) != 0:
51
conts = np.stack([c.astype('float32').values for n,c in df[self.cont_names].items()], 1)
52
else: conts = [[]]
53
classes = None
54
col_names = list(df[self.cat_names].columns.values) + list(df[self.cont_names].columns.values)
55
return TabularLine(codes[0], conts[0], classes, col_names)
56
57
def process(self, ds):
58
if ds.inner_df is None:
59
ds.classes,ds.cat_names,ds.cont_names = self.classes,self.cat_names,self.cont_names
60
ds.col_names = self.cat_names + self.cont_names
61
ds.preprocessed = True
62
return
63
for i,proc in enumerate(self.procs):
64
if isinstance(proc, TabularProc): proc(ds.inner_df, test=True)
65
else:
66
#cat and cont names may have been changed by transform (like Fill_NA)
67
proc = proc(ds.cat_names, ds.cont_names)
68
proc(ds.inner_df)
69
ds.cat_names,ds.cont_names = proc.cat_names,proc.cont_names
70
self.procs[i] = proc
71
self.cat_names,self.cont_names = ds.cat_names,ds.cont_names
72
if len(ds.cat_names) != 0:
73
ds.codes = np.stack([c.cat.codes.values for n,c in ds.inner_df[ds.cat_names].items()], 1).astype(np.int64) + 1
74
self.classes = ds.classes = OrderedDict({n:np.concatenate([['#na#'],c.cat.categories.values])
75
for n,c in ds.inner_df[ds.cat_names].items()})
76
cat_cols = list(ds.inner_df[ds.cat_names].columns.values)
77
else: ds.codes,ds.classes,self.classes,cat_cols = None,None,None,[]
78
if len(ds.cont_names) != 0:
79
ds.conts = np.stack([c.astype('float32').values for n,c in ds.inner_df[ds.cont_names].items()], 1)
80
cont_cols = list(ds.inner_df[ds.cont_names].columns.values)
81
else: ds.conts,cont_cols = None,[]
82
ds.col_names = cat_cols + cont_cols
83
ds.preprocessed = True
84
85
class TabularDataBunch(DataBunch):
86
"Create a `DataBunch` suitable for tabular data."
87
@classmethod
88
def from_df(cls, path, df:DataFrame, dep_var:str, valid_idx:Collection[int], procs:OptTabTfms=None,
89
cat_names:OptStrList=None, cont_names:OptStrList=None, classes:Collection=None,
90
test_df=None, bs:int=64, val_bs:int=None, num_workers:int=defaults.cpus, dl_tfms:Optional[Collection[Callable]]=None,
91
device:torch.device=None, collate_fn:Callable=data_collate, no_check:bool=False)->DataBunch:
92
"Create a `DataBunch` from `df` and `valid_idx` with `dep_var`. `kwargs` are passed to `DataBunch.create`."
93
cat_names = ifnone(cat_names, []).copy()
94
cont_names = ifnone(cont_names, list(set(df)-set(cat_names)-{dep_var}))
95
procs = listify(procs)
96
src = (TabularList.from_df(df, path=path, cat_names=cat_names, cont_names=cont_names, procs=procs)
97
.split_by_idx(valid_idx))
98
src = src.label_from_df(cols=dep_var) if classes is None else src.label_from_df(cols=dep_var, classes=classes)
99
if test_df is not None: src.add_test(TabularList.from_df(test_df, cat_names=cat_names, cont_names=cont_names,
100
processor = src.train.x.processor))
101
return src.databunch(path=path, bs=bs, val_bs=val_bs, num_workers=num_workers, device=device,
102
collate_fn=collate_fn, no_check=no_check)
103
104
class TabularList(ItemList):
105
"Basic `ItemList` for tabular data."
106
_item_cls=TabularLine
107
_processor=TabularProcessor
108
_bunch=TabularDataBunch
109
def __init__(self, items:Iterator, cat_names:OptStrList=None, cont_names:OptStrList=None,
110
procs=None, **kwargs)->'TabularList':
111
super().__init__(range_of(items), **kwargs)
112
#dataframe is in inner_df, items is just a range of index
113
if cat_names is None: cat_names = []
114
if cont_names is None: cont_names = []
115
self.cat_names,self.cont_names,self.procs = cat_names,cont_names,procs
116
self.copy_new += ['cat_names', 'cont_names', 'procs']
117
self.preprocessed = False
118
119
@classmethod
120
def from_df(cls, df:DataFrame, cat_names:OptStrList=None, cont_names:OptStrList=None, procs=None, **kwargs)->'ItemList':
121
"Get the list of inputs in the `col` of `path/csv_name`."
122
return cls(items=range(len(df)), cat_names=cat_names, cont_names=cont_names, procs=procs, inner_df=df.copy(), **kwargs)
123
124
def get(self, o):
125
if not self.preprocessed: return self.inner_df.iloc[o] if hasattr(self, 'inner_df') else self.items[o]
126
codes = [] if self.codes is None else self.codes[o]
127
conts = [] if self.conts is None else self.conts[o]
128
return self._item_cls(codes, conts, self.classes, self.col_names)
129
130
def get_emb_szs(self, sz_dict=None):
131
"Return the default embedding sizes suitable for this data or takes the ones in `sz_dict`."
132
return [def_emb_sz(self.classes, n, sz_dict) for n in self.cat_names]
133
134
def reconstruct(self, t:Tensor):
135
return self._item_cls(t[0], t[1], self.classes, self.col_names)
136
137
def show_xys(self, xs, ys)->None:
138
"Show the `xs` (inputs) and `ys` (targets)."
139
from IPython.display import display, HTML
140
items,names = [], xs[0].names + ['target']
141
for i, (x,y) in enumerate(zip(xs,ys)):
142
res = []
143
cats = x.cats if len(x.cats.size()) > 0 else []
144
conts = x.conts if len(x.conts.size()) > 0 else []
145
for c, n in zip(cats, x.names[:len(cats)]):
146
res.append(x.classes[n][c])
147
res += [f'{c:.4f}' for c in conts] + [y]
148
items.append(res)
149
items = np.array(items)
150
df = pd.DataFrame({n:items[:,i] for i,n in enumerate(names)}, columns=names)
151
with pd.option_context('display.max_colwidth', -1):
152
display(HTML(df.to_html(index=False)))
153
154
def show_xyzs(self, xs, ys, zs):
155
"Show `xs` (inputs), `ys` (targets) and `zs` (predictions)."
156
from IPython.display import display, HTML
157
items,names = [], xs[0].names + ['target', 'prediction']
158
for i, (x,y,z) in enumerate(zip(xs,ys,zs)):
159
res = []
160
cats = x.cats if len(x.cats.size()) > 0 else []
161
conts = x.conts if len(x.conts.size()) > 0 else []
162
for c, n in zip(cats, x.names[:len(cats)]):
163
res.append(str(x.classes[n][c]))
164
res += [f'{c:.4f}' for c in conts] + [y, z]
165
items.append(res)
166
items = np.array(items)
167
df = pd.DataFrame({n:items[:,i] for i,n in enumerate(names)}, columns=names)
168
with pd.option_context('display.max_colwidth', -1):
169
display(HTML(df.to_html(index=False)))
170
171
def tabular_learner(data:DataBunch, layers:Collection[int], emb_szs:Dict[str,int]=None, metrics=None,
172
ps:Collection[float]=None, emb_drop:float=0., y_range:OptRange=None, use_bn:bool=True, **learn_kwargs):
173
"Get a `Learner` using `data`, with `metrics`, including a `TabularModel` created using the remaining params."
174
emb_szs = data.get_emb_szs(ifnone(emb_szs, {}))
175
model = TabularModel(emb_szs, len(data.cont_names), out_sz=data.c, layers=layers, ps=ps, emb_drop=emb_drop,
176
y_range=y_range, use_bn=use_bn)
177
return Learner(data, model, metrics=metrics, **learn_kwargs)
178
179
180