Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
jantic
GitHub Repository: jantic/deoldify
Path: blob/master/fastai/torch_core.py
781 views
1
"Utility functions to help deal with tensors"
2
from .imports.torch import *
3
from .core import *
4
from collections import OrderedDict
5
from torch.nn.parallel import DistributedDataParallel
6
7
AffineMatrix = Tensor
8
BoolOrTensor = Union[bool,Tensor]
9
FloatOrTensor = Union[float,Tensor]
10
IntOrTensor = Union[int,Tensor]
11
ItemsList = Collection[Union[Tensor,ItemBase,'ItemsList',float,int]]
12
LambdaFunc = Callable[[Tensor],Tensor]
13
LayerFunc = Callable[[nn.Module],None]
14
ModuleList = Collection[nn.Module]
15
NPArray = np.ndarray
16
OptOptimizer = Optional[optim.Optimizer]
17
ParamList = Collection[nn.Parameter]
18
Rank0Tensor = NewType('OneEltTensor', Tensor)
19
SplitFunc = Callable[[nn.Module], List[nn.Module]]
20
SplitFuncOrIdxList = Union[Callable, Collection[ModuleList]]
21
TensorOrNumber = Union[Tensor,Number]
22
TensorOrNumList = Collection[TensorOrNumber]
23
TensorImage = Tensor
24
TensorImageSize = Tuple[int,int,int]
25
Tensors = Union[Tensor, Collection['Tensors']]
26
Weights = Dict[str,Tensor]
27
28
AffineFunc = Callable[[KWArgs], AffineMatrix]
29
HookFunc = Callable[[nn.Module, Tensors, Tensors], Any]
30
LogitTensorImage = TensorImage
31
LossFunction = Callable[[Tensor, Tensor], Rank0Tensor]
32
MetricFunc = Callable[[Tensor,Tensor],TensorOrNumber]
33
MetricFuncList = Collection[MetricFunc]
34
MetricsList = Collection[TensorOrNumber]
35
OptLossFunc = Optional[LossFunction]
36
OptMetrics = Optional[MetricsList]
37
OptSplitFunc = Optional[SplitFunc]
38
PixelFunc = Callable[[TensorImage, ArgStar, KWArgs], TensorImage]
39
40
LightingFunc = Callable[[LogitTensorImage, ArgStar, KWArgs], LogitTensorImage]
41
42
fastai_types = {
43
AnnealFunc:'AnnealFunc', ArgStar:'ArgStar', BatchSamples:'BatchSamples',
44
FilePathList:'FilePathList', Floats:'Floats', ImgLabel:'ImgLabel', ImgLabels:'ImgLabels', KeyFunc:'KeyFunc',
45
KWArgs:'KWArgs', ListOrItem:'ListOrItem', ListRules:'ListRules', ListSizes:'ListSizes',
46
NPArrayableList:'NPArrayableList', NPArrayList:'NPArrayList', NPArrayMask:'NPArrayMask', NPImage:'NPImage',
47
OptDataFrame:'OptDataFrame', OptListOrItem:'OptListOrItem', OptRange:'OptRange', OptStrTuple:'OptStrTuple',
48
OptStats:'OptStats', PathOrStr:'PathOrStr', PBar:'PBar', Point:'Point', Points:'Points', Sizes:'Sizes',
49
SplitArrayList:'SplitArrayList', StartOptEnd:'StartOptEnd', StrList:'StrList', Tokens:'Tokens',
50
OptStrList:'OptStrList', AffineMatrix:'AffineMatrix', BoolOrTensor:'BoolOrTensor', FloatOrTensor:'FloatOrTensor',
51
IntOrTensor:'IntOrTensor', ItemsList:'ItemsList', LambdaFunc:'LambdaFunc',
52
LayerFunc:'LayerFunc', ModuleList:'ModuleList', OptOptimizer:'OptOptimizer', ParamList:'ParamList',
53
Rank0Tensor:'Rank0Tensor', SplitFunc:'SplitFunc', SplitFuncOrIdxList:'SplitFuncOrIdxList',
54
TensorOrNumber:'TensorOrNumber', TensorOrNumList:'TensorOrNumList', TensorImage:'TensorImage',
55
TensorImageSize:'TensorImageSize', Tensors:'Tensors', Weights:'Weights', AffineFunc:'AffineFunc',
56
HookFunc:'HookFunc', LogitTensorImage:'LogitTensorImage', LossFunction:'LossFunction', MetricFunc:'MetricFunc',
57
MetricFuncList:'MetricFuncList', MetricsList:'MetricsList', OptLossFunc:'OptLossFunc', OptMetrics:'OptMetrics',
58
OptSplitFunc:'OptSplitFunc', PixelFunc:'PixelFunc', LightingFunc:'LightingFunc', IntsOrStrs:'IntsOrStrs',
59
PathLikeOrBinaryStream:'PathLikeOrBinaryStream'
60
}
61
62
bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)
63
bias_types = (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)
64
def is_pool_type(l:Callable): return re.search(r'Pool[123]d$', l.__class__.__name__)
65
no_wd_types = bn_types + (nn.LayerNorm,)
66
defaults.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
67
AdamW = partial(optim.Adam, betas=(0.9,0.99))
68
69
#Monkey-patch `torch.cuda.set_device` so that it updates `defaults.device`
70
_old_torch_cuda_set_device = torch.cuda.set_device
71
def _new_torch_cuda_set_device(device):
72
_old_torch_cuda_set_device(device)
73
defaults.device = torch.device('cuda', device) if isinstance(device, int) else device
74
torch.cuda.set_device = _new_torch_cuda_set_device
75
76
def tensor(x:Any, *rest)->Tensor:
77
"Like `torch.as_tensor`, but handle lists too, and can pass multiple vector elements directly."
78
if len(rest): x = (x,)+rest
79
# XXX: Pytorch bug in dataloader using num_workers>0; TODO: create repro and report
80
if is_listy(x) and len(x)==0: return tensor(0)
81
res = torch.tensor(x) if is_listy(x) else as_tensor(x)
82
if res.dtype is torch.int32:
83
warn('Tensor is int32: upgrading to int64; for better performance use int64 input')
84
return res.long()
85
return res
86
87
class Module(nn.Module, metaclass=PrePostInitMeta):
88
"Same as `nn.Module`, but no need for subclasses to call `super().__init__`"
89
def __pre_init__(self): super().__init__()
90
def __init__(self): pass
91
92
def np_address(x:np.ndarray)->int:
93
"Address of `x` in memory."
94
return x.__array_interface__['data'][0]
95
96
def to_detach(b:Tensors, cpu:bool=True):
97
"Recursively detach lists of tensors in `b `; put them on the CPU if `cpu=True`."
98
def _inner(x, cpu=True):
99
if not isinstance(x,Tensor): return x
100
x = x.detach()
101
return x.cpu() if cpu else x
102
return recurse(_inner, b, cpu=cpu)
103
104
def to_data(b:ItemsList):
105
"Recursively map lists of items in `b ` to their wrapped data."
106
return recurse(lambda x: x.data if isinstance(x,ItemBase) else x, b)
107
108
def to_cpu(b:ItemsList):
109
"Recursively map lists of tensors in `b ` to the cpu."
110
return recurse(lambda x: x.cpu() if isinstance(x,Tensor) else x, b)
111
112
def to_half(b:Collection[Tensor])->Collection[Tensor]:
113
"Recursively map lists of tensors in `b ` to FP16."
114
return recurse(lambda x: x.half() if x.dtype not in [torch.int64, torch.int32, torch.int16] else x, b)
115
116
def to_float(b:Collection[Tensor])->Collection[Tensor]:
117
"Recursively map lists of tensors in `b ` to FP16."
118
return recurse(lambda x: x.float() if x.dtype not in [torch.int64, torch.int32, torch.int16] else x, b)
119
120
def to_device(b:Tensors, device:torch.device):
121
"Recursively put `b` on `device`."
122
device = ifnone(device, defaults.device)
123
return recurse(lambda x: x.to(device, non_blocking=True), b)
124
125
def data_collate(batch:ItemsList)->Tensor:
126
"Convert `batch` items to tensor data."
127
return torch.utils.data.dataloader.default_collate(to_data(batch))
128
129
def requires_grad(m:nn.Module, b:Optional[bool]=None)->Optional[bool]:
130
"If `b` is not set return `requires_grad` of first param, else set `requires_grad` on all params as `b`"
131
ps = list(m.parameters())
132
if not ps: return None
133
if b is None: return ps[0].requires_grad
134
for p in ps: p.requires_grad=b
135
136
def trainable_params(m:nn.Module)->ParamList:
137
"Return list of trainable params in `m`."
138
res = filter(lambda p: p.requires_grad, m.parameters())
139
return res
140
141
def children(m:nn.Module)->ModuleList:
142
"Get children of `m`."
143
return list(m.children())
144
145
def num_children(m:nn.Module)->int:
146
"Get number of children modules in `m`."
147
return len(children(m))
148
149
def range_children(m:nn.Module)->Iterator[int]:
150
"Return iterator of len of children of `m`."
151
return range(num_children(m))
152
153
class ParameterModule(Module):
154
"Register a lone parameter `p` in a module."
155
def __init__(self, p:nn.Parameter): self.val = p
156
def forward(self, x): return x
157
158
def children_and_parameters(m:nn.Module):
159
"Return the children of `m` and its direct parameters not registered in modules."
160
children = list(m.children())
161
children_p = sum([[id(p) for p in c.parameters()] for c in m.children()],[])
162
for p in m.parameters():
163
if id(p) not in children_p: children.append(ParameterModule(p))
164
return children
165
166
def flatten_model(m:nn.Module):
167
if num_children(m):
168
mapped = map(flatten_model,children_and_parameters(m))
169
return sum(mapped,[])
170
else:
171
return [m]
172
173
#flatten_model = lambda m: sum(map(flatten_model,children_and_parameters(m)),[]) if num_children(m) else [m]
174
175
def first_layer(m:nn.Module)->nn.Module:
176
"Retrieve first layer in a module `m`."
177
return flatten_model(m)[0]
178
179
def last_layer(m:nn.Module)->nn.Module:
180
"Retrieve last layer in a module `m`."
181
return flatten_model(m)[-1]
182
183
def split_model_idx(model:nn.Module, idxs:Collection[int])->ModuleList:
184
"Split `model` according to the indexes in `idxs`."
185
layers = flatten_model(model)
186
if idxs[0] != 0: idxs = [0] + idxs
187
if idxs[-1] != len(layers): idxs.append(len(layers))
188
return [nn.Sequential(*layers[i:j]) for i,j in zip(idxs[:-1],idxs[1:])]
189
190
def split_model(model:nn.Module=None, splits:Collection[Union[nn.Module,ModuleList]]=None):
191
"Split `model` according to the layers in `splits`."
192
splits = listify(splits)
193
if isinstance(splits[0], nn.Module):
194
layers = flatten_model(model)
195
idxs = [layers.index(first_layer(s)) for s in splits]
196
return split_model_idx(model, idxs)
197
return [nn.Sequential(*s) for s in splits]
198
199
def get_param_groups(layer_groups:Collection[nn.Module])->List[List[nn.Parameter]]:
200
return [sum([list(trainable_params(c)) for c in l.children()], []) for l in layer_groups]
201
202
def split_no_wd_params(layer_groups:Collection[nn.Module])->List[List[nn.Parameter]]:
203
"Separate the parameters in `layer_groups` between `no_wd_types` and bias (`bias_types`) from the rest."
204
split_params = []
205
for l in layer_groups:
206
l1,l2 = [],[]
207
for c in l.children():
208
if isinstance(c, no_wd_types): l2 += list(trainable_params(c))
209
elif isinstance(c, bias_types):
210
bias = c.bias if hasattr(c, 'bias') else None
211
l1 += [p for p in trainable_params(c) if not (p is bias)]
212
if bias is not None: l2.append(bias)
213
else: l1 += list(trainable_params(c))
214
#Since we scan the children separately, we might get duplicates (tied weights). We need to preserve the order
215
#for the optimizer load of state_dict
216
l1,l2 = uniqueify(l1),uniqueify(l2)
217
split_params += [l1, l2]
218
return split_params
219
220
def set_bn_eval(m:nn.Module)->None:
221
"Set bn layers in eval mode for all recursive children of `m`."
222
for l in m.children():
223
if isinstance(l, bn_types) and not next(l.parameters()).requires_grad:
224
l.eval()
225
set_bn_eval(l)
226
227
def batch_to_half(b:Collection[Tensor])->Collection[Tensor]:
228
"Set the input of batch `b` to half precision."
229
return [to_half(b[0]), b[1]]
230
231
def bn2float(module:nn.Module)->nn.Module:
232
"If `module` is batchnorm don't use half precision."
233
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): module.float()
234
for child in module.children(): bn2float(child)
235
return module
236
237
def model2half(model:nn.Module)->nn.Module:
238
"Convert `model` to half precision except the batchnorm layers."
239
return bn2float(model.half())
240
241
def init_default(m:nn.Module, func:LayerFunc=nn.init.kaiming_normal_)->nn.Module:
242
"Initialize `m` weights with `func` and set `bias` to 0."
243
if func:
244
if hasattr(m, 'weight'): func(m.weight)
245
if hasattr(m, 'bias') and hasattr(m.bias, 'data'): m.bias.data.fill_(0.)
246
return m
247
248
def cond_init(m:nn.Module, init_func:LayerFunc):
249
"Initialize the non-batchnorm layers of `m` with `init_func`."
250
if (not isinstance(m, bn_types)) and requires_grad(m): init_default(m, init_func)
251
252
def apply_leaf(m:nn.Module, f:LayerFunc):
253
"Apply `f` to children of `m`."
254
c = children(m)
255
if isinstance(m, nn.Module): f(m)
256
for l in c: apply_leaf(l,f)
257
258
def apply_init(m, init_func:LayerFunc):
259
"Initialize all non-batchnorm layers of `m` with `init_func`."
260
apply_leaf(m, partial(cond_init, init_func=init_func))
261
262
def in_channels(m:nn.Module) -> List[int]:
263
"Return the shape of the first weight layer in `m`."
264
for l in flatten_model(m):
265
if hasattr(l, 'weight'): return l.weight.shape[1]
266
raise Exception('No weight layer')
267
268
class ModelOnCPU():
269
"A context manager to evaluate `model` on the CPU inside."
270
def __init__(self, model:nn.Module): self.model = model
271
def __enter__(self):
272
self.device = one_param(self.model).device
273
return self.model.cpu()
274
def __exit__(self, type, value, traceback):
275
self.model = self.model.to(self.device)
276
277
class NoneReduceOnCPU():
278
"A context manager to evaluate `loss_func` with none reduce and weights on the CPU inside."
279
def __init__(self, loss_func:LossFunction):
280
self.loss_func,self.device,self.old_red = loss_func,None,None
281
282
def __enter__(self):
283
if hasattr(self.loss_func, 'weight') and self.loss_func.weight is not None:
284
self.device = self.loss_func.weight.device
285
self.loss_func.weight = self.loss_func.weight.cpu()
286
if hasattr(self.loss_func, 'reduction'):
287
self.old_red = getattr(self.loss_func, 'reduction')
288
setattr(self.loss_func, 'reduction', 'none')
289
return self.loss_func
290
else: return partial(self.loss_func, reduction='none')
291
292
def __exit__(self, type, value, traceback):
293
if self.device is not None: self.loss_func.weight = self.loss_func.weight.to(self.device)
294
if self.old_red is not None: setattr(self.loss_func, 'reduction', self.old_red)
295
296
def model_type(dtype):
297
"Return the torch type corresponding to `dtype`."
298
return (torch.float32 if np.issubdtype(dtype, np.floating) else
299
torch.int64 if np.issubdtype(dtype, np.integer)
300
else None)
301
302
def np2model_tensor(a):
303
"Tranform numpy array `a` to a tensor of the same type."
304
dtype = model_type(a.dtype)
305
res = as_tensor(a)
306
if not dtype: return res
307
return res.type(dtype)
308
309
def _pca(x, k=2):
310
"Compute PCA of `x` with `k` dimensions."
311
x = x-torch.mean(x,0)
312
U,S,V = torch.svd(x.t())
313
return torch.mm(x,U[:,:k])
314
torch.Tensor.pca = _pca
315
316
def trange_of(x):
317
"Create a tensor from `range_of(x)`."
318
return torch.arange(len(x))
319
320
def to_np(x):
321
"Convert a tensor to a numpy array."
322
return x.data.cpu().numpy()
323
324
# monkey patching to allow matplotlib to plot tensors
325
def tensor__array__(self, dtype=None):
326
res = to_np(self)
327
if dtype is None: return res
328
else: return res.astype(dtype, copy=False)
329
Tensor.__array__ = tensor__array__
330
Tensor.ndim = property(lambda x: len(x.shape))
331
332
def grab_idx(x,i,batch_first:bool=True):
333
"Grab the `i`-th batch in `x`, `batch_first` stating the batch dimension."
334
if batch_first: return ([o[i].cpu() for o in x] if is_listy(x) else x[i].cpu())
335
else: return ([o[:,i].cpu() for o in x] if is_listy(x) else x[:,i].cpu())
336
337
def logit(x:Tensor)->Tensor:
338
"Logit of `x`, clamped to avoid inf."
339
x = x.clamp(1e-7, 1-1e-7)
340
return -(1/x-1).log()
341
342
def logit_(x:Tensor)->Tensor:
343
"Inplace logit of `x`, clamped to avoid inf"
344
x.clamp_(1e-7, 1-1e-7)
345
return (x.reciprocal_().sub_(1)).log_().neg_()
346
347
def set_all_seed(seed:int)->None:
348
"Sets the seeds for all pseudo random generators in fastai lib"
349
np.random.seed(seed)
350
torch.manual_seed(seed)
351
random.seed(seed)
352
353
def uniform(low:Number, high:Number=None, size:Optional[List[int]]=None)->FloatOrTensor:
354
"Draw 1 or shape=`size` random floats from uniform dist: min=`low`, max=`high`."
355
if high is None: high=low
356
return random.uniform(low,high) if size is None else torch.FloatTensor(*listify(size)).uniform_(low,high)
357
358
def log_uniform(low, high, size:Optional[List[int]]=None)->FloatOrTensor:
359
"Draw 1 or shape=`size` random floats from uniform dist: min=log(`low`), max=log(`high`)."
360
res = uniform(log(low), log(high), size)
361
return exp(res) if size is None else res.exp_()
362
363
def rand_bool(p:float, size:Optional[List[int]]=None)->BoolOrTensor:
364
"Draw 1 or shape=`size` random booleans (`True` occuring with probability `p`)."
365
return uniform(0,1,size)<p
366
367
def uniform_int(low:int, high:int, size:Optional[List[int]]=None)->IntOrTensor:
368
"Generate int or tensor `size` of ints between `low` and `high` (included)."
369
return random.randint(low,high) if size is None else torch.randint(low,high+1,size)
370
371
def one_param(m: nn.Module)->Tensor:
372
"Return the first parameter of `m`."
373
return next(m.parameters())
374
375
def try_int(o:Any)->Any:
376
"Try to convert `o` to int, default to `o` if not possible."
377
# NB: single-item rank-1 array/tensor can be converted to int, but we don't want to do this
378
if isinstance(o, (np.ndarray,Tensor)): return o if o.ndim else int(o)
379
if isinstance(o, collections.abc.Sized) or getattr(o,'__array_interface__',False): return o
380
try: return int(o)
381
except: return o
382
383
def get_model(model:nn.Module):
384
"Return the model maybe wrapped inside `model`."
385
return model.module if isinstance(model, (DistributedDataParallel, nn.DataParallel)) else model
386
387
def flatten_check(out:Tensor, targ:Tensor) -> Tensor:
388
"Check that `out` and `targ` have the same number of elements and flatten them."
389
out,targ = out.contiguous().view(-1),targ.contiguous().view(-1)
390
assert len(out) == len(targ), f"Expected output and target to have the same number of elements but got {len(out)} and {len(targ)}."
391
return out,targ
392
393
#Monkey-patch nn.DataParallel.reset
394
def _data_parallel_reset(self):
395
if hasattr(self.module, 'reset'): self.module.reset()
396
nn.DataParallel.reset = _data_parallel_reset
397
398
def remove_module_load(state_dict):
399
"""create new OrderedDict that does not contain `module.`"""
400
new_state_dict = OrderedDict()
401
for k, v in state_dict.items(): new_state_dict[k[7:]] = v
402
return new_state_dict
403
404
def num_distrib():
405
"Return the number of processes in distributed training (if applicable)."
406
return int(os.environ.get('WORLD_SIZE', 0))
407
408
def rank_distrib():
409
"Return the distributed rank of this process (if applicable)."
410
return int(os.environ.get('RANK', 0))
411
412
def add_metrics(last_metrics:Collection[Rank0Tensor], mets:Union[Rank0Tensor, Collection[Rank0Tensor]]):
413
"Return a dictionary for updating `last_metrics` with `mets`."
414
last_metrics,mets = listify(last_metrics),listify(mets)
415
return {'last_metrics': last_metrics + mets}
416
417
def try_save(state:Dict, path:Path=None, file:PathLikeOrBinaryStream=None):
418
target = open(path/file, 'wb') if is_pathlike(file) else file
419
try: torch.save(state, target)
420
except OSError as e:
421
raise Exception(f"{e}\n Can't write {path/file}. Pass an absolute writable pathlib obj `fname`.")
422
423
def np_func(f):
424
"Convert a function taking and returning numpy arrays to one taking and returning tensors"
425
def _inner(*args, **kwargs):
426
nargs = [to_np(arg) if isinstance(arg,Tensor) else arg for arg in args]
427
return tensor(f(*nargs, **kwargs))
428
functools.update_wrapper(_inner, f)
429
return _inner
430
431
432