Contact Us!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
Avatar for stephanie's main branch.

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

| Download

"Guiding Future STEM Leaders through Innovative Research Training" ~ thinkingbeyond.education

Views: 1103
Image: ubuntu2204
{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "Vision Transformer Code: https://colab.research.google.com/drive/1UD_bybmnndzhi-tvuaqvHNHP5RWIq5zY?usp=sharing"
      ],
      "metadata": {
        "id": "GBAOoWEIDh-O"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8WGhXYYdqa--"
      },
      "outputs": [],
      "source": [
        "from tensorflow import lite\n",
        "import tensorflow as tf\n",
        "from tensorflow import keras\n",
        "from tensorflow.keras import layers\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "import random, os\n",
        "import shutil\n",
        "import matplotlib.pyplot as plt\n",
        "from matplotlib.image import imread\n",
        "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n",
        "from tensorflow.keras.metrics import categorical_accuracy\n",
        "from sklearn.model_selection import train_test_split\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import kagglehub\n",
        "\n",
        "# Download latest version\n",
        "path = kagglehub.dataset_download(\"sovitrath/diabetic-retinopathy-224x224-2019-data\")\n",
        "\n",
        "print(\"Path to dataset files:\", path)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "0RcpxR65T4EV",
        "outputId": "3fab5ed6-d269-41c3-a70e-5a97dc459ede"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Downloading from https://www.kaggle.com/api/v1/datasets/download/sovitrath/diabetic-retinopathy-224x224-2019-data?dataset_version_number=4...\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "100%|██████████| 238M/238M [00:02<00:00, 95.5MB/s]"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Extracting files...\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Path to dataset files: /root/.cache/kagglehub/datasets/sovitrath/diabetic-retinopathy-224x224-2019-data/versions/4\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Add an additional column, mapping to the type\n",
        "df = pd.read_csv(r'../root/.cache/kagglehub/datasets/sovitrath/diabetic-retinopathy-224x224-2019-data/versions/4/train.csv')\n",
        "\n",
        "diagnosis_dict_binary = {\n",
        "    0: 'No_DR',\n",
        "    1: 'DR',\n",
        "    2: 'DR',\n",
        "    3: 'DR',\n",
        "    4: 'DR'\n",
        "}\n",
        "\n",
        "diagnosis_dict = {\n",
        "    0: 'No_DR',\n",
        "    1: 'Mild',\n",
        "    2: 'Moderate',\n",
        "    3: 'Severe',\n",
        "    4: 'Proliferate_DR',\n",
        "}\n",
        "\n",
        "\n",
        "df['binary_type'] =  df['diagnosis'].map(diagnosis_dict_binary.get)\n",
        "df['type'] = df['diagnosis'].map(diagnosis_dict.get)\n",
        "df.head()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 206
        },
        "id": "TVjQOD_DknKd",
        "outputId": "3e77f7ea-3701-4fa5-dca7-66051f2dc433"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "        id_code  diagnosis binary_type            type\n",
              "0  000c1434d8d7          2          DR        Moderate\n",
              "1  001639a390f0          4          DR  Proliferate_DR\n",
              "2  0024cdab0c1e          1          DR            Mild\n",
              "3  002c21358ce6          0       No_DR           No_DR\n",
              "4  005b95c28852          0       No_DR           No_DR"
            ],
            "text/html": [
              "\n",
              "  <div id=\"df-08c6fac5-e6d0-4ab6-a065-4304bcf8ada5\" class=\"colab-df-container\">\n",
              "    <div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>id_code</th>\n",
              "      <th>diagnosis</th>\n",
              "      <th>binary_type</th>\n",
              "      <th>type</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>000c1434d8d7</td>\n",
              "      <td>2</td>\n",
              "      <td>DR</td>\n",
              "      <td>Moderate</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>001639a390f0</td>\n",
              "      <td>4</td>\n",
              "      <td>DR</td>\n",
              "      <td>Proliferate_DR</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>0024cdab0c1e</td>\n",
              "      <td>1</td>\n",
              "      <td>DR</td>\n",
              "      <td>Mild</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>002c21358ce6</td>\n",
              "      <td>0</td>\n",
              "      <td>No_DR</td>\n",
              "      <td>No_DR</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>005b95c28852</td>\n",
              "      <td>0</td>\n",
              "      <td>No_DR</td>\n",
              "      <td>No_DR</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>\n",
              "    <div class=\"colab-df-buttons\">\n",
              "\n",
              "  <div class=\"colab-df-container\">\n",
              "    <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-08c6fac5-e6d0-4ab6-a065-4304bcf8ada5')\"\n",
              "            title=\"Convert this dataframe to an interactive table.\"\n",
              "            style=\"display:none;\">\n",
              "\n",
              "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n",
              "    <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n",
              "  </svg>\n",
              "    </button>\n",
              "\n",
              "  <style>\n",
              "    .colab-df-container {\n",
              "      display:flex;\n",
              "      gap: 12px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert {\n",
              "      background-color: #E8F0FE;\n",
              "      border: none;\n",
              "      border-radius: 50%;\n",
              "      cursor: pointer;\n",
              "      display: none;\n",
              "      fill: #1967D2;\n",
              "      height: 32px;\n",
              "      padding: 0 0 0 0;\n",
              "      width: 32px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert:hover {\n",
              "      background-color: #E2EBFA;\n",
              "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "      fill: #174EA6;\n",
              "    }\n",
              "\n",
              "    .colab-df-buttons div {\n",
              "      margin-bottom: 4px;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert {\n",
              "      background-color: #3B4455;\n",
              "      fill: #D2E3FC;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert:hover {\n",
              "      background-color: #434B5C;\n",
              "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "      fill: #FFFFFF;\n",
              "    }\n",
              "  </style>\n",
              "\n",
              "    <script>\n",
              "      const buttonEl =\n",
              "        document.querySelector('#df-08c6fac5-e6d0-4ab6-a065-4304bcf8ada5 button.colab-df-convert');\n",
              "      buttonEl.style.display =\n",
              "        google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "\n",
              "      async function convertToInteractive(key) {\n",
              "        const element = document.querySelector('#df-08c6fac5-e6d0-4ab6-a065-4304bcf8ada5');\n",
              "        const dataTable =\n",
              "          await google.colab.kernel.invokeFunction('convertToInteractive',\n",
              "                                                    [key], {});\n",
              "        if (!dataTable) return;\n",
              "\n",
              "        const docLinkHtml = 'Like what you see? Visit the ' +\n",
              "          '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
              "          + ' to learn more about interactive tables.';\n",
              "        element.innerHTML = '';\n",
              "        dataTable['output_type'] = 'display_data';\n",
              "        await google.colab.output.renderOutput(dataTable, element);\n",
              "        const docLink = document.createElement('div');\n",
              "        docLink.innerHTML = docLinkHtml;\n",
              "        element.appendChild(docLink);\n",
              "      }\n",
              "    </script>\n",
              "  </div>\n",
              "\n",
              "\n",
              "<div id=\"df-7174c461-e092-4f1c-ae3d-9983f0ef18e5\">\n",
              "  <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-7174c461-e092-4f1c-ae3d-9983f0ef18e5')\"\n",
              "            title=\"Suggest charts\"\n",
              "            style=\"display:none;\">\n",
              "\n",
              "<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
              "     width=\"24px\">\n",
              "    <g>\n",
              "        <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n",
              "    </g>\n",
              "</svg>\n",
              "  </button>\n",
              "\n",
              "<style>\n",
              "  .colab-df-quickchart {\n",
              "      --bg-color: #E8F0FE;\n",
              "      --fill-color: #1967D2;\n",
              "      --hover-bg-color: #E2EBFA;\n",
              "      --hover-fill-color: #174EA6;\n",
              "      --disabled-fill-color: #AAA;\n",
              "      --disabled-bg-color: #DDD;\n",
              "  }\n",
              "\n",
              "  [theme=dark] .colab-df-quickchart {\n",
              "      --bg-color: #3B4455;\n",
              "      --fill-color: #D2E3FC;\n",
              "      --hover-bg-color: #434B5C;\n",
              "      --hover-fill-color: #FFFFFF;\n",
              "      --disabled-bg-color: #3B4455;\n",
              "      --disabled-fill-color: #666;\n",
              "  }\n",
              "\n",
              "  .colab-df-quickchart {\n",
              "    background-color: var(--bg-color);\n",
              "    border: none;\n",
              "    border-radius: 50%;\n",
              "    cursor: pointer;\n",
              "    display: none;\n",
              "    fill: var(--fill-color);\n",
              "    height: 32px;\n",
              "    padding: 0;\n",
              "    width: 32px;\n",
              "  }\n",
              "\n",
              "  .colab-df-quickchart:hover {\n",
              "    background-color: var(--hover-bg-color);\n",
              "    box-shadow: 0 1px 2px rgba(60, 64, 67, 0.3), 0 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "    fill: var(--button-hover-fill-color);\n",
              "  }\n",
              "\n",
              "  .colab-df-quickchart-complete:disabled,\n",
              "  .colab-df-quickchart-complete:disabled:hover {\n",
              "    background-color: var(--disabled-bg-color);\n",
              "    fill: var(--disabled-fill-color);\n",
              "    box-shadow: none;\n",
              "  }\n",
              "\n",
              "  .colab-df-spinner {\n",
              "    border: 2px solid var(--fill-color);\n",
              "    border-color: transparent;\n",
              "    border-bottom-color: var(--fill-color);\n",
              "    animation:\n",
              "      spin 1s steps(1) infinite;\n",
              "  }\n",
              "\n",
              "  @keyframes spin {\n",
              "    0% {\n",
              "      border-color: transparent;\n",
              "      border-bottom-color: var(--fill-color);\n",
              "      border-left-color: var(--fill-color);\n",
              "    }\n",
              "    20% {\n",
              "      border-color: transparent;\n",
              "      border-left-color: var(--fill-color);\n",
              "      border-top-color: var(--fill-color);\n",
              "    }\n",
              "    30% {\n",
              "      border-color: transparent;\n",
              "      border-left-color: var(--fill-color);\n",
              "      border-top-color: var(--fill-color);\n",
              "      border-right-color: var(--fill-color);\n",
              "    }\n",
              "    40% {\n",
              "      border-color: transparent;\n",
              "      border-right-color: var(--fill-color);\n",
              "      border-top-color: var(--fill-color);\n",
              "    }\n",
              "    60% {\n",
              "      border-color: transparent;\n",
              "      border-right-color: var(--fill-color);\n",
              "    }\n",
              "    80% {\n",
              "      border-color: transparent;\n",
              "      border-right-color: var(--fill-color);\n",
              "      border-bottom-color: var(--fill-color);\n",
              "    }\n",
              "    90% {\n",
              "      border-color: transparent;\n",
              "      border-bottom-color: var(--fill-color);\n",
              "    }\n",
              "  }\n",
              "</style>\n",
              "\n",
              "  <script>\n",
              "    async function quickchart(key) {\n",
              "      const quickchartButtonEl =\n",
              "        document.querySelector('#' + key + ' button');\n",
              "      quickchartButtonEl.disabled = true;  // To prevent multiple clicks.\n",
              "      quickchartButtonEl.classList.add('colab-df-spinner');\n",
              "      try {\n",
              "        const charts = await google.colab.kernel.invokeFunction(\n",
              "            'suggestCharts', [key], {});\n",
              "      } catch (error) {\n",
              "        console.error('Error during call to suggestCharts:', error);\n",
              "      }\n",
              "      quickchartButtonEl.classList.remove('colab-df-spinner');\n",
              "      quickchartButtonEl.classList.add('colab-df-quickchart-complete');\n",
              "    }\n",
              "    (() => {\n",
              "      let quickchartButtonEl =\n",
              "        document.querySelector('#df-7174c461-e092-4f1c-ae3d-9983f0ef18e5 button');\n",
              "      quickchartButtonEl.style.display =\n",
              "        google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "    })();\n",
              "  </script>\n",
              "</div>\n",
              "\n",
              "    </div>\n",
              "  </div>\n"
            ],
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "dataframe",
              "variable_name": "df",
              "summary": "{\n  \"name\": \"df\",\n  \"rows\": 3662,\n  \"fields\": [\n    {\n      \"column\": \"id_code\",\n      \"properties\": {\n        \"dtype\": \"string\",\n        \"num_unique_values\": 3662,\n        \"samples\": [\n          \"90960ddf4d14\",\n          \"4e0656629d02\",\n          \"3b018e8b7303\"\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"diagnosis\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 1,\n        \"min\": 0,\n        \"max\": 4,\n        \"num_unique_values\": 5,\n        \"samples\": [\n          4,\n          3,\n          1\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"binary_type\",\n      \"properties\": {\n        \"dtype\": \"category\",\n        \"num_unique_values\": 2,\n        \"samples\": [\n          \"No_DR\",\n          \"DR\"\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"type\",\n      \"properties\": {\n        \"dtype\": \"category\",\n        \"num_unique_values\": 5,\n        \"samples\": [\n          \"Proliferate_DR\",\n          \"Severe\"\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    }\n  ]\n}"
            }
          },
          "metadata": {},
          "execution_count": 4
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "train_intermediate, val = train_test_split(df, test_size = 0.15, stratify = df['type'])\n",
        "train, test = train_test_split(train_intermediate, test_size = 0.15 / (1 - 0.15), stratify = train_intermediate['type'])\n",
        "\n",
        "#print number\n",
        "print(train['type'].value_counts(), '\\n')"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "PwW1qeynVi3Z",
        "outputId": "9dd1e339-35cd-4e25-c65c-966f29f12fae"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "type\n",
            "No_DR             1263\n",
            "Moderate           699\n",
            "Mild               258\n",
            "Proliferate_DR     207\n",
            "Severe             135\n",
            "Name: count, dtype: int64 \n",
            "\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "#create directories for train, val and test. update it everytime the code runs with shutil.rmtree\n",
        "base_dir = '.\\dataset'\n",
        "\n",
        "\n",
        "train_dir = os.path.join(base_dir, 'train')\n",
        "val_dir = os.path.join(base_dir, 'val')\n",
        "test_dir = os.path.join(base_dir, 'test')\n",
        "\n",
        "if os.path.exists(base_dir):\n",
        "    shutil.rmtree(base_dir)\n",
        "\n",
        "#make directories for train, val and test\n",
        "os.makedirs(train_dir, exist_ok=True)\n",
        "os.makedirs(val_dir, exist_ok=True)\n",
        "os.makedirs(test_dir, exist_ok=True)\n",
        "#ensure the directories have the same columns\n",
        "valid_types = ['No_DR', 'Mild', 'Moderate', 'Severe', 'Proliferate_DR']\n",
        "df = df[df['type'].isin(valid_types)]\n",
        "assert all(train['type'].isin(valid_types))\n",
        "assert all(val['type'].isin(valid_types))\n",
        "assert all(test['type'].isin(valid_types))"
      ],
      "metadata": {
        "id": "2yChIiyNVkKY"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#copy the files from the sourcefile to train, validation and test directory respectively\n",
        "src_dir = r'../root/.cache/kagglehub/datasets/sovitrath/diabetic-retinopathy-224x224-2019-data/versions/4/colored_images'\n",
        "for index, row in train.iterrows():\n",
        "    diagnosis = row['type']\n",
        "    binary_diagnosis = row['binary_type']\n",
        "    id_code = row['id_code'] + \".png\"\n",
        "    srcfile = os.path.join(src_dir, diagnosis, id_code)\n",
        "    dstfile = os.path.join(train_dir, diagnosis)\n",
        "    os.makedirs(dstfile, exist_ok = True)\n",
        "    if os.path.exists(srcfile):\n",
        "      shutil.copy(srcfile, dstfile)\n",
        "for index, row in val.iterrows():\n",
        "    diagnosis = row['type']\n",
        "    binary_diagnosis = row['binary_type']\n",
        "    id_code = row['id_code'] + \".png\"\n",
        "    srcfile = os.path.join(src_dir, diagnosis, id_code)\n",
        "    dstfile = os.path.join(val_dir, diagnosis)\n",
        "    os.makedirs(dstfile, exist_ok=True)\n",
        "    if os.path.exists(srcfile):\n",
        "        shutil.copy(srcfile, dstfile)\n",
        "\n",
        "for index, row in test.iterrows():\n",
        "    diagnosis = row['type']\n",
        "    binary_diagnosis = row['binary_type']\n",
        "    id_code = row['id_code'] + \".png\"\n",
        "    srcfile = os.path.join(src_dir, diagnosis, id_code)\n",
        "    dstfile = os.path.join(test_dir, diagnosis)\n",
        "    os.makedirs(dstfile, exist_ok=True)\n",
        "    if os.path.exists(srcfile):\n",
        "        shutil.copy(srcfile, dstfile)\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "for subdir in [train_dir, val_dir, test_dir]:\n",
        "    print(f\"\\nContents of {subdir}:\")\n",
        "    for root, dirs, files in os.walk(subdir):\n",
        "        print(f\"{root}: {len(files)} files\")\n",
        "\n",
        "\n"
      ],
      "metadata": {
        "id": "RnSpw1uQbj5O",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "b5d264a6-60c5-4d27-de41-98b301aa82d8"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\n",
            "Contents of .\\dataset/train:\n",
            ".\\dataset/train: 0 files\n",
            ".\\dataset/train/Moderate: 699 files\n",
            ".\\dataset/train/Mild: 258 files\n",
            ".\\dataset/train/Severe: 135 files\n",
            ".\\dataset/train/Proliferate_DR: 207 files\n",
            ".\\dataset/train/No_DR: 1263 files\n",
            "\n",
            "Contents of .\\dataset/val:\n",
            ".\\dataset/val: 0 files\n",
            ".\\dataset/val/Moderate: 150 files\n",
            ".\\dataset/val/Mild: 56 files\n",
            ".\\dataset/val/Severe: 29 files\n",
            ".\\dataset/val/Proliferate_DR: 44 files\n",
            ".\\dataset/val/No_DR: 271 files\n",
            "\n",
            "Contents of .\\dataset/test:\n",
            ".\\dataset/test: 0 files\n",
            ".\\dataset/test/Moderate: 150 files\n",
            ".\\dataset/test/Mild: 56 files\n",
            ".\\dataset/test/Severe: 29 files\n",
            ".\\dataset/test/Proliferate_DR: 44 files\n",
            ".\\dataset/test/No_DR: 271 files\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "train_path = train_dir\n",
        "val_path = val_dir\n",
        "test_path = test_dir\n",
        "\n",
        "train_batches = ImageDataGenerator(rescale = 1./255).flow_from_directory(train_path, target_size=(224,224), shuffle = True)\n",
        "val_batches = ImageDataGenerator(rescale = 1./255).flow_from_directory(val_path, target_size=(224,224), shuffle = True)\n",
        "test_batches = ImageDataGenerator(rescale = 1./255).flow_from_directory(test_path, target_size=(224,224), shuffle = False)\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "qOXxFZg-fLsL",
        "outputId": "138309c1-cf63-499d-9c81-de0d7eafa7bc"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Found 2562 images belonging to 5 classes.\n",
            "Found 550 images belonging to 5 classes.\n",
            "Found 550 images belonging to 5 classes.\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "#cnn model\n",
        "model = tf.keras.Sequential([\n",
        "    layers.Conv2D(8, (3,3), padding=\"valid\", input_shape=(224,224,3), activation = 'relu'),\n",
        "    layers.MaxPooling2D(pool_size=(2,2)),\n",
        "    layers.BatchNormalization(),\n",
        "\n",
        "    layers.Conv2D(16, (3,3), padding=\"valid\", activation = 'relu'),\n",
        "    layers.MaxPooling2D(pool_size=(2,2)),\n",
        "    layers.BatchNormalization(),\n",
        "\n",
        "    layers.Conv2D(32, (4,4), padding=\"valid\", activation = 'relu'),\n",
        "    layers.MaxPooling2D(pool_size=(2,2)),\n",
        "    layers.BatchNormalization(),\n",
        "    layers.Flatten(),\n",
        "    layers.Dense(32, activation = 'relu'),\n",
        "    layers.Dropout(0.15),\n",
        "    layers.Dense(5, activation = 'softmax')\n",
        "])\n",
        "\n",
        "model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),\n",
        "              loss='categorical_crossentropy',\n",
        "              metrics=['accuracy'])\n",
        "\n",
        "\n",
        "history = model.fit(train_batches,\n",
        "                    epochs=50,\n",
        "                    validation_data=val_batches)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "YM5WCYAziRiy",
        "outputId": "f97c448d-db1c-4a63-bfc4-d8661f49a40a"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/keras/src/layers/convolutional/base_conv.py:107: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead.\n",
            "  super().__init__(activity_regularizer=activity_regularizer, **kwargs)\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Epoch 1/10\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/keras/src/trainers/data_adapters/py_dataset_adapter.py:122: UserWarning: Your `PyDataset` class should call `super().__init__(**kwargs)` in its constructor. `**kwargs` can include `workers`, `use_multiprocessing`, `max_queue_size`. Do not pass these arguments to `fit()`, as they will be ignored.\n",
            "  self._warn_if_super_not_called()\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m101s\u001b[0m 1s/step - accuracy: 0.4029 - loss: 1.5315 - val_accuracy: 0.5673 - val_loss: 1.5060\n",
            "Epoch 2/10\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m145s\u001b[0m 1s/step - accuracy: 0.6735 - loss: 0.9081 - val_accuracy: 0.6727 - val_loss: 1.3945\n",
            "Epoch 3/10\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m139s\u001b[0m 1s/step - accuracy: 0.7026 - loss: 0.8166 - val_accuracy: 0.6945 - val_loss: 1.2591\n",
            "Epoch 4/10\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m142s\u001b[0m 1s/step - accuracy: 0.7062 - loss: 0.8137 - val_accuracy: 0.6855 - val_loss: 1.1358\n",
            "Epoch 5/10\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m144s\u001b[0m 1s/step - accuracy: 0.7130 - loss: 0.7965 - val_accuracy: 0.7127 - val_loss: 0.9882\n",
            "Epoch 6/10\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m139s\u001b[0m 1s/step - accuracy: 0.7532 - loss: 0.7223 - val_accuracy: 0.7327 - val_loss: 0.8859\n",
            "Epoch 7/10\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m144s\u001b[0m 1s/step - accuracy: 0.7584 - loss: 0.7100 - val_accuracy: 0.7382 - val_loss: 0.8226\n",
            "Epoch 8/10\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m94s\u001b[0m 1s/step - accuracy: 0.7351 - loss: 0.7226 - val_accuracy: 0.7327 - val_loss: 0.7908\n",
            "Epoch 9/10\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m143s\u001b[0m 1s/step - accuracy: 0.7378 - loss: 0.7089 - val_accuracy: 0.7545 - val_loss: 0.7670\n",
            "Epoch 10/10\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m140s\u001b[0m 1s/step - accuracy: 0.7330 - loss: 0.7138 - val_accuracy: 0.7545 - val_loss: 0.7665\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "from sklearn.metrics import confusion_matrix, f1_score, precision_score, recall_score\n",
        "\n",
        "# Custom Callback to compute metrics after each epoch\n",
        "class MetricsCallback(tf.keras.callbacks.Callback):\n",
        "    def __init__(self, validation_batches):\n",
        "        self.val_batches = validation_batches\n",
        "        self.sensitivity = []\n",
        "        self.specificity = []\n",
        "        self.f1_scores = []\n",
        "\n",
        "    def on_epoch_end(self, epoch, logs=None):\n",
        "        # Get predictions and ground truths\n",
        "        y_pred_probs = self.model.predict(self.val_batches)  # Probabilities\n",
        "        y_pred = np.argmax(y_pred_probs, axis=1)  # Predicted classes\n",
        "        y_true = self.val_batches.labels  # True labels\n",
        "\n",
        "        # Compute confusion matrix\n",
        "        cm = confusion_matrix(y_true, y_pred)\n",
        "        tp = np.diag(cm)  # True positives for each class\n",
        "        fp = np.sum(cm, axis=0) - tp  # False positives for each class\n",
        "        fn = np.sum(cm, axis=1) - tp  # False negatives for each class\n",
        "        tn = np.sum(cm) - (tp + fp + fn)  # True negatives for each class\n",
        "\n",
        "        # Avoid division by zero\n",
        "        recall = tp / (tp + fn + np.finfo(float).eps)  # Sensitivity (Recall)\n",
        "        specificity = tn / (tn + fp + np.finfo(float).eps)\n",
        "        precision = tp / (tp + fp + np.finfo(float).eps)\n",
        "        f1 = 2 * (precision * recall) / (precision + recall + np.finfo(float).eps)\n",
        "\n",
        "        # Store averages across classes\n",
        "        self.sensitivity.append(np.mean(recall))\n",
        "        self.specificity.append(np.mean(specificity))\n",
        "        self.f1_scores.append(np.mean(f1))\n",
        "\n",
        "# Initialize callback\n",
        "metrics_callback = MetricsCallback(validation_batches=val_batches)\n",
        "\n",
        "# Train the model with the callback\n",
        "history = model.fit(train_batches,\n",
        "                    epochs=50,\n",
        "                    validation_data=val_batches,\n",
        "                    callbacks=[metrics_callback])\n",
        "\n",
        "# Plot Accuracy\n",
        "plt.plot(history.history['accuracy'], label='Train Accuracy')\n",
        "plt.plot(history.history['val_accuracy'], label='Validation Accuracy')\n",
        "plt.xlabel('Epochs')\n",
        "plt.ylabel('Accuracy')\n",
        "plt.title('Accuracy vs Epochs')\n",
        "plt.legend()\n",
        "plt.savefig('cnn acc vs epoch.png')\n",
        "plt.show()\n",
        "\n",
        "# Plot Sensitivity, Specificity, and F1-score\n",
        "epochs = range(1, len(metrics_callback.sensitivity) + 1)\n",
        "plt.plot(epochs, metrics_callback.sensitivity, label='Sensitivity')\n",
        "plt.plot(epochs, metrics_callback.specificity, label='Specificity')\n",
        "plt.plot(epochs, metrics_callback.f1_scores, label='F1 Score')\n",
        "plt.xlabel('Epochs')\n",
        "plt.ylabel('Metrics')\n",
        "plt.title('Sensitivity, Specificity, and F1-Score vs Epochs')\n",
        "plt.legend()\n",
        "plt.savefig('cnn spec f1 vs epoch.png')\n",
        "plt.show()\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "Qj9X5SIpKRFg",
        "outputId": "c196f1c2-626f-4974-84c9-0b742bf74033"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Epoch 1/50\n",
            "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m8s\u001b[0m 440ms/step\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m107s\u001b[0m 1s/step - accuracy: 0.8073 - loss: 0.5413 - val_accuracy: 0.7327 - val_loss: 0.7587\n",
            "Epoch 2/50\n",
            "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m8s\u001b[0m 418ms/step\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m138s\u001b[0m 1s/step - accuracy: 0.8078 - loss: 0.5459 - val_accuracy: 0.7400 - val_loss: 0.7507\n",
            "Epoch 3/50\n",
            "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 327ms/step\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m141s\u001b[0m 1s/step - accuracy: 0.8205 - loss: 0.5216 - val_accuracy: 0.7327 - val_loss: 0.7526\n",
            "Epoch 4/50\n",
            "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 320ms/step\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m104s\u001b[0m 1s/step - accuracy: 0.8220 - loss: 0.5207 - val_accuracy: 0.7327 - val_loss: 0.7552\n",
            "Epoch 5/50\n",
            "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 332ms/step\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m102s\u001b[0m 1s/step - accuracy: 0.8213 - loss: 0.5000 - val_accuracy: 0.7364 - val_loss: 0.7481\n",
            "Epoch 6/50\n",
            "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 316ms/step\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m107s\u001b[0m 1s/step - accuracy: 0.8227 - loss: 0.5007 - val_accuracy: 0.7400 - val_loss: 0.7574\n",
            "Epoch 7/50\n",
            "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m8s\u001b[0m 429ms/step\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m101s\u001b[0m 1s/step - accuracy: 0.8283 - loss: 0.4953 - val_accuracy: 0.7364 - val_loss: 0.7528\n",
            "Epoch 8/50\n",
            "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 324ms/step\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m144s\u001b[0m 1s/step - accuracy: 0.8353 - loss: 0.4914 - val_accuracy: 0.7418 - val_loss: 0.7406\n",
            "Epoch 9/50\n",
            "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 326ms/step\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m104s\u001b[0m 1s/step - accuracy: 0.8286 - loss: 0.4643 - val_accuracy: 0.7382 - val_loss: 0.7535\n",
            "Epoch 10/50\n",
            "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 322ms/step\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m107s\u001b[0m 1s/step - accuracy: 0.8268 - loss: 0.4771 - val_accuracy: 0.7436 - val_loss: 0.7535\n",
            "Epoch 11/50\n",
            "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m8s\u001b[0m 432ms/step\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m101s\u001b[0m 1s/step - accuracy: 0.8484 - loss: 0.4625 - val_accuracy: 0.7382 - val_loss: 0.7537\n",
            "Epoch 12/50\n",
            "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 331ms/step\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m144s\u001b[0m 1s/step - accuracy: 0.8515 - loss: 0.4411 - val_accuracy: 0.7345 - val_loss: 0.7652\n",
            "Epoch 13/50\n",
            "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 322ms/step\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m106s\u001b[0m 1s/step - accuracy: 0.8459 - loss: 0.4426 - val_accuracy: 0.7364 - val_loss: 0.7536\n",
            "Epoch 14/50\n",
            "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 326ms/step\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m111s\u001b[0m 1s/step - accuracy: 0.8416 - loss: 0.4430 - val_accuracy: 0.7345 - val_loss: 0.7559\n",
            "Epoch 15/50\n",
            "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 324ms/step\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m138s\u001b[0m 1s/step - accuracy: 0.8490 - loss: 0.4353 - val_accuracy: 0.7455 - val_loss: 0.7509\n",
            "Epoch 16/50\n",
            "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m8s\u001b[0m 436ms/step\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m103s\u001b[0m 1s/step - accuracy: 0.8572 - loss: 0.4196 - val_accuracy: 0.7400 - val_loss: 0.7536\n",
            "Epoch 17/50\n",
            "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 337ms/step\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m145s\u001b[0m 1s/step - accuracy: 0.8584 - loss: 0.4279 - val_accuracy: 0.7509 - val_loss: 0.7486\n",
            "Epoch 18/50\n",
            "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 380ms/step\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m109s\u001b[0m 1s/step - accuracy: 0.8700 - loss: 0.4032 - val_accuracy: 0.7491 - val_loss: 0.7484\n",
            "Epoch 19/50\n",
            "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 317ms/step\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m105s\u001b[0m 1s/step - accuracy: 0.8678 - loss: 0.4065 - val_accuracy: 0.7436 - val_loss: 0.7501\n",
            "Epoch 20/50\n",
            "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 329ms/step\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m141s\u001b[0m 1s/step - accuracy: 0.8708 - loss: 0.3996 - val_accuracy: 0.7473 - val_loss: 0.7510\n",
            "Epoch 21/50\n",
            "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m8s\u001b[0m 433ms/step\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m103s\u001b[0m 1s/step - accuracy: 0.8808 - loss: 0.3935 - val_accuracy: 0.7382 - val_loss: 0.7557\n",
            "Epoch 22/50\n",
            "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m8s\u001b[0m 434ms/step\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m105s\u001b[0m 1s/step - accuracy: 0.8671 - loss: 0.4118 - val_accuracy: 0.7400 - val_loss: 0.7632\n",
            "Epoch 23/50\n",
            "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m8s\u001b[0m 429ms/step\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m102s\u001b[0m 1s/step - accuracy: 0.8693 - loss: 0.3885 - val_accuracy: 0.7527 - val_loss: 0.7595\n",
            "Epoch 24/50\n",
            "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 314ms/step\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m141s\u001b[0m 1s/step - accuracy: 0.8671 - loss: 0.3895 - val_accuracy: 0.7418 - val_loss: 0.7526\n",
            "Epoch 25/50\n",
            "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m8s\u001b[0m 433ms/step\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m147s\u001b[0m 1s/step - accuracy: 0.8796 - loss: 0.3842 - val_accuracy: 0.7364 - val_loss: 0.7678\n",
            "Epoch 26/50\n",
            "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m8s\u001b[0m 429ms/step\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m139s\u001b[0m 1s/step - accuracy: 0.8792 - loss: 0.3600 - val_accuracy: 0.7491 - val_loss: 0.7655\n",
            "Epoch 27/50\n",
            "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m8s\u001b[0m 431ms/step\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m101s\u001b[0m 1s/step - accuracy: 0.8941 - loss: 0.3391 - val_accuracy: 0.7400 - val_loss: 0.7691\n",
            "Epoch 28/50\n",
            "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 321ms/step\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m145s\u001b[0m 1s/step - accuracy: 0.8877 - loss: 0.3532 - val_accuracy: 0.7418 - val_loss: 0.7658\n",
            "Epoch 29/50\n",
            "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 325ms/step\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m102s\u001b[0m 1s/step - accuracy: 0.8948 - loss: 0.3400 - val_accuracy: 0.7436 - val_loss: 0.7761\n",
            "Epoch 30/50\n",
            "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 324ms/step\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m147s\u001b[0m 1s/step - accuracy: 0.8936 - loss: 0.3344 - val_accuracy: 0.7527 - val_loss: 0.7652\n",
            "Epoch 31/50\n",
            "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 320ms/step\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m137s\u001b[0m 1s/step - accuracy: 0.8954 - loss: 0.3314 - val_accuracy: 0.7327 - val_loss: 0.7747\n",
            "Epoch 32/50\n",
            "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 370ms/step\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m108s\u001b[0m 1s/step - accuracy: 0.9006 - loss: 0.3306 - val_accuracy: 0.7436 - val_loss: 0.7662\n",
            "Epoch 33/50\n",
            "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m8s\u001b[0m 433ms/step\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m137s\u001b[0m 1s/step - accuracy: 0.8943 - loss: 0.3268 - val_accuracy: 0.7436 - val_loss: 0.7721\n",
            "Epoch 34/50\n",
            "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 314ms/step\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m142s\u001b[0m 1s/step - accuracy: 0.9086 - loss: 0.3117 - val_accuracy: 0.7436 - val_loss: 0.7863\n",
            "Epoch 35/50\n",
            "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 315ms/step\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m100s\u001b[0m 1s/step - accuracy: 0.9065 - loss: 0.2979 - val_accuracy: 0.7418 - val_loss: 0.7767\n",
            "Epoch 36/50\n",
            "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 334ms/step\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m106s\u001b[0m 1s/step - accuracy: 0.9149 - loss: 0.3005 - val_accuracy: 0.7455 - val_loss: 0.7790\n",
            "Epoch 37/50\n",
            "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m8s\u001b[0m 431ms/step\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m136s\u001b[0m 1s/step - accuracy: 0.9074 - loss: 0.2953 - val_accuracy: 0.7364 - val_loss: 0.7779\n",
            "Epoch 38/50\n",
            "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 320ms/step\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m143s\u001b[0m 1s/step - accuracy: 0.9103 - loss: 0.3055 - val_accuracy: 0.7509 - val_loss: 0.7785\n",
            "Epoch 39/50\n",
            "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m9s\u001b[0m 491ms/step\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m105s\u001b[0m 1s/step - accuracy: 0.9185 - loss: 0.2735 - val_accuracy: 0.7418 - val_loss: 0.7780\n",
            "Epoch 40/50\n",
            "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m8s\u001b[0m 422ms/step\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m137s\u001b[0m 1s/step - accuracy: 0.9141 - loss: 0.2989 - val_accuracy: 0.7364 - val_loss: 0.7859\n",
            "Epoch 41/50\n",
            "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m9s\u001b[0m 506ms/step\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m102s\u001b[0m 1s/step - accuracy: 0.9229 - loss: 0.2660 - val_accuracy: 0.7382 - val_loss: 0.7951\n",
            "Epoch 42/50\n",
            "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m8s\u001b[0m 429ms/step\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m100s\u001b[0m 1s/step - accuracy: 0.9228 - loss: 0.2747 - val_accuracy: 0.7491 - val_loss: 0.7982\n",
            "Epoch 43/50\n",
            "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m8s\u001b[0m 422ms/step\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m100s\u001b[0m 1s/step - accuracy: 0.9129 - loss: 0.2866 - val_accuracy: 0.7345 - val_loss: 0.7970\n",
            "Epoch 44/50\n",
            "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 321ms/step\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m145s\u001b[0m 1s/step - accuracy: 0.9183 - loss: 0.2825 - val_accuracy: 0.7382 - val_loss: 0.7967\n",
            "Epoch 45/50\n",
            "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 321ms/step\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m100s\u001b[0m 1s/step - accuracy: 0.9215 - loss: 0.2690 - val_accuracy: 0.7436 - val_loss: 0.7990\n",
            "Epoch 46/50\n",
            "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 317ms/step\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m102s\u001b[0m 1s/step - accuracy: 0.9264 - loss: 0.2446 - val_accuracy: 0.7364 - val_loss: 0.7988\n",
            "Epoch 47/50\n",
            "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 395ms/step\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m143s\u001b[0m 1s/step - accuracy: 0.9116 - loss: 0.2727 - val_accuracy: 0.7436 - val_loss: 0.7939\n",
            "Epoch 48/50\n",
            "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 317ms/step\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m140s\u001b[0m 1s/step - accuracy: 0.9241 - loss: 0.2640 - val_accuracy: 0.7509 - val_loss: 0.7995\n",
            "Epoch 49/50\n",
            "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m10s\u001b[0m 537ms/step\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m143s\u001b[0m 1s/step - accuracy: 0.9366 - loss: 0.2348 - val_accuracy: 0.7545 - val_loss: 0.7973\n",
            "Epoch 50/50\n",
            "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 396ms/step\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m140s\u001b[0m 1s/step - accuracy: 0.9305 - loss: 0.2520 - val_accuracy: 0.7455 - val_loss: 0.7954\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 640x480 with 1 Axes>"
            ],
            "image/png": "\n"
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 640x480 with 1 Axes>"
            ],
            "image/png": "\n"
          },
          "metadata": {}
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "files.download('cnn acc vs epoch.png')\n",
        "files.download('cnn spec f1 vs epoch.png')"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 17
        },
        "id": "fmRa8IcOTLkI",
        "outputId": "af8ef24b-3d4e-4361-ed42-65f41e2f6cbc"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Javascript object>"
            ],
            "application/javascript": [
              "\n",
              "    async function download(id, filename, size) {\n",
              "      if (!google.colab.kernel.accessAllowed) {\n",
              "        return;\n",
              "      }\n",
              "      const div = document.createElement('div');\n",
              "      const label = document.createElement('label');\n",
              "      label.textContent = `Downloading \"${filename}\": `;\n",
              "      div.appendChild(label);\n",
              "      const progress = document.createElement('progress');\n",
              "      progress.max = size;\n",
              "      div.appendChild(progress);\n",
              "      document.body.appendChild(div);\n",
              "\n",
              "      const buffers = [];\n",
              "      let downloaded = 0;\n",
              "\n",
              "      const channel = await google.colab.kernel.comms.open(id);\n",
              "      // Send a message to notify the kernel that we're ready.\n",
              "      channel.send({})\n",
              "\n",
              "      for await (const message of channel.messages) {\n",
              "        // Send a message to notify the kernel that we're ready.\n",
              "        channel.send({})\n",
              "        if (message.buffers) {\n",
              "          for (const buffer of message.buffers) {\n",
              "            buffers.push(buffer);\n",
              "            downloaded += buffer.byteLength;\n",
              "            progress.value = downloaded;\n",
              "          }\n",
              "        }\n",
              "      }\n",
              "      const blob = new Blob(buffers, {type: 'application/binary'});\n",
              "      const a = document.createElement('a');\n",
              "      a.href = window.URL.createObjectURL(blob);\n",
              "      a.download = filename;\n",
              "      div.appendChild(a);\n",
              "      a.click();\n",
              "      div.remove();\n",
              "    }\n",
              "  "
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Javascript object>"
            ],
            "application/javascript": [
              "download(\"download_09c4daec-a244-4b57-9fe7-25703f2fb2bb\", \"cnn acc vs epoch.png\", 39993)"
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Javascript object>"
            ],
            "application/javascript": [
              "\n",
              "    async function download(id, filename, size) {\n",
              "      if (!google.colab.kernel.accessAllowed) {\n",
              "        return;\n",
              "      }\n",
              "      const div = document.createElement('div');\n",
              "      const label = document.createElement('label');\n",
              "      label.textContent = `Downloading \"${filename}\": `;\n",
              "      div.appendChild(label);\n",
              "      const progress = document.createElement('progress');\n",
              "      progress.max = size;\n",
              "      div.appendChild(progress);\n",
              "      document.body.appendChild(div);\n",
              "\n",
              "      const buffers = [];\n",
              "      let downloaded = 0;\n",
              "\n",
              "      const channel = await google.colab.kernel.comms.open(id);\n",
              "      // Send a message to notify the kernel that we're ready.\n",
              "      channel.send({})\n",
              "\n",
              "      for await (const message of channel.messages) {\n",
              "        // Send a message to notify the kernel that we're ready.\n",
              "        channel.send({})\n",
              "        if (message.buffers) {\n",
              "          for (const buffer of message.buffers) {\n",
              "            buffers.push(buffer);\n",
              "            downloaded += buffer.byteLength;\n",
              "            progress.value = downloaded;\n",
              "          }\n",
              "        }\n",
              "      }\n",
              "      const blob = new Blob(buffers, {type: 'application/binary'});\n",
              "      const a = document.createElement('a');\n",
              "      a.href = window.URL.createObjectURL(blob);\n",
              "      a.download = filename;\n",
              "      div.appendChild(a);\n",
              "      a.click();\n",
              "      div.remove();\n",
              "    }\n",
              "  "
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Javascript object>"
            ],
            "application/javascript": [
              "download(\"download_d4f6269d-e3e7-4295-8b54-975e06234979\", \"cnn spec f1 vs epoch.png\", 38896)"
            ]
          },
          "metadata": {}
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "model.save('64x3-CNN.keras')"
      ],
      "metadata": {
        "id": "Dio44z4w62QL"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "test_loss, test_accuracy = model.evaluate(test_batches)\n",
        "print(f\"Test Accuracy: {test_accuracy * 100:.2f}%\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "44sSxmGdSrAH",
        "outputId": "5b045e9a-46a5-47b5-a54b-b94f021a158f"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m6s\u001b[0m 329ms/step - accuracy: 0.7009 - loss: 0.8493\n",
            "Test Accuracy: 73.82%\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from sklearn.metrics import roc_curve, auc\n",
        "from sklearn.preprocessing import label_binarize\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "y_test = test_batches.labels\n",
        "# Binarize the output (e.g., for 5 classes)\n",
        "y_test_bin = label_binarize(y_test, classes=[0, 1, 2, 3, 4])\n",
        "n_classes = y_test_bin.shape[1]  # Number of classes\n",
        "\n",
        "# Ensure predictions are also in the right shape\n",
        "y_pred_probs = model.predict(test_batches)  # Probabilities for all classes\n",
        "\n",
        "# Compute ROC curve and AUC for each class\n",
        "fpr = {}\n",
        "tpr = {}\n",
        "roc_auc = {}\n",
        "\n",
        "for i in range(n_classes):\n",
        "    fpr[i], tpr[i], _ = roc_curve(y_test_bin[:, i], y_pred_probs[:, i])\n",
        "    roc_auc[i] = auc(fpr[i], tpr[i])\n",
        "\n",
        "# Plot all ROC curves\n",
        "plt.figure()\n",
        "colors = ['blue', 'green', 'red', 'cyan', 'magenta']\n",
        "for i in range(n_classes):\n",
        "    plt.plot(fpr[i], tpr[i], color=colors[i], lw=2,\n",
        "             label=f'Class {i} (AUC = {roc_auc[i]:.2f})')\n",
        "\n",
        "# Add random guess line\n",
        "plt.plot([0, 1], [0, 1], color='gray', linestyle='--')\n",
        "plt.xlabel('False Positive Rate')\n",
        "plt.ylabel('True Positive Rate')\n",
        "plt.title('Receiver Operating Characteristic for Multi-Class')\n",
        "plt.legend(loc='lower right')\n",
        "plt.savefig('cnn roc curve.png')\n",
        "plt.show()\n",
        "\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 544
        },
        "id": "8vxZfaCOJgiY",
        "outputId": "73ce4818-40e5-4d5c-cadd-66a1a44ac9a0"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/keras/src/trainers/data_adapters/py_dataset_adapter.py:122: UserWarning: Your `PyDataset` class should call `super().__init__(**kwargs)` in its constructor. `**kwargs` can include `workers`, `use_multiprocessing`, `max_queue_size`. Do not pass these arguments to `fit()`, as they will be ignored.\n",
            "  self._warn_if_super_not_called()\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\u001b[1m18/18\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m8s\u001b[0m 435ms/step\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 640x480 with 1 Axes>"
            ],
            "image/png": "\n"
          },
          "metadata": {}
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from google.colab import files\n",
        "files.download('cnn roc curve.png')\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 17
        },
        "id": "R0fHuQEqRp3J",
        "outputId": "a5e08691-f8e4-4ad9-ed04-acfb81450d3f"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Javascript object>"
            ],
            "application/javascript": [
              "\n",
              "    async function download(id, filename, size) {\n",
              "      if (!google.colab.kernel.accessAllowed) {\n",
              "        return;\n",
              "      }\n",
              "      const div = document.createElement('div');\n",
              "      const label = document.createElement('label');\n",
              "      label.textContent = `Downloading \"${filename}\": `;\n",
              "      div.appendChild(label);\n",
              "      const progress = document.createElement('progress');\n",
              "      progress.max = size;\n",
              "      div.appendChild(progress);\n",
              "      document.body.appendChild(div);\n",
              "\n",
              "      const buffers = [];\n",
              "      let downloaded = 0;\n",
              "\n",
              "      const channel = await google.colab.kernel.comms.open(id);\n",
              "      // Send a message to notify the kernel that we're ready.\n",
              "      channel.send({})\n",
              "\n",
              "      for await (const message of channel.messages) {\n",
              "        // Send a message to notify the kernel that we're ready.\n",
              "        channel.send({})\n",
              "        if (message.buffers) {\n",
              "          for (const buffer of message.buffers) {\n",
              "            buffers.push(buffer);\n",
              "            downloaded += buffer.byteLength;\n",
              "            progress.value = downloaded;\n",
              "          }\n",
              "        }\n",
              "      }\n",
              "      const blob = new Blob(buffers, {type: 'application/binary'});\n",
              "      const a = document.createElement('a');\n",
              "      a.href = window.URL.createObjectURL(blob);\n",
              "      a.download = filename;\n",
              "      div.appendChild(a);\n",
              "      a.click();\n",
              "      div.remove();\n",
              "    }\n",
              "  "
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Javascript object>"
            ],
            "application/javascript": [
              "download(\"download_49361e34-304f-4086-939c-a50de81b72cf\", \"cnn roc curve.png\", 48906)"
            ]
          },
          "metadata": {}
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "#generate confusion matrix\n",
        "from sklearn.metrics import classification_report\n",
        "y_pred = np.argmax(y_pred_probs, axis=1)  # Predicted class labels\n",
        "print(classification_report(y_test, y_pred, target_names=test_batches.class_indices.keys()))\n",
        "from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay\n",
        "cm = confusion_matrix(y_test, y_pred)\n",
        "disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=test_batches.class_indices.keys())\n",
        "disp.plot(cmap=\"Blues\")\n",
        "plt.savefig('cnn confusion matrix.png')\n",
        "plt.show()\n",
        "\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 658
        },
        "id": "bvpypyMW1zA1",
        "outputId": "144aab14-2ed0-4b0d-f508-f3455e5334b6"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "                precision    recall  f1-score   support\n",
            "\n",
            "          Mild       0.47      0.50      0.49        56\n",
            "      Moderate       0.60      0.73      0.66       150\n",
            "         No_DR       0.89      0.97      0.93       271\n",
            "Proliferate_DR       0.33      0.02      0.04        44\n",
            "        Severe       0.60      0.21      0.31        29\n",
            "\n",
            "      accuracy                           0.74       550\n",
            "     macro avg       0.58      0.49      0.49       550\n",
            "  weighted avg       0.71      0.74      0.71       550\n",
            "\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 640x480 with 2 Axes>"
            ],
            "image/png": "\n"
          },
          "metadata": {}
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from google.colab import files\n",
        "files.download('cnn confusion matrix.png')\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 17
        },
        "id": "VlBDK0727jAr",
        "outputId": "81f9a05b-cd68-4080-e3bc-79a6aa236a27"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Javascript object>"
            ],
            "application/javascript": [
              "\n",
              "    async function download(id, filename, size) {\n",
              "      if (!google.colab.kernel.accessAllowed) {\n",
              "        return;\n",
              "      }\n",
              "      const div = document.createElement('div');\n",
              "      const label = document.createElement('label');\n",
              "      label.textContent = `Downloading \"${filename}\": `;\n",
              "      div.appendChild(label);\n",
              "      const progress = document.createElement('progress');\n",
              "      progress.max = size;\n",
              "      div.appendChild(progress);\n",
              "      document.body.appendChild(div);\n",
              "\n",
              "      const buffers = [];\n",
              "      let downloaded = 0;\n",
              "\n",
              "      const channel = await google.colab.kernel.comms.open(id);\n",
              "      // Send a message to notify the kernel that we're ready.\n",
              "      channel.send({})\n",
              "\n",
              "      for await (const message of channel.messages) {\n",
              "        // Send a message to notify the kernel that we're ready.\n",
              "        channel.send({})\n",
              "        if (message.buffers) {\n",
              "          for (const buffer of message.buffers) {\n",
              "            buffers.push(buffer);\n",
              "            downloaded += buffer.byteLength;\n",
              "            progress.value = downloaded;\n",
              "          }\n",
              "        }\n",
              "      }\n",
              "      const blob = new Blob(buffers, {type: 'application/binary'});\n",
              "      const a = document.createElement('a');\n",
              "      a.href = window.URL.createObjectURL(blob);\n",
              "      a.download = filename;\n",
              "      div.appendChild(a);\n",
              "      a.click();\n",
              "      div.remove();\n",
              "    }\n",
              "  "
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Javascript object>"
            ],
            "application/javascript": [
              "download(\"download_b059b73c-48e9-4e03-a601-5e77d7d71fdf\", \"cnn confusion matrix.png\", 27332)"
            ]
          },
          "metadata": {}
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "#evaluate training and validation accuracy on a iteratively halving dataset\n",
        "from math import ceil, pow\n",
        "src_dir = r'../root/.cache/kagglehub/datasets/sovitrath/diabetic-retinopathy-224x224-2019-data/versions/4/colored_images'\n",
        "for i in range(5):\n",
        "  total_files = len(train)\n",
        "  denom = pow(2, i)\n",
        "  print(denom)\n",
        "  files_to_copy = ceil(total_files / denom)  # Half the files, rounded up if odd\n",
        "  print(files_to_copy)\n",
        "\n",
        "  if os.path.exists(train_dir):\n",
        "    shutil.rmtree(train_dir)\n",
        "  os.makedirs(train_dir, exist_ok=True)\n",
        "\n",
        "\n",
        "  # Counter to track the number of files copied\n",
        "  copied_files = 0\n",
        "\n",
        "  for index, row in train.iterrows():\n",
        "    if copied_files >= files_to_copy:\n",
        "        break  # Stop when half of the files are copied\n",
        "\n",
        "    diagnosis = row['type']\n",
        "    binary_diagnosis = row['binary_type']\n",
        "    id_code = row['id_code'] + \".png\"\n",
        "    srcfile = os.path.join(src_dir, diagnosis, id_code)\n",
        "    dstfile = os.path.join(train_dir, binary_diagnosis)\n",
        "    os.makedirs(dstfile, exist_ok=True)\n",
        "\n",
        "    if os.path.exists(srcfile):\n",
        "        shutil.copy(srcfile, dstfile)\n",
        "        copied_files += 1  # Increment the counter\n",
        "\n",
        "  for index, row in val.iterrows():\n",
        "    diagnosis = row['type']\n",
        "    binary_diagnosis = row['binary_type']\n",
        "    id_code = row['id_code'] + \".png\"\n",
        "    srcfile = os.path.join(src_dir, diagnosis, id_code)\n",
        "    dstfile = os.path.join(val_dir, binary_diagnosis)\n",
        "    os.makedirs(dstfile, exist_ok=True)\n",
        "    if os.path.exists(srcfile):\n",
        "        shutil.copy(srcfile, dstfile)\n",
        "\n",
        "  for index, row in test.iterrows():\n",
        "    diagnosis = row['type']\n",
        "    binary_diagnosis = row['binary_type']\n",
        "    id_code = row['id_code'] + \".png\"\n",
        "    srcfile = os.path.join(src_dir, diagnosis, id_code)\n",
        "    dstfile = os.path.join(test_dir, binary_diagnosis)\n",
        "    os.makedirs(dstfile, exist_ok=True)\n",
        "    if os.path.exists(srcfile):\n",
        "        shutil.copy(srcfile, dstfile)\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "  for subdir in [train_dir, val_dir, test_dir]:\n",
        "    print(f\"\\nContents of {subdir}:\")\n",
        "    for root, dirs, files in os.walk(subdir):\n",
        "        print(f\"{root}: {len(files)} files\")\n",
        "\n",
        "  train_path = train_dir\n",
        "  val_path = val_dir\n",
        "  test_path = test_dir\n",
        "\n",
        "  train_batches = ImageDataGenerator(rescale = 1./255).flow_from_directory(train_path, target_size=(224,224), shuffle = True)\n",
        "  val_batches = ImageDataGenerator(rescale = 1./255).flow_from_directory(val_path, target_size=(224,224), shuffle = True)\n",
        "  test_batches = ImageDataGenerator(rescale = 1./255).flow_from_directory(test_path, target_size=(224,224), shuffle = False)\n",
        "\n",
        "  #cnn model\n",
        "  model = tf.keras.Sequential([\n",
        "      layers.Conv2D(8, (3,3), padding=\"valid\", input_shape=(224,224,3), activation = 'relu'),\n",
        "      layers.MaxPooling2D(pool_size=(2,2)),\n",
        "      layers.BatchNormalization(),\n",
        "\n",
        "      layers.Conv2D(16, (3,3), padding=\"valid\", activation = 'relu'),\n",
        "      layers.MaxPooling2D(pool_size=(2,2)),\n",
        "      layers.BatchNormalization(),\n",
        "\n",
        "      layers.Conv2D(32, (4,4), padding=\"valid\", activation = 'relu'),\n",
        "      layers.MaxPooling2D(pool_size=(2,2)),\n",
        "      layers.BatchNormalization(),\n",
        "      layers.Flatten(),\n",
        "      layers.Dense(32, activation = 'relu'),\n",
        "      layers.Dropout(0.15),\n",
        "      layers.Dense(2, activation = 'softmax')\n",
        "  ])\n",
        "\n",
        "  model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),\n",
        "                loss='categorical_crossentropy',\n",
        "                metrics=['accuracy'])\n",
        "\n",
        "\n",
        "  history = model.fit(train_batches,\n",
        "                      epochs=10,\n",
        "                      validation_data=val_batches)\n",
        "\n",
        "\n",
        "\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "U39_KR12Ot-t",
        "outputId": "002f0a49-d8a9-4b67-9804-58a475f6271b"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "1.0\n",
            "2562\n",
            "\n",
            "Contents of .\\dataset/train:\n",
            ".\\dataset/train: 0 files\n",
            ".\\dataset/train/No_DR: 1263 files\n",
            ".\\dataset/train/DR: 1299 files\n",
            "\n",
            "Contents of .\\dataset/val:\n",
            ".\\dataset/val: 0 files\n",
            ".\\dataset/val/No_DR: 271 files\n",
            ".\\dataset/val/DR: 279 files\n",
            "\n",
            "Contents of .\\dataset/test:\n",
            ".\\dataset/test: 0 files\n",
            ".\\dataset/test/No_DR: 271 files\n",
            ".\\dataset/test/DR: 279 files\n",
            "Found 2562 images belonging to 2 classes.\n",
            "Found 550 images belonging to 2 classes.\n",
            "Found 550 images belonging to 2 classes.\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/keras/src/layers/convolutional/base_conv.py:107: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead.\n",
            "  super().__init__(activity_regularizer=activity_regularizer, **kwargs)\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Epoch 1/10\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/keras/src/trainers/data_adapters/py_dataset_adapter.py:122: UserWarning: Your `PyDataset` class should call `super().__init__(**kwargs)` in its constructor. `**kwargs` can include `workers`, `use_multiprocessing`, `max_queue_size`. Do not pass these arguments to `fit()`, as they will be ignored.\n",
            "  self._warn_if_super_not_called()\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m95s\u001b[0m 1s/step - accuracy: 0.7844 - loss: 0.4764 - val_accuracy: 0.5073 - val_loss: 0.8788\n",
            "Epoch 2/10\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m138s\u001b[0m 1s/step - accuracy: 0.9100 - loss: 0.2328 - val_accuracy: 0.5073 - val_loss: 1.0356\n",
            "Epoch 3/10\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m87s\u001b[0m 1s/step - accuracy: 0.9173 - loss: 0.2120 - val_accuracy: 0.5073 - val_loss: 1.1552\n",
            "Epoch 4/10\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m144s\u001b[0m 1s/step - accuracy: 0.9285 - loss: 0.1944 - val_accuracy: 0.5109 - val_loss: 0.9722\n",
            "Epoch 5/10\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m87s\u001b[0m 1s/step - accuracy: 0.9311 - loss: 0.1797 - val_accuracy: 0.5491 - val_loss: 0.6548\n",
            "Epoch 6/10\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m143s\u001b[0m 1s/step - accuracy: 0.9346 - loss: 0.1593 - val_accuracy: 0.7582 - val_loss: 0.4198\n",
            "Epoch 7/10\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m86s\u001b[0m 1s/step - accuracy: 0.9482 - loss: 0.1526 - val_accuracy: 0.9164 - val_loss: 0.2454\n",
            "Epoch 8/10\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m89s\u001b[0m 1s/step - accuracy: 0.9395 - loss: 0.1521 - val_accuracy: 0.9236 - val_loss: 0.1954\n",
            "Epoch 9/10\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m87s\u001b[0m 1s/step - accuracy: 0.9404 - loss: 0.1574 - val_accuracy: 0.9309 - val_loss: 0.1698\n",
            "Epoch 10/10\n",
            "\u001b[1m81/81\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m143s\u001b[0m 1s/step - accuracy: 0.9452 - loss: 0.1475 - val_accuracy: 0.9327 - val_loss: 0.1669\n",
            "2.0\n",
            "1281\n",
            "\n",
            "Contents of .\\dataset/train:\n",
            ".\\dataset/train: 0 files\n",
            ".\\dataset/train/No_DR: 664 files\n",
            ".\\dataset/train/DR: 617 files\n",
            "\n",
            "Contents of .\\dataset/val:\n",
            ".\\dataset/val: 0 files\n",
            ".\\dataset/val/No_DR: 271 files\n",
            ".\\dataset/val/DR: 279 files\n",
            "\n",
            "Contents of .\\dataset/test:\n",
            ".\\dataset/test: 0 files\n",
            ".\\dataset/test/No_DR: 271 files\n",
            ".\\dataset/test/DR: 279 files\n",
            "Found 1281 images belonging to 2 classes.\n",
            "Found 550 images belonging to 2 classes.\n",
            "Found 550 images belonging to 2 classes.\n",
            "Epoch 1/10\n",
            "\u001b[1m41/41\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m52s\u001b[0m 1s/step - accuracy: 0.7730 - loss: 0.4892 - val_accuracy: 0.5073 - val_loss: 0.6800\n",
            "Epoch 2/10\n",
            "\u001b[1m41/41\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m47s\u001b[0m 1s/step - accuracy: 0.8863 - loss: 0.2912 - val_accuracy: 0.5073 - val_loss: 0.6994\n",
            "Epoch 3/10\n",
            "\u001b[1m41/41\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m51s\u001b[0m 1s/step - accuracy: 0.8942 - loss: 0.2338 - val_accuracy: 0.5073 - val_loss: 0.7084\n",
            "Epoch 4/10\n",
            "\u001b[1m41/41\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m79s\u001b[0m 1s/step - accuracy: 0.9099 - loss: 0.2347 - val_accuracy: 0.5073 - val_loss: 0.7110\n",
            "Epoch 5/10\n",
            "\u001b[1m41/41\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m47s\u001b[0m 1s/step - accuracy: 0.9302 - loss: 0.2254 - val_accuracy: 0.5127 - val_loss: 0.6887\n",
            "Epoch 6/10\n",
            "\u001b[1m41/41\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m52s\u001b[0m 1s/step - accuracy: 0.9352 - loss: 0.1886 - val_accuracy: 0.5218 - val_loss: 0.6459\n",
            "Epoch 7/10\n",
            "\u001b[1m41/41\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m79s\u001b[0m 1s/step - accuracy: 0.9431 - loss: 0.1787 - val_accuracy: 0.5327 - val_loss: 0.6414\n",
            "Epoch 8/10\n",
            "\u001b[1m41/41\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m80s\u001b[0m 1s/step - accuracy: 0.9385 - loss: 0.1534 - val_accuracy: 0.5982 - val_loss: 0.5673\n",
            "Epoch 9/10\n",
            "\u001b[1m41/41\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m82s\u001b[0m 1s/step - accuracy: 0.9442 - loss: 0.1668 - val_accuracy: 0.7182 - val_loss: 0.4592\n",
            "Epoch 10/10\n",
            "\u001b[1m41/41\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m49s\u001b[0m 1s/step - accuracy: 0.9510 - loss: 0.1555 - val_accuracy: 0.7745 - val_loss: 0.4030\n",
            "4.0\n",
            "641\n",
            "\n",
            "Contents of .\\dataset/train:\n",
            ".\\dataset/train: 0 files\n",
            ".\\dataset/train/No_DR: 335 files\n",
            ".\\dataset/train/DR: 306 files\n",
            "\n",
            "Contents of .\\dataset/val:\n",
            ".\\dataset/val: 0 files\n",
            ".\\dataset/val/No_DR: 271 files\n",
            ".\\dataset/val/DR: 279 files\n",
            "\n",
            "Contents of .\\dataset/test:\n",
            ".\\dataset/test: 0 files\n",
            ".\\dataset/test/No_DR: 271 files\n",
            ".\\dataset/test/DR: 279 files\n",
            "Found 641 images belonging to 2 classes.\n",
            "Found 550 images belonging to 2 classes.\n",
            "Found 550 images belonging to 2 classes.\n",
            "Epoch 1/10\n",
            "\u001b[1m21/21\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m30s\u001b[0m 1s/step - accuracy: 0.6456 - loss: 0.6968 - val_accuracy: 0.6873 - val_loss: 0.6828\n",
            "Epoch 2/10\n",
            "\u001b[1m21/21\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m27s\u001b[0m 1s/step - accuracy: 0.8668 - loss: 0.3267 - val_accuracy: 0.7309 - val_loss: 0.6757\n",
            "Epoch 3/10\n",
            "\u001b[1m21/21\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m27s\u001b[0m 1s/step - accuracy: 0.8835 - loss: 0.2830 - val_accuracy: 0.5982 - val_loss: 0.6689\n",
            "Epoch 4/10\n",
            "\u001b[1m21/21\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m41s\u001b[0m 1s/step - accuracy: 0.8973 - loss: 0.2757 - val_accuracy: 0.5145 - val_loss: 0.6677\n",
            "Epoch 5/10\n",
            "\u001b[1m21/21\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m41s\u001b[0m 1s/step - accuracy: 0.9076 - loss: 0.2304 - val_accuracy: 0.5145 - val_loss: 0.6660\n",
            "Epoch 6/10\n",
            "\u001b[1m21/21\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m27s\u001b[0m 1s/step - accuracy: 0.9293 - loss: 0.1915 - val_accuracy: 0.5145 - val_loss: 0.6673\n",
            "Epoch 7/10\n",
            "\u001b[1m21/21\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m27s\u001b[0m 1s/step - accuracy: 0.9245 - loss: 0.1830 - val_accuracy: 0.5127 - val_loss: 0.6822\n",
            "Epoch 8/10\n",
            "\u001b[1m21/21\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m27s\u001b[0m 1s/step - accuracy: 0.9402 - loss: 0.1654 - val_accuracy: 0.5436 - val_loss: 0.6704\n",
            "Epoch 9/10\n",
            "\u001b[1m21/21\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m27s\u001b[0m 1s/step - accuracy: 0.9411 - loss: 0.1661 - val_accuracy: 0.5455 - val_loss: 0.6825\n",
            "Epoch 10/10\n",
            "\u001b[1m21/21\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m30s\u001b[0m 1s/step - accuracy: 0.9319 - loss: 0.1554 - val_accuracy: 0.5618 - val_loss: 0.6682\n",
            "8.0\n",
            "321\n",
            "\n",
            "Contents of .\\dataset/train:\n",
            ".\\dataset/train: 0 files\n",
            ".\\dataset/train/No_DR: 168 files\n",
            ".\\dataset/train/DR: 153 files\n",
            "\n",
            "Contents of .\\dataset/val:\n",
            ".\\dataset/val: 0 files\n",
            ".\\dataset/val/No_DR: 271 files\n",
            ".\\dataset/val/DR: 279 files\n",
            "\n",
            "Contents of .\\dataset/test:\n",
            ".\\dataset/test: 0 files\n",
            ".\\dataset/test/No_DR: 271 files\n",
            ".\\dataset/test/DR: 279 files\n",
            "Found 321 images belonging to 2 classes.\n",
            "Found 550 images belonging to 2 classes.\n",
            "Found 550 images belonging to 2 classes.\n",
            "Epoch 1/10\n",
            "\u001b[1m11/11\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m19s\u001b[0m 1s/step - accuracy: 0.6672 - loss: 0.6967 - val_accuracy: 0.5073 - val_loss: 0.6918\n",
            "Epoch 2/10\n",
            "\u001b[1m11/11\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 2s/step - accuracy: 0.8493 - loss: 0.3656 - val_accuracy: 0.5073 - val_loss: 0.6956\n",
            "Epoch 3/10\n",
            "\u001b[1m11/11\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 1s/step - accuracy: 0.8947 - loss: 0.3088 - val_accuracy: 0.5073 - val_loss: 0.7002\n",
            "Epoch 4/10\n",
            "\u001b[1m11/11\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 1s/step - accuracy: 0.8497 - loss: 0.3160 - val_accuracy: 0.5073 - val_loss: 0.7053\n",
            "Epoch 5/10\n",
            "\u001b[1m11/11\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m21s\u001b[0m 1s/step - accuracy: 0.8942 - loss: 0.2618 - val_accuracy: 0.5073 - val_loss: 0.7103\n",
            "Epoch 6/10\n",
            "\u001b[1m11/11\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 1s/step - accuracy: 0.9385 - loss: 0.2016 - val_accuracy: 0.5073 - val_loss: 0.7169\n",
            "Epoch 7/10\n",
            "\u001b[1m11/11\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m21s\u001b[0m 2s/step - accuracy: 0.9371 - loss: 0.1914 - val_accuracy: 0.5073 - val_loss: 0.7235\n",
            "Epoch 8/10\n",
            "\u001b[1m11/11\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m17s\u001b[0m 1s/step - accuracy: 0.9385 - loss: 0.1765 - val_accuracy: 0.5073 - val_loss: 0.7315\n",
            "Epoch 9/10\n",
            "\u001b[1m11/11\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 1s/step - accuracy: 0.9635 - loss: 0.1654 - val_accuracy: 0.5073 - val_loss: 0.7362\n",
            "Epoch 10/10\n",
            "\u001b[1m11/11\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m16s\u001b[0m 1s/step - accuracy: 0.9161 - loss: 0.1862 - val_accuracy: 0.5073 - val_loss: 0.7401\n",
            "16.0\n",
            "161\n",
            "\n",
            "Contents of .\\dataset/train:\n",
            ".\\dataset/train: 0 files\n",
            ".\\dataset/train/No_DR: 89 files\n",
            ".\\dataset/train/DR: 72 files\n",
            "\n",
            "Contents of .\\dataset/val:\n",
            ".\\dataset/val: 0 files\n",
            ".\\dataset/val/No_DR: 271 files\n",
            ".\\dataset/val/DR: 279 files\n",
            "\n",
            "Contents of .\\dataset/test:\n",
            ".\\dataset/test: 0 files\n",
            ".\\dataset/test/No_DR: 271 files\n",
            ".\\dataset/test/DR: 279 files\n",
            "Found 161 images belonging to 2 classes.\n",
            "Found 550 images belonging to 2 classes.\n",
            "Found 550 images belonging to 2 classes.\n",
            "Epoch 1/10\n",
            "\u001b[1m6/6\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m15s\u001b[0m 2s/step - accuracy: 0.6488 - loss: 0.7023 - val_accuracy: 0.5691 - val_loss: 0.6915\n",
            "Epoch 2/10\n",
            "\u001b[1m6/6\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 2s/step - accuracy: 0.7992 - loss: 0.4822 - val_accuracy: 0.5364 - val_loss: 0.6905\n",
            "Epoch 3/10\n",
            "\u001b[1m6/6\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 2s/step - accuracy: 0.8545 - loss: 0.3240 - val_accuracy: 0.5073 - val_loss: 0.6910\n",
            "Epoch 4/10\n",
            "\u001b[1m6/6\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m12s\u001b[0m 2s/step - accuracy: 0.8747 - loss: 0.3692 - val_accuracy: 0.5073 - val_loss: 0.6940\n",
            "Epoch 5/10\n",
            "\u001b[1m6/6\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 2s/step - accuracy: 0.9120 - loss: 0.2714 - val_accuracy: 0.5073 - val_loss: 0.6983\n",
            "Epoch 6/10\n",
            "\u001b[1m6/6\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m21s\u001b[0m 2s/step - accuracy: 0.9072 - loss: 0.2942 - val_accuracy: 0.5073 - val_loss: 0.7036\n",
            "Epoch 7/10\n",
            "\u001b[1m6/6\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m12s\u001b[0m 2s/step - accuracy: 0.8756 - loss: 0.2688 - val_accuracy: 0.5073 - val_loss: 0.7078\n",
            "Epoch 8/10\n",
            "\u001b[1m6/6\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m12s\u001b[0m 2s/step - accuracy: 0.9226 - loss: 0.2576 - val_accuracy: 0.5073 - val_loss: 0.7089\n",
            "Epoch 9/10\n",
            "\u001b[1m6/6\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m20s\u001b[0m 2s/step - accuracy: 0.9151 - loss: 0.2522 - val_accuracy: 0.5073 - val_loss: 0.7091\n",
            "Epoch 10/10\n",
            "\u001b[1m6/6\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m12s\u001b[0m 2s/step - accuracy: 0.9447 - loss: 0.2030 - val_accuracy: 0.5073 - val_loss: 0.7103\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import matplotlib.pyplot as plt\n",
        "\n",
        "# Data\n",
        "database_size = [2562, 1281, 641, 321, 161]\n",
        "cnn_train_acc = [94.52, 95.10, 93.19, 91.61, 94.47]\n",
        "cnn_val_acc = [93.27, 77.45, 56.18, 50.73, 50.73]\n",
        "vit_train_acc = [91.67, 88.21, 70.20, 54.34, 62.67]\n",
        "vit_val_acc = [92.73, 94, 69.27, 63.09, 63.09]\n",
        "\n",
        "# Plot\n",
        "figure, axis = plt.subplots(1, 2, figsize=(10, 5))\n",
        "# Define the y-ticks with 5 increments\n",
        "y_ticks = np.arange(50, 101, 5)\n",
        "\n",
        "# CNN plot\n",
        "axis[0].plot(database_size, cnn_train_acc, label='CNN Train Accuracy')\n",
        "axis[0].plot(database_size, cnn_val_acc, label='CNN Validation Accuracy')\n",
        "axis[0].set_xlabel('Database Size')\n",
        "axis[0].set_ylabel('Accuracy')\n",
        "axis[0].set_title('CNN Accuracy vs Database Size')\n",
        "#axis[0].set_xticks(database_size)  # Set the x-ticks explicitly\n",
        "axis[0].set_yticks(y_ticks)  # Set consistent y-ticks\n",
        "axis[0].invert_xaxis()  # Reverse the x-axis\n",
        "axis[0].legend()\n",
        "\n",
        "# ViT plot\n",
        "axis[1].plot(database_size, vit_train_acc, label='ViT Train Accuracy')\n",
        "axis[1].plot(database_size, vit_val_acc, label='ViT Validation Accuracy')\n",
        "axis[1].set_xlabel('Database Size')\n",
        "axis[1].set_ylabel('Accuracy')\n",
        "axis[1].set_title('ViT Accuracy vs Database Size')\n",
        "#axis[1].set_xticks(database_size)  # Set the x-ticks explicitly\n",
        "axis[1].set_yticks(y_ticks)  # Set consistent y-ticks\n",
        "axis[1].invert_xaxis()  # Reverse the x-axis\n",
        "axis[1].legend()\n",
        "plt.savefig('cnn vs vit.png')\n",
        "plt.tight_layout()\n",
        "plt.show()\n",
        "\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 507
        },
        "id": "geABaxbazdIg",
        "outputId": "74682125-a198-477d-9aa0-1b11e69e5ffa"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 1000x500 with 2 Axes>"
            ],
            "image/png": "\n"
          },
          "metadata": {}
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from google.colab import files\n",
        "files.download('cnn vs vit.png')"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 17
        },
        "id": "-TU7NzwR-x5a",
        "outputId": "e134c1f0-344a-43dd-8dd0-30bd052967ee"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Javascript object>"
            ],
            "application/javascript": [
              "\n",
              "    async function download(id, filename, size) {\n",
              "      if (!google.colab.kernel.accessAllowed) {\n",
              "        return;\n",
              "      }\n",
              "      const div = document.createElement('div');\n",
              "      const label = document.createElement('label');\n",
              "      label.textContent = `Downloading \"${filename}\": `;\n",
              "      div.appendChild(label);\n",
              "      const progress = document.createElement('progress');\n",
              "      progress.max = size;\n",
              "      div.appendChild(progress);\n",
              "      document.body.appendChild(div);\n",
              "\n",
              "      const buffers = [];\n",
              "      let downloaded = 0;\n",
              "\n",
              "      const channel = await google.colab.kernel.comms.open(id);\n",
              "      // Send a message to notify the kernel that we're ready.\n",
              "      channel.send({})\n",
              "\n",
              "      for await (const message of channel.messages) {\n",
              "        // Send a message to notify the kernel that we're ready.\n",
              "        channel.send({})\n",
              "        if (message.buffers) {\n",
              "          for (const buffer of message.buffers) {\n",
              "            buffers.push(buffer);\n",
              "            downloaded += buffer.byteLength;\n",
              "            progress.value = downloaded;\n",
              "          }\n",
              "        }\n",
              "      }\n",
              "      const blob = new Blob(buffers, {type: 'application/binary'});\n",
              "      const a = document.createElement('a');\n",
              "      a.href = window.URL.createObjectURL(blob);\n",
              "      a.download = filename;\n",
              "      div.appendChild(a);\n",
              "      a.click();\n",
              "      div.remove();\n",
              "    }\n",
              "  "
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Javascript object>"
            ],
            "application/javascript": [
              "download(\"download_3547fb09-3ca6-46a8-9622-6b968af0982d\", \"cnn vs vit.png\", 56896)"
            ]
          },
          "metadata": {}
        }
      ]
    }
  ]
}