Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/vision/ipynb/edsr.ipynb
3236 views
{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text" }, "source": [ "# Enhanced Deep Residual Networks for single-image super-resolution\n", "\n", "**Author:** Gitesh Chawda<br>\n", "**Date created:** 2022/04/07<br>\n", "**Last modified:** 2024/08/27<br>\n", "**Description:** Training an EDSR model on the DIV2K Dataset." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text" }, "source": [ "## Introduction\n", "\n", "In this example, we implement\n", "[Enhanced Deep Residual Networks for Single Image Super-Resolution (EDSR)](https://arxiv.org/abs/1707.02921)\n", "by Bee Lim, Sanghyun Son, Heewon Kim, Seungjun Nah, and Kyoung Mu Lee.\n", "\n", "The EDSR architecture is based on the SRResNet architecture and consists of multiple\n", "residual blocks. It uses constant scaling layers instead of batch normalization layers to\n", "produce consistent results (input and output have similar distributions, thus\n", "normalizing intermediate features may not be desirable). Instead of using a L2 loss (mean squared error),\n", "the authors employed an L1 loss (mean absolute error), which performs better empirically.\n", "\n", "Our implementation only includes 16 residual blocks with 64 channels.\n", "\n", "Alternatively, as shown in the Keras example\n", "[Image Super-Resolution using an Efficient Sub-Pixel CNN](https://keras.io/examples/vision/super_resolution_sub_pixel/#image-superresolution-using-an-efficient-subpixel-cnn),\n", "you can do super-resolution using an ESPCN Model. According to the survey paper, EDSR is one of the top-five\n", "best-performing super-resolution methods based on PSNR scores. However, it has more\n", "parameters and requires more computational power than other approaches.\n", "It has a PSNR value (\u224834db) that is slightly higher than ESPCN (\u224832db).\n", "As per the survey paper, EDSR performs better than ESPCN.\n", "\n", "Paper:\n", "[A comprehensive review of deep learning based single image super-resolution](https://arxiv.org/abs/2102.09351)\n", "\n", "Comparison Graph:\n", "<img src=\"https://dfzljdn9uc3pi.cloudfront.net/2021/cs-621/1/fig-11-2x.jpg\" width=\"500\" />" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text" }, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab_type": "code" }, "outputs": [], "source": [ "import os\n", "\n", "os.environ[\"KERAS_BACKEND\"] = \"tensorflow\"\n", "\n", "import numpy as np\n", "import tensorflow as tf\n", "import tensorflow_datasets as tfds\n", "import matplotlib.pyplot as plt\n", "\n", "import keras\n", "from keras import layers\n", "from keras import ops\n", "\n", "AUTOTUNE = tf.data.AUTOTUNE" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text" }, "source": [ "## Download the training dataset\n", "\n", "We use the DIV2K Dataset, a prominent single-image super-resolution dataset with 1,000\n", "images of scenes with various sorts of degradations,\n", "divided into 800 images for training, 100 images for validation, and 100\n", "images for testing. We use 4x bicubic downsampled images as our \"low quality\" reference." ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab_type": "code" }, "outputs": [], "source": [ "# Download DIV2K from TF Datasets\n", "# Using bicubic 4x degradation type\n", "div2k_data = tfds.image.Div2k(config=\"bicubic_x4\")\n", "div2k_data.download_and_prepare()\n", "\n", "# Taking train data from div2k_data object\n", "train = div2k_data.as_dataset(split=\"train\", as_supervised=True)\n", "train_cache = train.cache()\n", "# Validation data\n", "val = div2k_data.as_dataset(split=\"validation\", as_supervised=True)\n", "val_cache = val.cache()" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text" }, "source": [ "## Flip, crop and resize images" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab_type": "code" }, "outputs": [], "source": [ "\n", "def flip_left_right(lowres_img, highres_img):\n", " \"\"\"Flips Images to left and right.\"\"\"\n", "\n", " # Outputs random values from a uniform distribution in between 0 to 1\n", " rn = keras.random.uniform(shape=(), maxval=1)\n", " # If rn is less than 0.5 it returns original lowres_img and highres_img\n", " # If rn is greater than 0.5 it returns flipped image\n", " return ops.cond(\n", " rn < 0.5,\n", " lambda: (lowres_img, highres_img),\n", " lambda: (\n", " ops.flip(lowres_img),\n", " ops.flip(highres_img),\n", " ),\n", " )\n", "\n", "\n", "def random_rotate(lowres_img, highres_img):\n", " \"\"\"Rotates Images by 90 degrees.\"\"\"\n", "\n", " # Outputs random values from uniform distribution in between 0 to 4\n", " rn = ops.cast(\n", " keras.random.uniform(shape=(), maxval=4, dtype=\"float32\"), dtype=\"int32\"\n", " )\n", " # Here rn signifies number of times the image(s) are rotated by 90 degrees\n", " return tf.image.rot90(lowres_img, rn), tf.image.rot90(highres_img, rn)\n", "\n", "\n", "def random_crop(lowres_img, highres_img, hr_crop_size=96, scale=4):\n", " \"\"\"Crop images.\n", "\n", " low resolution images: 24x24\n", " high resolution images: 96x96\n", " \"\"\"\n", " lowres_crop_size = hr_crop_size // scale # 96//4=24\n", " lowres_img_shape = ops.shape(lowres_img)[:2] # (height,width)\n", "\n", " lowres_width = ops.cast(\n", " keras.random.uniform(\n", " shape=(), maxval=lowres_img_shape[1] - lowres_crop_size + 1, dtype=\"float32\"\n", " ),\n", " dtype=\"int32\",\n", " )\n", " lowres_height = ops.cast(\n", " keras.random.uniform(\n", " shape=(), maxval=lowres_img_shape[0] - lowres_crop_size + 1, dtype=\"float32\"\n", " ),\n", " dtype=\"int32\",\n", " )\n", "\n", " highres_width = lowres_width * scale\n", " highres_height = lowres_height * scale\n", "\n", " lowres_img_cropped = lowres_img[\n", " lowres_height : lowres_height + lowres_crop_size,\n", " lowres_width : lowres_width + lowres_crop_size,\n", " ] # 24x24\n", " highres_img_cropped = highres_img[\n", " highres_height : highres_height + hr_crop_size,\n", " highres_width : highres_width + hr_crop_size,\n", " ] # 96x96\n", "\n", " return lowres_img_cropped, highres_img_cropped\n", "" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text" }, "source": [ "## Prepare a `tf.data.Dataset` object\n", "\n", "We augment the training data with random horizontal flips and 90 rotations.\n", "\n", "As low resolution images, we use 24x24 RGB input patches." ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab_type": "code" }, "outputs": [], "source": [ "\n", "def dataset_object(dataset_cache, training=True):\n", " ds = dataset_cache\n", " ds = ds.map(\n", " lambda lowres, highres: random_crop(lowres, highres, scale=4),\n", " num_parallel_calls=AUTOTUNE,\n", " )\n", "\n", " if training:\n", " ds = ds.map(random_rotate, num_parallel_calls=AUTOTUNE)\n", " ds = ds.map(flip_left_right, num_parallel_calls=AUTOTUNE)\n", " # Batching Data\n", " ds = ds.batch(16)\n", "\n", " if training:\n", " # Repeating Data, so that cardinality if dataset becomes infinte\n", " ds = ds.repeat()\n", " # prefetching allows later images to be prepared while the current image is being processed\n", " ds = ds.prefetch(buffer_size=AUTOTUNE)\n", " return ds\n", "\n", "\n", "train_ds = dataset_object(train_cache, training=True)\n", "val_ds = dataset_object(val_cache, training=False)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text" }, "source": [ "## Visualize the data\n", "\n", "Let's visualize a few sample images:" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab_type": "code" }, "outputs": [], "source": [ "lowres, highres = next(iter(train_ds))\n", "\n", "# High Resolution Images\n", "plt.figure(figsize=(10, 10))\n", "for i in range(9):\n", " ax = plt.subplot(3, 3, i + 1)\n", " plt.imshow(highres[i].numpy().astype(\"uint8\"))\n", " plt.title(highres[i].shape)\n", " plt.axis(\"off\")\n", "\n", "# Low Resolution Images\n", "plt.figure(figsize=(10, 10))\n", "for i in range(9):\n", " ax = plt.subplot(3, 3, i + 1)\n", " plt.imshow(lowres[i].numpy().astype(\"uint8\"))\n", " plt.title(lowres[i].shape)\n", " plt.axis(\"off\")\n", "\n", "\n", "def PSNR(super_resolution, high_resolution):\n", " \"\"\"Compute the peak signal-to-noise ratio, measures quality of image.\"\"\"\n", " # Max value of pixel is 255\n", " psnr_value = tf.image.psnr(high_resolution, super_resolution, max_val=255)[0]\n", " return psnr_value\n", "" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text" }, "source": [ "## Build the model\n", "\n", "In the paper, the authors train three models: EDSR, MDSR, and a baseline model. In this code example,\n", "we only train the baseline model.\n", "\n", "### Comparison with model with three residual blocks\n", "\n", "The residual block design of EDSR differs from that of ResNet. Batch normalization\n", "layers have been removed (together with the final ReLU activation): since batch normalization\n", "layers normalize the features, they hurt output value range flexibility.\n", "It is thus better to remove them. Further, it also helps reduce the\n", "amount of GPU RAM required by the model, since the batch normalization layers consume the same amount of\n", "memory as the preceding convolutional layers.\n", "\n", "<img src=\"https://miro.medium.com/max/1050/1*EPviXGqlGWotVtV2gqVvNg.png\" width=\"500\" />" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab_type": "code" }, "outputs": [], "source": [ "\n", "class EDSRModel(keras.Model):\n", " def train_step(self, data):\n", " # Unpack the data. Its structure depends on your model and\n", " # on what you pass to `fit()`.\n", " x, y = data\n", "\n", " with tf.GradientTape() as tape:\n", " y_pred = self(x, training=True) # Forward pass\n", " # Compute the loss value\n", " # (the loss function is configured in `compile()`)\n", " loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)\n", "\n", " # Compute gradients\n", " trainable_vars = self.trainable_variables\n", " gradients = tape.gradient(loss, trainable_vars)\n", " # Update weights\n", " self.optimizer.apply_gradients(zip(gradients, trainable_vars))\n", " # Update metrics (includes the metric that tracks the loss)\n", " self.compiled_metrics.update_state(y, y_pred)\n", " # Return a dict mapping metric names to current value\n", " return {m.name: m.result() for m in self.metrics}\n", "\n", " def predict_step(self, x):\n", " # Adding dummy dimension using tf.expand_dims and converting to float32 using tf.cast\n", " x = ops.cast(tf.expand_dims(x, axis=0), dtype=\"float32\")\n", " # Passing low resolution image to model\n", " super_resolution_img = self(x, training=False)\n", " # Clips the tensor from min(0) to max(255)\n", " super_resolution_img = ops.clip(super_resolution_img, 0, 255)\n", " # Rounds the values of a tensor to the nearest integer\n", " super_resolution_img = ops.round(super_resolution_img)\n", " # Removes dimensions of size 1 from the shape of a tensor and converting to uint8\n", " super_resolution_img = ops.squeeze(\n", " ops.cast(super_resolution_img, dtype=\"uint8\"), axis=0\n", " )\n", " return super_resolution_img\n", "\n", "\n", "# Residual Block\n", "def ResBlock(inputs):\n", " x = layers.Conv2D(64, 3, padding=\"same\", activation=\"relu\")(inputs)\n", " x = layers.Conv2D(64, 3, padding=\"same\")(x)\n", " x = layers.Add()([inputs, x])\n", " return x\n", "\n", "\n", "# Upsampling Block\n", "def Upsampling(inputs, factor=2, **kwargs):\n", " x = layers.Conv2D(64 * (factor**2), 3, padding=\"same\", **kwargs)(inputs)\n", " x = layers.Lambda(lambda x: tf.nn.depth_to_space(x, block_size=factor))(x)\n", " x = layers.Conv2D(64 * (factor**2), 3, padding=\"same\", **kwargs)(x)\n", " x = layers.Lambda(lambda x: tf.nn.depth_to_space(x, block_size=factor))(x)\n", " return x\n", "\n", "\n", "def make_model(num_filters, num_of_residual_blocks):\n", " # Flexible Inputs to input_layer\n", " input_layer = layers.Input(shape=(None, None, 3))\n", " # Scaling Pixel Values\n", " x = layers.Rescaling(scale=1.0 / 255)(input_layer)\n", " x = x_new = layers.Conv2D(num_filters, 3, padding=\"same\")(x)\n", "\n", " # 16 residual blocks\n", " for _ in range(num_of_residual_blocks):\n", " x_new = ResBlock(x_new)\n", "\n", " x_new = layers.Conv2D(num_filters, 3, padding=\"same\")(x_new)\n", " x = layers.Add()([x, x_new])\n", "\n", " x = Upsampling(x)\n", " x = layers.Conv2D(3, 3, padding=\"same\")(x)\n", "\n", " output_layer = layers.Rescaling(scale=255)(x)\n", " return EDSRModel(input_layer, output_layer)\n", "\n", "\n", "model = make_model(num_filters=64, num_of_residual_blocks=16)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text" }, "source": [ "## Train the model" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab_type": "code" }, "outputs": [], "source": [ "# Using adam optimizer with initial learning rate as 1e-4, changing learning rate after 5000 steps to 5e-5\n", "optim_edsr = keras.optimizers.Adam(\n", " learning_rate=keras.optimizers.schedules.PiecewiseConstantDecay(\n", " boundaries=[5000], values=[1e-4, 5e-5]\n", " )\n", ")\n", "# Compiling model with loss as mean absolute error(L1 Loss) and metric as psnr\n", "model.compile(optimizer=optim_edsr, loss=\"mae\", metrics=[PSNR])\n", "# Training for more epochs will improve results\n", "model.fit(train_ds, epochs=100, steps_per_epoch=200, validation_data=val_ds)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text" }, "source": [ "## Run inference on new images and plot the results" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab_type": "code" }, "outputs": [], "source": [ "\n", "def plot_results(lowres, preds):\n", " \"\"\"\n", " Displays low resolution image and super resolution image\n", " \"\"\"\n", " plt.figure(figsize=(24, 14))\n", " plt.subplot(132), plt.imshow(lowres), plt.title(\"Low resolution\")\n", " plt.subplot(133), plt.imshow(preds), plt.title(\"Prediction\")\n", " plt.show()\n", "\n", "\n", "for lowres, highres in val.take(10):\n", " lowres = tf.image.random_crop(lowres, (150, 150, 3))\n", " preds = model.predict_step(lowres)\n", " plot_results(lowres, preds)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text" }, "source": [ "## Final remarks\n", "\n", "In this example, we implemented the EDSR model (Enhanced Deep Residual Networks for Single Image\n", "Super-Resolution). You could improve the model accuracy by training the model for more epochs, as well as\n", "training the model with a wider variety of inputs with mixed downgrading factors, so as to\n", "be able to handle a greater range of real-world images.\n", "\n", "You could also improve on the given baseline EDSR model by implementing EDSR+,\n", "or MDSR( Multi-Scale super-resolution) and MDSR+,\n", "which were proposed in the same paper.\n", ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "edsr", "private_outputs": false, "provenance": [], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.0" } }, "nbformat": 4, "nbformat_minor": 0 }