Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
jantic
GitHub Repository: jantic/deoldify
Path: blob/master/ColorizeTrainingVideo.ipynb
781 views
Kernel: Python 3

Video Model Training

NOTES:

  • It's assumed that there's a pretrained generator from the ColorizeTrainingStable notebook available at the specified path.

  • This is "NoGAN" based training, described in the DeOldify readme.

#NOTE: This must be the first call in order to work properly! from deoldify import device from deoldify.device_id import DeviceId #choices: CPU, GPU0...GPU7 device.set(device=DeviceId.GPU0)
import os import fastai from fastai import * from fastai.vision import * from fastai.callbacks.tensorboard import * from fastai.vision.gan import * from deoldify.generators import * from deoldify.critics import * from deoldify.dataset import * from deoldify.loss import * from deoldify.save import * from deoldify.augs import noisify from PIL import Image, ImageDraw, ImageFont from PIL import ImageFile

Setup

path = Path('data/imagenet/ILSVRC/Data/CLS-LOC') path_hr = path path_lr = path/'bandw' proj_id = 'VideoModel' gen_name = proj_id + '_gen' pre_gen_name = gen_name + '_0' crit_name = proj_id + '_crit' name_gen = proj_id + '_image_gen' path_gen = path/name_gen TENSORBOARD_PATH = Path('data/tensorboard/' + proj_id) nf_factor = 2 xtra_tfms=[noisify(p=0.8)] pct_start = 1e-8
def get_data(bs:int, sz:int, keep_pct:float): return get_colorize_data(sz=sz, bs=bs, crappy_path=path_lr, good_path=path_hr, random_seed=None, keep_pct=keep_pct, xtra_tfms=xtra_tfms) def get_crit_data(classes, bs, sz): src = ImageList.from_folder(path, include=classes, recurse=True).split_by_rand_pct(0.1, seed=42) ll = src.label_from_folder(classes=classes) data = (ll.transform(get_transforms(max_zoom=2.), size=sz) .databunch(bs=bs).normalize(imagenet_stats)) return data def create_training_images(fn,i): dest = path_lr/fn.relative_to(path_hr) dest.parent.mkdir(parents=True, exist_ok=True) img = PIL.Image.open(fn).convert('LA').convert('RGB') img.save(dest) def save_preds(dl): i=0 names = dl.dataset.items for b in dl: preds = learn_gen.pred_batch(batch=b, reconstruct=True) for o in preds: o.save(path_gen/names[i].name) i += 1 def save_gen_images(): if path_gen.exists(): shutil.rmtree(path_gen) path_gen.mkdir(exist_ok=True) data_gen = get_data(bs=bs, sz=sz, keep_pct=0.085) save_preds(data_gen.fix_dl) PIL.Image.open(path_gen.ls()[0])

Create black and white training images

Only runs if the directory isn't already created.

if not path_lr.exists(): il = ImageList.from_folder(path_hr) parallel(create_training_images, il.items)

Finetune Generator With Noise Augmented Images.

This helps the generator better deal with noisy/grainy video (which is pretty normal).
bs=8 sz=192 keep_pct=0.25
data_gen = get_data(bs=bs, sz=sz, keep_pct=keep_pct)
learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor)
learn_gen.callback_fns.append(partial(ImageGenTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GenPre'))
learn_gen = learn_gen.load(pre_gen_name, with_opt=False)
learn_gen.unfreeze()
learn_gen.fit_one_cycle(1, pct_start=pct_start, max_lr=slice(5e-8,5e-5))
learn_gen.save(pre_gen_name)

Repeatable GAN Cycle

NOTE

Best results so far have been based only doing a single run of the cells below (otherwise glitches are introduced that are visible in video).

old_checkpoint_num = 0 checkpoint_num = old_checkpoint_num + 1 gen_old_checkpoint_name = gen_name + '_' + str(old_checkpoint_num) gen_new_checkpoint_name = gen_name + '_' + str(checkpoint_num) crit_old_checkpoint_name = crit_name + '_' + str(old_checkpoint_num) crit_new_checkpoint_name= crit_name + '_' + str(checkpoint_num)

Save Generated Images

bs=8 sz=192
learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)
save_gen_images()

Pretrain Critic

bs=16 sz=192
learn_gen=None gc.collect()
data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)
data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)
learn_critic = colorize_crit_learner(data=data_crit, nf=256).load(crit_old_checkpoint_name, with_opt=False)
learn_critic.callback_fns.append(partial(LearnerTensorboardWriter, base_dir=TENSORBOARD_PATH, name='CriticPre'))
learn_critic.fit_one_cycle(4, 1e-4)
learn_critic.save(crit_new_checkpoint_name)

GAN

learn_crit=None learn_gen=None gc.collect()
lr=5e-6 sz=192 bs=5
data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)
learn_crit = colorize_crit_learner(data=data_crit, nf=256).load(crit_new_checkpoint_name, with_opt=False)
learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)
switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65) learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.0,1.5), show_img=False, switcher=switcher, opt_func=partial(optim.Adam, betas=(0.,0.9)), wd=1e-3) learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.)) learn.callback_fns.append(partial(GANTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GanLearner', visual_iters=100, stats_iters=10, loss_iters=1)) learn.callback_fns.append(partial(GANSaveCallback, learn_gen=learn_gen, filename=gen_new_checkpoint_name, save_iters=100))

Instructions:

Find the checkpoint just before where glitches start to be introduced. So far this has been found at the point of iterating through 1.4% of the data when using learning rate of 1e-5, and at 2.2% of the data for 5e-6.

learn.data = get_data(sz=sz, bs=bs, keep_pct=0.03) learn_gen.freeze_to(-1) learn.fit(1,lr)