Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
lucidrains
GitHub Repository: lucidrains/vit-pytorch
Path: blob/main/vit_pytorch/extractor.py
649 views
1
import torch
2
from torch import nn
3
4
def exists(val):
5
return val is not None
6
7
def identity(t):
8
return t
9
10
def clone_and_detach(t):
11
return t.clone().detach()
12
13
def apply_tuple_or_single(fn, val):
14
if isinstance(val, tuple):
15
return tuple(map(fn, val))
16
return fn(val)
17
18
class Extractor(nn.Module):
19
def __init__(
20
self,
21
vit,
22
device = None,
23
layer = None,
24
layer_name = 'transformer',
25
layer_save_input = False,
26
return_embeddings_only = False,
27
detach = True
28
):
29
super().__init__()
30
self.vit = vit
31
32
self.data = None
33
self.latents = None
34
self.hooks = []
35
self.hook_registered = False
36
self.ejected = False
37
self.device = device
38
39
self.layer = layer
40
self.layer_name = layer_name
41
self.layer_save_input = layer_save_input # whether to save input or output of layer
42
self.return_embeddings_only = return_embeddings_only
43
44
self.detach_fn = clone_and_detach if detach else identity
45
46
def _hook(self, _, inputs, output):
47
layer_output = inputs if self.layer_save_input else output
48
self.latents = apply_tuple_or_single(self.detach_fn, layer_output)
49
50
def _register_hook(self):
51
if not exists(self.layer):
52
assert hasattr(self.vit, self.layer_name), 'layer whose output to take as embedding not found in vision transformer'
53
layer = getattr(self.vit, self.layer_name)
54
else:
55
layer = self.layer
56
57
handle = layer.register_forward_hook(self._hook)
58
self.hooks.append(handle)
59
self.hook_registered = True
60
61
def eject(self):
62
self.ejected = True
63
for hook in self.hooks:
64
hook.remove()
65
self.hooks.clear()
66
return self.vit
67
68
def clear(self):
69
del self.latents
70
self.latents = None
71
72
def forward(
73
self,
74
img,
75
return_embeddings_only = False
76
):
77
assert not self.ejected, 'extractor has been ejected, cannot be used anymore'
78
self.clear()
79
if not self.hook_registered:
80
self._register_hook()
81
82
pred = self.vit(img)
83
84
target_device = self.device if exists(self.device) else img.device
85
latents = apply_tuple_or_single(lambda t: t.to(target_device), self.latents)
86
87
if return_embeddings_only or self.return_embeddings_only:
88
return latents
89
90
return pred, latents
91
92