Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
jantic
GitHub Repository: jantic/deoldify
Path: blob/master/fastai/collab.py
781 views
1
"Module support for Collaborative Filtering"
2
from .tabular import *
3
from . import tabular
4
5
__all__ = [*tabular.__all__, 'EmbeddingDotBias', 'EmbeddingNN', 'collab_learner', 'CollabDataBunch', 'CollabLine',
6
'CollabList', 'CollabLearner']
7
8
class CollabProcessor(TabularProcessor):
9
"Subclass `TabularProcessor for `process_one`."
10
def process_one(self, item):
11
res = super().process_one(item)
12
return CollabLine(res.cats,res.conts,res.classes,res.names)
13
14
class CollabLine(TabularLine):
15
"Base item for collaborative filtering, subclasses `TabularLine`."
16
def __init__(self, cats, conts, classes, names):
17
super().__init__(cats, conts, classes, names)
18
self.data = [self.data[0][0],self.data[0][1]]
19
20
class CollabList(TabularList):
21
"Base `ItemList` for collaborative filtering, subclasses `TabularList`."
22
_item_cls,_label_cls,_processor = CollabLine,FloatList,CollabProcessor
23
24
def reconstruct(self, t:Tensor): return CollabLine(tensor(t), tensor([]), self.classes, self.col_names)
25
26
class EmbeddingNN(TabularModel):
27
"Subclass `TabularModel` to create a NN suitable for collaborative filtering."
28
def __init__(self, emb_szs:ListSizes, layers:Collection[int]=None, ps:Collection[float]=None,
29
emb_drop:float=0., y_range:OptRange=None, use_bn:bool=True, bn_final:bool=False):
30
super().__init__(emb_szs=emb_szs, n_cont=0, out_sz=1, layers=layers, ps=ps, emb_drop=emb_drop, y_range=y_range,
31
use_bn=use_bn, bn_final=bn_final)
32
33
def forward(self, users:LongTensor, items:LongTensor) -> Tensor:
34
return super().forward(torch.stack([users,items], dim=1), None)
35
36
class EmbeddingDotBias(Module):
37
"Base dot model for collaborative filtering."
38
def __init__(self, n_factors:int, n_users:int, n_items:int, y_range:Tuple[float,float]=None):
39
self.y_range = y_range
40
(self.u_weight, self.i_weight, self.u_bias, self.i_bias) = [embedding(*o) for o in [
41
(n_users, n_factors), (n_items, n_factors), (n_users,1), (n_items,1)
42
]]
43
44
def forward(self, users:LongTensor, items:LongTensor) -> Tensor:
45
dot = self.u_weight(users)* self.i_weight(items)
46
res = dot.sum(1) + self.u_bias(users).squeeze() + self.i_bias(items).squeeze()
47
if self.y_range is None: return res
48
return torch.sigmoid(res) * (self.y_range[1]-self.y_range[0]) + self.y_range[0]
49
50
class CollabDataBunch(DataBunch):
51
"Base `DataBunch` for collaborative filtering."
52
@classmethod
53
def from_df(cls, ratings:DataFrame, valid_pct:float=0.2, user_name:Optional[str]=None, item_name:Optional[str]=None,
54
rating_name:Optional[str]=None, test:DataFrame=None, seed:int=None, path:PathOrStr='.', bs:int=64,
55
val_bs:int=None, num_workers:int=defaults.cpus, dl_tfms:Optional[Collection[Callable]]=None,
56
device:torch.device=None, collate_fn:Callable=data_collate, no_check:bool=False) -> 'CollabDataBunch':
57
"Create a `DataBunch` suitable for collaborative filtering from `ratings`."
58
user_name = ifnone(user_name, ratings.columns[0])
59
item_name = ifnone(item_name, ratings.columns[1])
60
rating_name = ifnone(rating_name,ratings.columns[2])
61
cat_names = [user_name,item_name]
62
src = (CollabList.from_df(ratings, cat_names=cat_names, procs=Categorify)
63
.split_by_rand_pct(valid_pct=valid_pct, seed=seed).label_from_df(cols=rating_name))
64
if test is not None: src.add_test(CollabList.from_df(test, cat_names=cat_names))
65
return src.databunch(path=path, bs=bs, val_bs=val_bs, num_workers=num_workers, device=device,
66
collate_fn=collate_fn, no_check=no_check)
67
68
class CollabLearner(Learner):
69
"`Learner` suitable for collaborative filtering."
70
def get_idx(self, arr:Collection, is_item:bool=True):
71
"Fetch item or user (based on `is_item`) for all in `arr`. (Set model to `cpu` and no grad.)"
72
m = self.model.eval().cpu()
73
requires_grad(m,False)
74
u_class,i_class = self.data.train_ds.x.classes.values()
75
classes = i_class if is_item else u_class
76
c2i = {v:k for k,v in enumerate(classes)}
77
try: return tensor([c2i[o] for o in arr])
78
except Exception as e:
79
print(f"""You're trying to access {'an item' if is_item else 'a user'} that isn't in the training data.
80
If it was in your original data, it may have been split such that it's only in the validation set now.""")
81
82
def bias(self, arr:Collection, is_item:bool=True):
83
"Bias for item or user (based on `is_item`) for all in `arr`. (Set model to `cpu` and no grad.)"
84
idx = self.get_idx(arr, is_item)
85
m = self.model
86
layer = m.i_bias if is_item else m.u_bias
87
return layer(idx).squeeze()
88
89
def weight(self, arr:Collection, is_item:bool=True):
90
"Bias for item or user (based on `is_item`) for all in `arr`. (Set model to `cpu` and no grad.)"
91
idx = self.get_idx(arr, is_item)
92
m = self.model
93
layer = m.i_weight if is_item else m.u_weight
94
return layer(idx)
95
96
def collab_learner(data, n_factors:int=None, use_nn:bool=False, emb_szs:Dict[str,int]=None, layers:Collection[int]=None,
97
ps:Collection[float]=None, emb_drop:float=0., y_range:OptRange=None, use_bn:bool=True,
98
bn_final:bool=False, **learn_kwargs)->Learner:
99
"Create a Learner for collaborative filtering on `data`."
100
emb_szs = data.get_emb_szs(ifnone(emb_szs, {}))
101
u,m = data.train_ds.x.classes.values()
102
if use_nn: model = EmbeddingNN(emb_szs=emb_szs, layers=layers, ps=ps, emb_drop=emb_drop, y_range=y_range,
103
use_bn=use_bn, bn_final=bn_final, **learn_kwargs)
104
else: model = EmbeddingDotBias(n_factors, len(u), len(m), y_range=y_range)
105
return CollabLearner(data, model, **learn_kwargs)
106
107
108