Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
jantic
GitHub Repository: jantic/deoldify
Path: blob/master/fastai/callbacks/tensorboard.py
781 views
1
"Provides convenient callbacks for Learners that write model images, metrics/losses, stats and histograms to Tensorboard"
2
from ..basic_train import Learner
3
from ..basic_data import DatasetType, DataBunch
4
from ..vision import Image
5
from ..vision.gan import GANLearner
6
from ..callbacks import LearnerCallback
7
from ..core import *
8
from ..torch_core import *
9
from threading import Thread, Event
10
from time import sleep
11
from queue import Queue
12
import statistics
13
import torchvision.utils as vutils
14
from abc import ABC
15
#This is an optional dependency in fastai. Must install separately.
16
try: from tensorboardX import SummaryWriter
17
except: print("To use this tracker, please run 'pip install tensorboardx'. Also you must have Tensorboard running to see results")
18
19
__all__=['LearnerTensorboardWriter', 'GANTensorboardWriter', 'ImageGenTensorboardWriter']
20
21
#---Example usage (applies to any of the callbacks)---
22
# proj_id = 'Colorize'
23
# tboard_path = Path('data/tensorboard/' + proj_id)
24
# learn.callback_fns.append(partial(GANTensorboardWriter, base_dir=tboard_path, name='GanLearner'))
25
26
class LearnerTensorboardWriter(LearnerCallback):
27
"Broadly useful callback for Learners that writes to Tensorboard. Writes model histograms, losses/metrics, and gradient stats."
28
def __init__(self, learn:Learner, base_dir:Path, name:str, loss_iters:int=25, hist_iters:int=500, stats_iters:int=100):
29
super().__init__(learn=learn)
30
self.base_dir,self.name,self.loss_iters,self.hist_iters,self.stats_iters = base_dir,name,loss_iters,hist_iters,stats_iters
31
log_dir = base_dir/name
32
self.tbwriter = SummaryWriter(str(log_dir))
33
self.hist_writer = HistogramTBWriter()
34
self.stats_writer = ModelStatsTBWriter()
35
#self.graph_writer = GraphTBWriter()
36
self.data = None
37
self.metrics_root = '/metrics/'
38
self._update_batches_if_needed()
39
40
def _get_new_batch(self, ds_type:DatasetType)->Collection[Tensor]:
41
"Retrieves new batch of DatasetType, and detaches it."
42
return self.learn.data.one_batch(ds_type=ds_type, detach=True, denorm=False, cpu=False)
43
44
def _update_batches_if_needed(self)->None:
45
"one_batch function is extremely slow with large datasets. This is caching the result as an optimization."
46
if self.learn.data.valid_dl is None: return # Running learning rate finder, so return
47
update_batches = self.data is not self.learn.data
48
if not update_batches: return
49
self.data = self.learn.data
50
self.trn_batch = self._get_new_batch(ds_type=DatasetType.Train)
51
self.val_batch = self._get_new_batch(ds_type=DatasetType.Valid)
52
53
def _write_model_stats(self, iteration:int)->None:
54
"Writes gradient statistics to Tensorboard."
55
self.stats_writer.write(model=self.learn.model, iteration=iteration, tbwriter=self.tbwriter)
56
57
def _write_training_loss(self, iteration:int, last_loss:Tensor)->None:
58
"Writes training loss to Tensorboard."
59
scalar_value = to_np(last_loss)
60
tag = self.metrics_root + 'train_loss'
61
self.tbwriter.add_scalar(tag=tag, scalar_value=scalar_value, global_step=iteration)
62
63
def _write_weight_histograms(self, iteration:int)->None:
64
"Writes model weight histograms to Tensorboard."
65
self.hist_writer.write(model=self.learn.model, iteration=iteration, tbwriter=self.tbwriter)
66
67
def _write_scalar(self, name:str, scalar_value, iteration:int)->None:
68
"Writes single scalar value to Tensorboard."
69
tag = self.metrics_root + name
70
self.tbwriter.add_scalar(tag=tag, scalar_value=scalar_value, global_step=iteration)
71
72
#TODO: Relying on a specific hardcoded start_idx here isn't great. Is there a better solution?
73
def _write_metrics(self, iteration:int, last_metrics:MetricsList, start_idx:int=2)->None:
74
"Writes training metrics to Tensorboard."
75
recorder = self.learn.recorder
76
for i, name in enumerate(recorder.names[start_idx:]):
77
if last_metrics is None or len(last_metrics) < i+1: return
78
scalar_value = last_metrics[i]
79
self._write_scalar(name=name, scalar_value=scalar_value, iteration=iteration)
80
81
def on_train_begin(self, **kwargs: Any) -> None:
82
#self.graph_writer.write(model=self.learn.model, tbwriter=self.tbwriter,
83
#input_to_model=next(iter(self.learn.data.dl(DatasetType.Single)))[0])
84
return
85
86
def on_batch_end(self, last_loss:Tensor, iteration:int, **kwargs)->None:
87
"Callback function that writes batch end appropriate data to Tensorboard."
88
if iteration == 0: return
89
self._update_batches_if_needed()
90
if iteration % self.loss_iters == 0: self._write_training_loss(iteration=iteration, last_loss=last_loss)
91
if iteration % self.hist_iters == 0: self._write_weight_histograms(iteration=iteration)
92
93
# Doing stuff here that requires gradient info, because they get zeroed out afterwards in training loop
94
def on_backward_end(self, iteration:int, **kwargs)->None:
95
"Callback function that writes backward end appropriate data to Tensorboard."
96
if iteration == 0: return
97
self._update_batches_if_needed()
98
if iteration % self.stats_iters == 0: self._write_model_stats(iteration=iteration)
99
100
def on_epoch_end(self, last_metrics:MetricsList, iteration:int, **kwargs)->None:
101
"Callback function that writes epoch end appropriate data to Tensorboard."
102
self._write_metrics(iteration=iteration, last_metrics=last_metrics)
103
104
# TODO: We're overriding almost everything here. Seems like a good idea to question that ("is a" vs "has a")
105
class GANTensorboardWriter(LearnerTensorboardWriter):
106
"Callback for GANLearners that writes to Tensorboard. Extends LearnerTensorboardWriter and adds output image writes."
107
def __init__(self, learn:GANLearner, base_dir:Path, name:str, loss_iters:int=25, hist_iters:int=500,
108
stats_iters:int=100, visual_iters:int=100):
109
super().__init__(learn=learn, base_dir=base_dir, name=name, loss_iters=loss_iters, hist_iters=hist_iters, stats_iters=stats_iters)
110
self.visual_iters = visual_iters
111
self.img_gen_vis = ImageTBWriter()
112
self.gen_stats_updated = True
113
self.crit_stats_updated = True
114
115
def _write_weight_histograms(self, iteration:int)->None:
116
"Writes model weight histograms to Tensorboard."
117
generator, critic = self.learn.gan_trainer.generator, self.learn.gan_trainer.critic
118
self.hist_writer.write(model=generator, iteration=iteration, tbwriter=self.tbwriter, name='generator')
119
self.hist_writer.write(model=critic, iteration=iteration, tbwriter=self.tbwriter, name='critic')
120
121
def _write_gen_model_stats(self, iteration:int)->None:
122
"Writes gradient statistics for generator to Tensorboard."
123
generator = self.learn.gan_trainer.generator
124
self.stats_writer.write(model=generator, iteration=iteration, tbwriter=self.tbwriter, name='gen_model_stats')
125
self.gen_stats_updated = True
126
127
def _write_critic_model_stats(self, iteration:int)->None:
128
"Writes gradient statistics for critic to Tensorboard."
129
critic = self.learn.gan_trainer.critic
130
self.stats_writer.write(model=critic, iteration=iteration, tbwriter=self.tbwriter, name='crit_model_stats')
131
self.crit_stats_updated = True
132
133
def _write_model_stats(self, iteration:int)->None:
134
"Writes gradient statistics to Tensorboard."
135
# We don't want to write stats when model is not iterated on and hence has zeroed out gradients
136
gen_mode = self.learn.gan_trainer.gen_mode
137
if gen_mode and not self.gen_stats_updated: self._write_gen_model_stats(iteration=iteration)
138
if not gen_mode and not self.crit_stats_updated: self._write_critic_model_stats(iteration=iteration)
139
140
def _write_training_loss(self, iteration:int, last_loss:Tensor)->None:
141
"Writes training loss to Tensorboard."
142
recorder = self.learn.gan_trainer.recorder
143
if len(recorder.losses) == 0: return
144
scalar_value = to_np((recorder.losses[-1:])[0])
145
tag = self.metrics_root + 'train_loss'
146
self.tbwriter.add_scalar(tag=tag, scalar_value=scalar_value, global_step=iteration)
147
148
def _write_images(self, iteration:int)->None:
149
"Writes model generated, original and real images to Tensorboard."
150
trainer = self.learn.gan_trainer
151
#TODO: Switching gen_mode temporarily seems a bit hacky here. Certainly not a good side-effect. Is there a better way?
152
gen_mode = trainer.gen_mode
153
try:
154
trainer.switch(gen_mode=True)
155
self.img_gen_vis.write(learn=self.learn, trn_batch=self.trn_batch, val_batch=self.val_batch,
156
iteration=iteration, tbwriter=self.tbwriter)
157
finally: trainer.switch(gen_mode=gen_mode)
158
159
def on_batch_end(self, iteration:int, **kwargs)->None:
160
"Callback function that writes batch end appropriate data to Tensorboard."
161
super().on_batch_end(iteration=iteration, **kwargs)
162
if iteration == 0: return
163
if iteration % self.visual_iters == 0: self._write_images(iteration=iteration)
164
165
def on_backward_end(self, iteration:int, **kwargs)->None:
166
"Callback function that writes backward end appropriate data to Tensorboard."
167
if iteration == 0: return
168
self._update_batches_if_needed()
169
#TODO: This could perhaps be implemented as queues of requests instead but that seemed like overkill.
170
# But I'm not the biggest fan of maintaining these boolean flags either... Review pls.
171
if iteration % self.stats_iters == 0: self.gen_stats_updated, self.crit_stats_updated = False, False
172
if not (self.gen_stats_updated and self.crit_stats_updated): self._write_model_stats(iteration=iteration)
173
174
class ImageGenTensorboardWriter(LearnerTensorboardWriter):
175
"Callback for non-GAN image generating Learners that writes to Tensorboard. Extends LearnerTensorboardWriter and adds output image writes."
176
def __init__(self, learn:Learner, base_dir:Path, name:str, loss_iters:int=25, hist_iters:int=500, stats_iters:int=100,
177
visual_iters:int=100):
178
super().__init__(learn=learn, base_dir=base_dir, name=name, loss_iters=loss_iters, hist_iters=hist_iters,
179
stats_iters=stats_iters)
180
self.visual_iters = visual_iters
181
self.img_gen_vis = ImageTBWriter()
182
183
def _write_images(self, iteration:int)->None:
184
"Writes model generated, original and real images to Tensorboard"
185
self.img_gen_vis.write(learn=self.learn, trn_batch=self.trn_batch, val_batch=self.val_batch, iteration=iteration,
186
tbwriter=self.tbwriter)
187
188
def on_batch_end(self, iteration:int, **kwargs)->None:
189
"Callback function that writes batch end appropriate data to Tensorboard."
190
super().on_batch_end(iteration=iteration, **kwargs)
191
if iteration == 0: return
192
if iteration % self.visual_iters == 0:
193
self._write_images(iteration=iteration)
194
195
class TBWriteRequest(ABC):
196
"A request object for Tensorboard writes. Useful for queuing up and executing asynchronous writes."
197
def __init__(self, tbwriter: SummaryWriter, iteration:int):
198
super().__init__()
199
self.tbwriter = tbwriter
200
self.iteration = iteration
201
202
@abstractmethod
203
def write(self)->None: pass
204
205
# SummaryWriter writes tend to block quite a bit. This gets around that and greatly boosts performance.
206
# Not all tensorboard writes are using this- just the ones that take a long time. Note that the
207
# SummaryWriter does actually use a threadsafe consumer/producer design ultimately to write to Tensorboard,
208
# so writes done outside of this async loop should be fine.
209
class AsyncTBWriter():
210
"Callback for GANLearners that writes to Tensorboard. Extends LearnerTensorboardWriter and adds output image writes."
211
def __init__(self):
212
super().__init__()
213
self.stop_request = Event()
214
self.queue = Queue()
215
self.thread = Thread(target=self._queue_processor, daemon=True)
216
self.thread.start()
217
218
def request_write(self, request: TBWriteRequest)->None:
219
"Queues up an asynchronous write request to Tensorboard."
220
if self.stop_request.isSet(): return
221
self.queue.put(request)
222
223
def _queue_processor(self)->None:
224
"Processes queued up write requests asynchronously to Tensorboard."
225
while not self.stop_request.isSet():
226
while not self.queue.empty():
227
if self.stop_request.isSet(): return
228
request = self.queue.get()
229
request.write()
230
sleep(0.2)
231
232
#Provided this to stop thread explicitly or by context management (with statement) but thread should end on its own
233
# upon program exit, due to being a daemon. So using this is probably unecessary.
234
def close(self)->None:
235
"Stops asynchronous request queue processing thread."
236
self.stop_request.set()
237
self.thread.join()
238
239
# Nothing to do, thread already started. Could start thread here to enforce use of context manager
240
# (but that sounds like a pain and a bit unweildy and unecessary for actual usage)
241
def __enter__(self): pass
242
243
def __exit__(self, exc_type, exc_value, traceback): self.close()
244
245
asyncTBWriter = AsyncTBWriter()
246
247
class ModelImageSet():
248
"Convenience object that holds the original, real(target) and generated versions of a single image fed to a model."
249
@staticmethod
250
def get_list_from_model(learn:Learner, ds_type:DatasetType, batch:Tuple)->[]:
251
"Factory method to convert a batch of model images to a list of ModelImageSet."
252
image_sets = []
253
x,y = batch[0],batch[1]
254
preds=[]
255
preds = learn.pred_batch(ds_type=ds_type, batch=(x,y), reconstruct=True)
256
for orig_px, real_px, gen in zip(x,y,preds):
257
orig, real = Image(px=orig_px), Image(px=real_px)
258
image_set = ModelImageSet(orig=orig, real=real, gen=gen)
259
image_sets.append(image_set)
260
return image_sets
261
262
def __init__(self, orig:Image, real:Image, gen:Image): self.orig, self.real, self.gen = orig, real, gen
263
264
class HistogramTBRequest(TBWriteRequest):
265
"Request object for model histogram writes to Tensorboard."
266
def __init__(self, model:nn.Module, iteration:int, tbwriter:SummaryWriter, name:str):
267
super().__init__(tbwriter=tbwriter, iteration=iteration)
268
self.params = [(name, values.clone().detach().cpu()) for (name, values) in model.named_parameters()]
269
self.name = name
270
271
def _write_histogram(self, param_name:str, values)->None:
272
"Writes single model histogram to Tensorboard."
273
tag = self.name + '/weights/' + param_name
274
self.tbwriter.add_histogram(tag=tag, values=values, global_step=self.iteration)
275
276
def write(self)->None:
277
"Writes model histograms to Tensorboard."
278
for param_name, values in self.params: self._write_histogram(param_name=param_name, values=values)
279
280
#If this isn't done async then this is sloooooow
281
class HistogramTBWriter():
282
"Writes model histograms to Tensorboard."
283
def __init__(self): super().__init__()
284
285
def write(self, model:nn.Module, iteration:int, tbwriter:SummaryWriter, name:str='model')->None:
286
"Writes model histograms to Tensorboard."
287
request = HistogramTBRequest(model=model, iteration=iteration, tbwriter=tbwriter, name=name)
288
asyncTBWriter.request_write(request)
289
290
class ModelStatsTBRequest(TBWriteRequest):
291
"Request object for model gradient statistics writes to Tensorboard."
292
def __init__(self, model:nn.Module, iteration:int, tbwriter:SummaryWriter, name:str):
293
super().__init__(tbwriter=tbwriter, iteration=iteration)
294
self.gradients = [x.grad.clone().detach().cpu() for x in model.parameters() if x.grad is not None]
295
self.name = name
296
297
def _add_gradient_scalar(self, name:str, scalar_value)->None:
298
"Writes a single scalar value for a gradient statistic to Tensorboard."
299
tag = self.name + '/gradients/' + name
300
self.tbwriter.add_scalar(tag=tag, scalar_value=scalar_value, global_step=self.iteration)
301
302
def _write_avg_norm(self, norms:[])->None:
303
"Writes the average norm of the gradients to Tensorboard."
304
avg_norm = sum(norms)/len(self.gradients)
305
self._add_gradient_scalar('avg_norm', scalar_value=avg_norm)
306
307
def _write_median_norm(self, norms:[])->None:
308
"Writes the median norm of the gradients to Tensorboard."
309
median_norm = statistics.median(norms)
310
self._add_gradient_scalar('median_norm', scalar_value=median_norm)
311
312
def _write_max_norm(self, norms:[])->None:
313
"Writes the maximum norm of the gradients to Tensorboard."
314
max_norm = max(norms)
315
self._add_gradient_scalar('max_norm', scalar_value=max_norm)
316
317
def _write_min_norm(self, norms:[])->None:
318
"Writes the minimum norm of the gradients to Tensorboard."
319
min_norm = min(norms)
320
self._add_gradient_scalar('min_norm', scalar_value=min_norm)
321
322
def _write_num_zeros(self)->None:
323
"Writes the number of zeroes in the gradients to Tensorboard."
324
gradient_nps = [to_np(x.data) for x in self.gradients]
325
num_zeros = sum((np.asarray(x) == 0.0).sum() for x in gradient_nps)
326
self._add_gradient_scalar('num_zeros', scalar_value=num_zeros)
327
328
def _write_avg_gradient(self)->None:
329
"Writes the average of the gradients to Tensorboard."
330
avg_gradient = sum(x.data.mean() for x in self.gradients)/len(self.gradients)
331
self._add_gradient_scalar('avg_gradient', scalar_value=avg_gradient)
332
333
def _write_median_gradient(self)->None:
334
"Writes the median of the gradients to Tensorboard."
335
median_gradient = statistics.median(x.data.median() for x in self.gradients)
336
self._add_gradient_scalar('median_gradient', scalar_value=median_gradient)
337
338
def _write_max_gradient(self)->None:
339
"Writes the maximum of the gradients to Tensorboard."
340
max_gradient = max(x.data.max() for x in self.gradients)
341
self._add_gradient_scalar('max_gradient', scalar_value=max_gradient)
342
343
def _write_min_gradient(self)->None:
344
"Writes the minimum of the gradients to Tensorboard."
345
min_gradient = min(x.data.min() for x in self.gradients)
346
self._add_gradient_scalar('min_gradient', scalar_value=min_gradient)
347
348
def write(self)->None:
349
"Writes model gradient statistics to Tensorboard."
350
if len(self.gradients) == 0: return
351
norms = [x.data.norm() for x in self.gradients]
352
self._write_avg_norm(norms=norms)
353
self._write_median_norm(norms=norms)
354
self._write_max_norm(norms=norms)
355
self._write_min_norm(norms=norms)
356
self._write_num_zeros()
357
self._write_avg_gradient()
358
self._write_median_gradient()
359
self._write_max_gradient()
360
self._write_min_gradient()
361
362
class ModelStatsTBWriter():
363
"Writes model gradient statistics to Tensorboard."
364
def write(self, model:nn.Module, iteration:int, tbwriter:SummaryWriter, name:str='model_stats')->None:
365
"Writes model gradient statistics to Tensorboard."
366
request = ModelStatsTBRequest(model=model, iteration=iteration, tbwriter=tbwriter, name=name)
367
asyncTBWriter.request_write(request)
368
369
class ImageTBRequest(TBWriteRequest):
370
"Request object for model image output writes to Tensorboard."
371
def __init__(self, learn:Learner, batch:Tuple, iteration:int, tbwriter:SummaryWriter, ds_type:DatasetType):
372
super().__init__(tbwriter=tbwriter, iteration=iteration)
373
self.image_sets = ModelImageSet.get_list_from_model(learn=learn, batch=batch, ds_type=ds_type)
374
self.ds_type = ds_type
375
376
def _write_images(self, name:str, images:[Tensor])->None:
377
"Writes list of images as tensors to Tensorboard."
378
tag = self.ds_type.name + ' ' + name
379
self.tbwriter.add_image(tag=tag, img_tensor=vutils.make_grid(images, normalize=True), global_step=self.iteration)
380
381
def _get_image_tensors(self)->([Tensor], [Tensor], [Tensor]):
382
"Gets list of image tensors from lists of Image objects, as a tuple of original, generated and real(target) images."
383
orig_images, gen_images, real_images = [], [], []
384
for image_set in self.image_sets:
385
orig_images.append(image_set.orig.px)
386
gen_images.append(image_set.gen.px)
387
real_images.append(image_set.real.px)
388
return orig_images, gen_images, real_images
389
390
def write(self)->None:
391
"Writes original, generated and real(target) images to Tensorboard."
392
orig_images, gen_images, real_images = self._get_image_tensors()
393
self._write_images(name='orig images', images=orig_images)
394
self._write_images(name='gen images', images=gen_images)
395
self._write_images(name='real images', images=real_images)
396
397
#If this isn't done async then this is noticeably slower
398
class ImageTBWriter():
399
"Writes model image output to Tensorboard."
400
def __init__(self): super().__init__()
401
402
def write(self, learn:Learner, trn_batch:Tuple, val_batch:Tuple, iteration:int, tbwriter:SummaryWriter)->None:
403
"Writes training and validation batch images to Tensorboard."
404
self._write_for_dstype(learn=learn, batch=val_batch, iteration=iteration, tbwriter=tbwriter, ds_type=DatasetType.Valid)
405
self._write_for_dstype(learn=learn, batch=trn_batch, iteration=iteration, tbwriter=tbwriter, ds_type=DatasetType.Train)
406
407
def _write_for_dstype(self, learn:Learner, batch:Tuple, iteration:int, tbwriter:SummaryWriter, ds_type:DatasetType)->None:
408
"Writes batch images of specified DatasetType to Tensorboard."
409
request = ImageTBRequest(learn=learn, batch=batch, iteration=iteration, tbwriter=tbwriter, ds_type=ds_type)
410
asyncTBWriter.request_write(request)
411
412
class GraphTBRequest(TBWriteRequest):
413
"Request object for model histogram writes to Tensorboard."
414
def __init__(self, model:nn.Module, tbwriter:SummaryWriter, input_to_model:torch.Tensor):
415
super().__init__(tbwriter=tbwriter, iteration=0)
416
self.model,self.input_to_model = model,input_to_model
417
418
def write(self)->None:
419
"Writes single model graph to Tensorboard."
420
self.tbwriter.add_graph(model=self.model, input_to_model=self.input_to_model)
421
422
class GraphTBWriter():
423
"Writes model network graph to Tensorboard."
424
def write(self, model:nn.Module, tbwriter:SummaryWriter, input_to_model:torch.Tensor)->None:
425
"Writes model graph to Tensorboard."
426
request = GraphTBRequest(model=model, tbwriter=tbwriter, input_to_model=input_to_model)
427
asyncTBWriter.request_write(request)
428
429