Contact Us!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
Avatar for stephanie's main branch.

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

| Download

"Guiding Future STEM Leaders through Innovative Research Training" ~ thinkingbeyond.education

Views: 1070
Image: ubuntu2204
{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/ThinkingBeyond/BeyondAI-2024/blob/main/emeka/Cancer_Dataset_Code_Implementation.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "This code compares Multi-Layer Perceptron(MLP) with Kolmogorov Arnold Network (KAN) on the Wisconsin Breast Cancer Dataset.\n",
        "Each model is run with one hidden layer, tested over different width of the layer.\n",
        "The code test each model over a moderate width(27 neurons per layer), a larger width(64 neurons per layer) and a smaller width(18 neurons per layer)\n",
        "GridSearchCV was used by both models to calculate the best parameters and hyperparameters for the number of layer and neurons per layer chosen\n",
        "\n",
        "The result shows smooth covergence between each of the models,MLP had a very quick execution time and minumum loss unlike the KAN which took longer time to be executed and had a higher loss. Though, both models had similar accuracy score over the unseen data.\n"
      ],
      "metadata": {
        "id": "FZold8A3cMXE"
      }
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FRVmujAS_5KM"
      },
      "source": [
        "##Import Necessary Libraries"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "collapsed": true,
        "id": "BcXEEdI7_JeR",
        "outputId": "5448116e-7c9c-4f1e-8bf2-991fe96f31a5"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Collecting git+https://github.com/KindXiaoming/pykan.git\n",
            "  Cloning https://github.com/KindXiaoming/pykan.git to /tmp/pip-req-build-21vwrm42\n",
            "  Running command git clone --filter=blob:none --quiet https://github.com/KindXiaoming/pykan.git /tmp/pip-req-build-21vwrm42\n",
            "  Resolved https://github.com/KindXiaoming/pykan.git to commit f871c26d4df788ec1ba309c2c9c1803d82606b06\n",
            "  Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "Building wheels for collected packages: pykan\n",
            "  Building wheel for pykan (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for pykan: filename=pykan-0.2.8-py3-none-any.whl size=78208 sha256=2588868bf99eb288a347117698e7bd30e9aa1a54d3c4b0f7bbbeaf076602ce2e\n",
            "  Stored in directory: /tmp/pip-ephem-wheel-cache-ct7cab38/wheels/47/ca/5a/98124e020f3119f51c17f78738c621c140b7aa803b0feda76e\n",
            "Successfully built pykan\n",
            "Installing collected packages: pykan\n",
            "Successfully installed pykan-0.2.8\n",
            "Reading package lists... Done\n",
            "Building dependency tree... Done\n",
            "Reading state information... Done\n",
            "The following additional packages will be installed:\n",
            "  python3-pip-whl python3-setuptools-whl\n",
            "The following NEW packages will be installed:\n",
            "  python3-pip-whl python3-setuptools-whl python3.10-venv\n",
            "0 upgraded, 3 newly installed, 0 to remove and 49 not upgraded.\n",
            "Need to get 2,474 kB of archives.\n",
            "After this operation, 2,885 kB of additional disk space will be used.\n",
            "Get:1 http://archive.ubuntu.com/ubuntu jammy-updates/universe amd64 python3-pip-whl all 22.0.2+dfsg-1ubuntu0.5 [1,680 kB]\n",
            "Get:2 http://archive.ubuntu.com/ubuntu jammy-updates/universe amd64 python3-setuptools-whl all 59.6.0-1.2ubuntu0.22.04.2 [788 kB]\n",
            "Get:3 http://archive.ubuntu.com/ubuntu jammy-updates/universe amd64 python3.10-venv amd64 3.10.12-1~22.04.7 [5,718 B]\n",
            "Fetched 2,474 kB in 3s (949 kB/s)\n",
            "Selecting previously unselected package python3-pip-whl.\n",
            "(Reading database ... 123632 files and directories currently installed.)\n",
            "Preparing to unpack .../python3-pip-whl_22.0.2+dfsg-1ubuntu0.5_all.deb ...\n",
            "Unpacking python3-pip-whl (22.0.2+dfsg-1ubuntu0.5) ...\n",
            "Selecting previously unselected package python3-setuptools-whl.\n",
            "Preparing to unpack .../python3-setuptools-whl_59.6.0-1.2ubuntu0.22.04.2_all.deb ...\n",
            "Unpacking python3-setuptools-whl (59.6.0-1.2ubuntu0.22.04.2) ...\n",
            "Selecting previously unselected package python3.10-venv.\n",
            "Preparing to unpack .../python3.10-venv_3.10.12-1~22.04.7_amd64.deb ...\n",
            "Unpacking python3.10-venv (3.10.12-1~22.04.7) ...\n",
            "Setting up python3-setuptools-whl (59.6.0-1.2ubuntu0.22.04.2) ...\n",
            "Setting up python3-pip-whl (22.0.2+dfsg-1ubuntu0.5) ...\n",
            "Setting up python3.10-venv (3.10.12-1~22.04.7) ...\n",
            "Python 3.10.12\n",
            "Requirement already satisfied: scikit-learn==1.6.0 in /usr/local/lib/python3.10/dist-packages (1.6.0)\n",
            "Requirement already satisfied: pandas==2.2.2 in /usr/local/lib/python3.10/dist-packages (2.2.2)\n",
            "Requirement already satisfied: matplotlib==3.8.0 in /usr/local/lib/python3.10/dist-packages (3.8.0)\n",
            "Requirement already satisfied: torch==2.5.1+cu121 in /usr/local/lib/python3.10/dist-packages (2.5.1+cu121)\n",
            "Requirement already satisfied: numpy==1.26.4 in /usr/local/lib/python3.10/dist-packages (1.26.4)\n",
            "Requirement already satisfied: scipy>=1.6.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn==1.6.0) (1.13.1)\n",
            "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn==1.6.0) (1.4.2)\n",
            "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn==1.6.0) (3.5.0)\n",
            "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas==2.2.2) (2.8.2)\n",
            "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas==2.2.2) (2024.2)\n",
            "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas==2.2.2) (2024.2)\n",
            "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib==3.8.0) (1.3.1)\n",
            "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib==3.8.0) (0.12.1)\n",
            "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib==3.8.0) (4.55.3)\n",
            "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib==3.8.0) (1.4.7)\n",
            "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib==3.8.0) (24.2)\n",
            "Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib==3.8.0) (11.0.0)\n",
            "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib==3.8.0) (3.2.0)\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch==2.5.1+cu121) (3.16.1)\n",
            "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch==2.5.1+cu121) (4.12.2)\n",
            "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch==2.5.1+cu121) (3.4.2)\n",
            "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch==2.5.1+cu121) (3.1.4)\n",
            "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch==2.5.1+cu121) (2024.10.0)\n",
            "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.10/dist-packages (from torch==2.5.1+cu121) (1.13.1)\n",
            "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy==1.13.1->torch==2.5.1+cu121) (1.3.0)\n",
            "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas==2.2.2) (1.17.0)\n",
            "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch==2.5.1+cu121) (3.0.2)\n",
            "Collecting git+https://github.com/trevorstephens/gplearn.git\n",
            "  Cloning https://github.com/trevorstephens/gplearn.git to /tmp/pip-req-build-9czscjjq\n",
            "  Running command git clone --filter=blob:none --quiet https://github.com/trevorstephens/gplearn.git /tmp/pip-req-build-9czscjjq\n",
            "  Resolved https://github.com/trevorstephens/gplearn.git to commit 64517a85fd6d6c50f9ee9e5599f97458da278951\n",
            "  Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "Requirement already satisfied: scikit-learn>=1.0.2 in /usr/local/lib/python3.10/dist-packages (from gplearn==0.5.dev0) (1.6.0)\n",
            "Requirement already satisfied: joblib>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from gplearn==0.5.dev0) (1.4.2)\n",
            "Requirement already satisfied: numpy>=1.19.5 in /usr/local/lib/python3.10/dist-packages (from scikit-learn>=1.0.2->gplearn==0.5.dev0) (1.26.4)\n",
            "Requirement already satisfied: scipy>=1.6.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn>=1.0.2->gplearn==0.5.dev0) (1.13.1)\n",
            "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn>=1.0.2->gplearn==0.5.dev0) (3.5.0)\n",
            "Building wheels for collected packages: gplearn\n",
            "  Building wheel for gplearn (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for gplearn: filename=gplearn-0.5.dev0-py3-none-any.whl size=25731 sha256=ff7d8ca7361da0321b091ba86987be9c4edf2b1d81694d2b9c963c7105fceef6\n",
            "  Stored in directory: /tmp/pip-ephem-wheel-cache-bul6l3ka/wheels/80/ca/31/ed49dcfa9cebd48e1fc4e025a428c8898845195c5774669b3b\n",
            "Successfully built gplearn\n",
            "Installing collected packages: gplearn\n",
            "Successfully installed gplearn-0.5.dev0\n"
          ]
        }
      ],
      "source": [
        "!pip install git+https://github.com/KindXiaoming/pykan.git\n",
        "!apt-get install python3.10-venv\n",
        "!python3.10 -m venv .venv310\n",
        "!source .venv310/bin/activate\n",
        "!python --version  # 3.10.12\n",
        "!pip install scikit-learn==1.6.0 pandas==2.2.2 matplotlib==3.8.0 torch==2.5.1+cu121 numpy==1.26.4\n",
        "!pip install git+https://github.com/trevorstephens/gplearn.git"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "OtDshUPVgoba"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import pandas as pd\n",
        "import matplotlib.pyplot as plt\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "from kan import *\n",
        "from sklearn.neural_network import MLPClassifier\n",
        "from sklearn.model_selection import train_test_split\n",
        "from sklearn.model_selection import GridSearchCV\n",
        "from sklearn.metrics import accuracy_score, confusion_matrix, recall_score, classification_report\n",
        "from sklearn.preprocessing import StandardScaler\n",
        "from gplearn.genetic import SymbolicRegressor, SymbolicTransformer\n",
        "from sklearn.utils.estimator_checks import check_estimator\n",
        "import time\n",
        "import warnings\n",
        "warnings.filterwarnings(\"ignore\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "viFNCJ0n_4Gw"
      },
      "source": [
        "##Models\n",
        "The models are being called here along with the default number of layers and hyperparameter"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "HpSNDbUSXR56",
        "outputId": "ad3ff95c-4b3a-4402-9e0c-177bdcf55cab"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "checkpoint directory created: ./model\n",
            "saving model version 0.0\n"
          ]
        }
      ],
      "source": [
        "#MLP Model\n",
        "mlp = MLPClassifier()#default parameters\n",
        "#KAN Model\n",
        "kan = KAN(width=[2,2,1])# should be adjusted based on dataset"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RXqFvEH1BNB1"
      },
      "source": [
        "#Functions\n",
        "All the functions to be used for calculating the accuracies, losses, precision, recall, confusion matrix. The function to plot the graph is also written here."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gEgXBgcVMZhE"
      },
      "source": [
        "##Model Functions"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "L7ex_yBiBW3P"
      },
      "outputs": [],
      "source": [
        "def MLP_run(mlp,X_train,y_train,X_test,y_test,epochs):\n",
        "  y_train = torch.tensor(y_train, dtype=torch.float32)\n",
        "  y_test = torch.tensor(y_test, dtype=torch.float32)\n",
        "  train_accuracies,test_accuracies = [],[]\n",
        "  train_losses,test_losses = [],[]\n",
        "  criterion = nn.MSELoss()\n",
        "  start_time = time.time()\n",
        "\n",
        "  for epoch in range(epochs):\n",
        "    mlp.partial_fit(X_train, y_train, classes=np.unique(y_train))\n",
        "    y_pred_train = mlp.predict(X_train)\n",
        "    train_acc = accuracy_score(y_train, torch.tensor(y_pred_train))\n",
        "    train_loss = criterion(torch.tensor(y_pred_train, dtype=torch.float32), torch.tensor(y_train, dtype=torch.float32))\n",
        "    train_accuracies.append(train_acc)\n",
        "    train_losses.append(train_loss.item())\n",
        "    if (epoch + 1) % 10 == 0:\n",
        "      print(f'Epoch [{epoch+1}/{epochs}],MLP Train Loss: {train_loss.item():.4f}, MLP Train Accuracy: {train_acc:.4f}')\n",
        "\n",
        "  for epoch in range(epochs):\n",
        "    y_pred_test = mlp.predict(X_test)\n",
        "    test_acc = accuracy_score(torch.tensor(y_test), torch.tensor(y_pred_test))\n",
        "    test_loss = criterion(torch.tensor(y_pred_test, dtype=torch.float32), torch.tensor(y_test, dtype=torch.float32))\n",
        "    test_accuracies.append(test_acc)\n",
        "    test_losses.append(test_loss.item())\n",
        "    if (epoch + 1) % 10 == 0:\n",
        "      print(f'Epoch [{epoch+1}/{epochs}],MLP Test Loss: {test_loss.item():.4f}, MLP Test Accuracy: {test_acc:.4f}')\n",
        "  end_time = time.time()\n",
        "  MLP_Execution_Time = end_time - start_time\n",
        "  return train_accuracies,test_accuracies,train_losses,test_losses,MLP_Execution_Time\n",
        "\n",
        "\n",
        "def KAN_run(model,X_train,y_train,X_test,y_test,epochs):\n",
        "  y_train = torch.tensor(y_train, dtype=torch.long)\n",
        "  y_test = torch.tensor(y_test, dtype=torch.long)\n",
        "  optimizer = torch.optim.Adam(model.parameters(), lr=0.1)\n",
        "  loss_fn = nn.CrossEntropyLoss()\n",
        "  kan_train_accuracy = []\n",
        "  kan_test_accuracy = []\n",
        "  kan_train_loss = []\n",
        "  kan_test_loss = []\n",
        "  starting_time = time.time()\n",
        "\n",
        "  for epoch in range(epochs): #training\n",
        "    optimizer.zero_grad()\n",
        "    outputs = model(X_train)\n",
        "    predicted = torch.argmax(outputs, dim=1)\n",
        "    train_accuracy = accuracy_score(y_train, predicted)\n",
        "    train_loss = loss_fn(outputs, y_train)\n",
        "    train_loss.backward()\n",
        "    optimizer.step()\n",
        "    kan_train_accuracy.append(train_accuracy)\n",
        "    kan_train_loss.append(train_loss.item())\n",
        "    if (epoch + 1) % 10 == 0:\n",
        "      print(f'Epoch [{epoch+1}/{epochs}],KAN Train Loss: {train_loss.item():.4f}, KAN Train Accuracy: {train_accuracy:.4f}')\n",
        "\n",
        "  for epoch in range(epochs): # testing\n",
        "    predicted = model(x_test_tensor)\n",
        "    test_accuracy = accuracy_score(y_test_tensor, predicted.argmax(1))\n",
        "    test_loss = loss_fn(predicted, y_test_tensor)\n",
        "    kan_test_accuracy.append(test_accuracy)\n",
        "    kan_test_loss.append(test_loss.item())\n",
        "    if (epoch + 1) % 10 == 0:\n",
        "      print(f'Epoch [{epoch+1}/{epochs}],KAN Test Loss: {test_loss.item():.4f},KAN Test Accuracy: {test_accuracy:.4f}')\n",
        "  final_time = time.time()\n",
        "  KAN_Execution_Time = final_time - starting_time\n",
        "  return kan_train_accuracy, kan_test_accuracy, kan_train_loss, kan_test_loss, KAN_Execution_Time"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QuDnSAvzc2lt"
      },
      "source": [
        "##Model Run Function"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "id": "lxvrmcWvC1Ue"
      },
      "outputs": [],
      "source": [
        "def run_models(mlp_model,kan_model,epochs):\n",
        "  mlp_model_run = MLP_run(mlp_model,x_train_tensor,y_train_tensor,x_test_tensor,y_test_tensor,epochs)\n",
        "  mlp_metrics = [mlp_model_run[0],mlp_model_run[1],mlp_model_run[2],mlp_model_run[3]]\n",
        "  print(f'MLP Execution time: {mlp_model_run[4]:.2f} seconds')\n",
        "\n",
        "  print()\n",
        "\n",
        "  kan_model_run = KAN_run(kan_model,x_train_tensor,y_train_tensor,x_test_tensor,y_test_tensor,epochs)\n",
        "  kan_metrics = [kan_model_run[0],kan_model_run[1],kan_model_run[2],kan_model_run[3]]\n",
        "  print(f'KAN Execution time: {kan_model_run[4]:.2f} seconds')\n",
        "  return mlp_metrics,kan_metrics"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hI3eyPDc1gbD"
      },
      "source": [
        "##Plot Function"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "id": "LUsqvwwO1eq7"
      },
      "outputs": [],
      "source": [
        "def plot_comparison_over_epochs(mlp_metrics,kan_metrics, num_epochs):\n",
        "    epochs = range(1, num_epochs + 1)\n",
        "    mlp_train_acc, mlp_test_acc, mlp_train_loss, mlp_test_loss = mlp_metrics\n",
        "    kan_train_acc, kan_test_acc, kan_train_loss, kan_test_loss = kan_metrics\n",
        "    plt.figure(figsize=(10,6))\n",
        "\n",
        "    # Graph of Train Accuracy for both models over number of epochs\n",
        "    plt.subplot(121)\n",
        "    plt.plot(epochs, mlp_train_acc, label='MLP Train Accuracy',color='blue',marker='*')\n",
        "    plt.plot(epochs, kan_train_acc, label='KAN Train Accuracy',color='red',marker='*')\n",
        "    plt.plot(epochs, mlp_test_acc, label='MLP Test Accuracy',color='green',marker='*')\n",
        "    plt.plot(epochs, kan_test_acc, label='KAN Test Accuracy',color='yellow',marker='*')\n",
        "    plt.xlabel('Epochs')\n",
        "    plt.ylabel('Accuracy')\n",
        "    plt.title('MLP and KAN Train and Test Accuracy Over Epochs')\n",
        "    plt.legend()\n",
        "\n",
        "    # Graph of loss of both models during training\n",
        "    plt.subplot(122)\n",
        "    plt.plot(epochs, mlp_train_loss, label='MLP Train Loss',color='blue',marker='*')\n",
        "    plt.plot(epochs, kan_train_loss, label='KAN Train Loss',color='red',marker='*')\n",
        "    plt.ylim(0, 1)\n",
        "    plt.xlabel('Epochs')\n",
        "    plt.ylabel('Loss')\n",
        "    plt.title('Training Loss Over Epochs')\n",
        "    plt.legend()\n",
        "\n",
        "    plt.tight_layout()\n",
        "    plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TMS-nQNQaqwE"
      },
      "source": [
        "##Comparison Criteria Functions"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "id": "O5i0dI9bavhf"
      },
      "outputs": [],
      "source": [
        "def criteria_comparison(mlp,kan,x_test_tensor,y_test_tensor):\n",
        "  mlp_classification= classification_report(y_test_tensor, mlp.predict(x_test_tensor))\n",
        "  kan_classification = classification_report(y_test_tensor, kan(x_test_tensor).argmax(1))\n",
        "  mlp_confusion = confusion_matrix(y_test_tensor, mlp.predict(x_test_tensor))\n",
        "  kan_confusion = confusion_matrix(y_test_tensor, kan(x_test_tensor).argmax(1))\n",
        "  print(f'MLP Classification Report:\\n{mlp_classification}')\n",
        "  print(f'MLP Confusion Matrix:\\n{mlp_confusion}')\n",
        "  print(f'KAN Classification Report:\\n{kan_classification}')\n",
        "  print(f'KAN Confusion Matrix:\\n{kan_confusion}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yNbcf_WhsvNQ"
      },
      "source": [
        "##KAN Hyperparameter Tunning Function"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "id": "RAybot1JQsZQ"
      },
      "outputs": [],
      "source": [
        "class KANWrapper:\n",
        "    def __init__(self, width, grid=3, k=2, noise_scale=0.1, mult_arity=2, seed=123):\n",
        "        self.width = width\n",
        "        self.grid = grid\n",
        "        self.k = k\n",
        "        self.noise_scale = noise_scale\n",
        "        self.mult_arity = mult_arity\n",
        "        self.seed = seed\n",
        "        self.model = KAN(width=width, grid=grid, k=k, noise_scale=noise_scale, mult_arity=mult_arity, seed=seed)\n",
        "    def fit(self, X, y):\n",
        "        self.model.update_grid_from_samples(torch.tensor(X, dtype=torch.float32))  # Assuming X is a NumPy array\n",
        "        return self\n",
        "    def predict(self, X):\n",
        "        X_tensor = torch.tensor(X, dtype=torch.float32)  # Assuming X is a NumPy array\n",
        "        outputs = self.model(X_tensor)\n",
        "        return outputs.argmax(1).numpy()\n",
        "    def get_params(self, deep=True):\n",
        "        return {\"width\": self.width,\n",
        "                \"grid\": self.grid,\n",
        "                \"k\": self.k,\n",
        "                \"noise_scale\": self.noise_scale,\n",
        "                \"mult_arity\": self.mult_arity,\n",
        "                \"seed\": self.seed\n",
        "               }\n",
        "    def set_params(self, **parameters):\n",
        "        for parameter, value in parameters.items():\n",
        "            setattr(self, parameter, value)\n",
        "        self.model = KAN(width=self.width, grid=self.grid, k=self.k, noise_scale=self.noise_scale, mult_arity=self.mult_arity, seed=self.seed)\n",
        "        return self\n",
        "\n",
        "def kan_hyperparameter(kan_tune,x_train_tensor,y_train_tensor):\n",
        "  param_grid_kan = {'grid':[4,5,6],\n",
        "                      'k':[2,3,4],\n",
        "                      'noise_scale':[0.1,0.01,0.001,0.2]}\n",
        "  clfkan = GridSearchCV(kan_tune, param_grid_kan, cv=5, scoring='accuracy')\n",
        "  clfkan.fit(x_train_tensor.numpy(), y_train_tensor.numpy())\n",
        "  return f\"Best parameters: {clfkan.best_params_}\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3wE_1Z3gMOp7"
      },
      "source": [
        "#Dataset\n",
        "The dataset to be used for the analysis is loaded here as well as preparation for the dataset using standard scaler(to make every feature have a mean of 0 and a standard deviation of 1) for faster convergence and imporved performance"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cDe6EHenCQT0"
      },
      "source": [
        "##Load Dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "id": "VLD4BDBHCTRr"
      },
      "outputs": [],
      "source": [
        "url_wisconsin_breast_cancer = 'https://archive.ics.uci.edu/ml/machine-learning-databases/breast-cancer-wisconsin/breast-cancer-wisconsin.data'\n",
        "df = pd.read_csv(url_wisconsin_breast_cancer, header=None)\n",
        "df[10]=df[10].replace(2,0)\n",
        "df[10]=df[10].replace(4,1)\n",
        "x = df.drop([6, 10], axis=1).values\n",
        "y = df[10].values"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qSsKZk7TCafa"
      },
      "source": [
        "##Prepare Dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "id": "bHLgU-5yCZLj"
      },
      "outputs": [],
      "source": [
        "x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.4, random_state=42)\n",
        "# Standardize features\n",
        "scaler = StandardScaler()\n",
        "X_train = scaler.fit_transform(x_train)\n",
        "X_test = scaler.transform(x_test)\n",
        "# Converting data to tensors\n",
        "x_train_tensor = torch.tensor(X_train, dtype=torch.float32)\n",
        "y_train_tensor = torch.tensor(y_train)\n",
        "x_test_tensor = torch.tensor(X_test, dtype=torch.float32)\n",
        "y_test_tensor = torch.tensor(y_test)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WgRQmGitCypQ"
      },
      "source": [
        "# Runing the model\n",
        "Running the models over 100 epochs with different neurons per layer for MLP and KAN\n",
        "\n",
        "Plotting the accuracies and loss over epochs\n",
        "\n",
        "Calculating other comparison criteria\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "id": "QHDiSpNY8RB4"
      },
      "outputs": [],
      "source": [
        "epochs = 100"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hWPr9ED5WqPq"
      },
      "source": [
        "### Model Run with 1 hidden layer of 27 neurons"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "-6EPziiK5kMf",
        "outputId": "284dcc2b-5cce-4bcf-ed6d-3503f44c5041"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Best parameters: {'activation': 'relu', 'alpha': 0.01, 'hidden_layer_sizes': 27, 'learning_rate': 'adaptive', 'learning_rate_init': 0.001, 'solver': 'adam'}\n"
          ]
        }
      ],
      "source": [
        "# Hyperparmeter Tunning for the hidden layer\n",
        "param_grid_mlp_0 = {'hidden_layer_sizes': [(27)],\n",
        "    'activation': ['relu', 'tanh','logistics'],\n",
        "    'alpha':[0.01,0.1],\n",
        "    'learning_rate_init':[0.001,0.01,0.1,1],\n",
        "    'solver':['adam','sgd'],\n",
        "    'learning_rate':['adaptive','constant']}\n",
        "mlp_0 = MLPClassifier(max_iter=1000, random_state=123)\n",
        "clfmlp_0 = GridSearchCV(mlp_0, param_grid_mlp_0, cv=5, scoring='accuracy')\n",
        "clfmlp_0.fit(x_train_tensor.numpy(), y_train_tensor.numpy())\n",
        "print(f\"Best parameters: {clfmlp_0.best_params_}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "metadata": {
        "id": "eBomuvIg5SOE"
      },
      "outputs": [],
      "source": [
        "# # Remove comment to run the code using ctrl + /\n",
        "# kan_tune_0 = KANWrapper(width=[9,27,2], seed=123)\n",
        "# print(kan_hyperparameter( kan_tune_0,x_train_tensor,y_train_tensor))\n",
        "# # Result:  Best parameters: {'grid': 5, 'k': 2, 'noise_scale': 0.001}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "CCNVHkA8lKgp",
        "outputId": "ab34ec91-b406-42bc-a0bf-60a382ee36ee",
        "collapsed": true
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "checkpoint directory created: ./model\n",
            "saving model version 0.0\n",
            "Epoch [10/100],MLP Train Loss: 0.0477, MLP Train Accuracy: 0.9523\n",
            "Epoch [20/100],MLP Train Loss: 0.0430, MLP Train Accuracy: 0.9570\n",
            "Epoch [30/100],MLP Train Loss: 0.0430, MLP Train Accuracy: 0.9570\n",
            "Epoch [40/100],MLP Train Loss: 0.0430, MLP Train Accuracy: 0.9570\n",
            "Epoch [50/100],MLP Train Loss: 0.0430, MLP Train Accuracy: 0.9570\n",
            "Epoch [60/100],MLP Train Loss: 0.0430, MLP Train Accuracy: 0.9570\n",
            "Epoch [70/100],MLP Train Loss: 0.0430, MLP Train Accuracy: 0.9570\n",
            "Epoch [80/100],MLP Train Loss: 0.0430, MLP Train Accuracy: 0.9570\n",
            "Epoch [90/100],MLP Train Loss: 0.0430, MLP Train Accuracy: 0.9570\n",
            "Epoch [100/100],MLP Train Loss: 0.0430, MLP Train Accuracy: 0.9570\n",
            "Epoch [10/100],MLP Test Loss: 0.0286, MLP Test Accuracy: 0.9714\n",
            "Epoch [20/100],MLP Test Loss: 0.0286, MLP Test Accuracy: 0.9714\n",
            "Epoch [30/100],MLP Test Loss: 0.0286, MLP Test Accuracy: 0.9714\n",
            "Epoch [40/100],MLP Test Loss: 0.0286, MLP Test Accuracy: 0.9714\n",
            "Epoch [50/100],MLP Test Loss: 0.0286, MLP Test Accuracy: 0.9714\n",
            "Epoch [60/100],MLP Test Loss: 0.0286, MLP Test Accuracy: 0.9714\n",
            "Epoch [70/100],MLP Test Loss: 0.0286, MLP Test Accuracy: 0.9714\n",
            "Epoch [80/100],MLP Test Loss: 0.0286, MLP Test Accuracy: 0.9714\n",
            "Epoch [90/100],MLP Test Loss: 0.0286, MLP Test Accuracy: 0.9714\n",
            "Epoch [100/100],MLP Test Loss: 0.0286, MLP Test Accuracy: 0.9714\n",
            "MLP Execution time: 1.68 seconds\n",
            "\n",
            "Epoch [10/100],KAN Train Loss: 0.0919, KAN Train Accuracy: 0.9666\n",
            "Epoch [20/100],KAN Train Loss: 0.0125, KAN Train Accuracy: 0.9976\n",
            "Epoch [30/100],KAN Train Loss: 0.0038, KAN Train Accuracy: 0.9976\n",
            "Epoch [40/100],KAN Train Loss: 0.0025, KAN Train Accuracy: 1.0000\n",
            "Epoch [50/100],KAN Train Loss: 0.0028, KAN Train Accuracy: 0.9976\n",
            "Epoch [60/100],KAN Train Loss: 0.0018, KAN Train Accuracy: 1.0000\n",
            "Epoch [70/100],KAN Train Loss: 0.0007, KAN Train Accuracy: 1.0000\n",
            "Epoch [80/100],KAN Train Loss: 0.0002, KAN Train Accuracy: 1.0000\n",
            "Epoch [90/100],KAN Train Loss: 0.0001, KAN Train Accuracy: 1.0000\n",
            "Epoch [100/100],KAN Train Loss: 0.0000, KAN Train Accuracy: 1.0000\n",
            "Epoch [10/100],KAN Test Loss: 0.4149,KAN Test Accuracy: 0.9464\n",
            "Epoch [20/100],KAN Test Loss: 0.4149,KAN Test Accuracy: 0.9464\n",
            "Epoch [30/100],KAN Test Loss: 0.4149,KAN Test Accuracy: 0.9464\n",
            "Epoch [40/100],KAN Test Loss: 0.4149,KAN Test Accuracy: 0.9464\n",
            "Epoch [50/100],KAN Test Loss: 0.4149,KAN Test Accuracy: 0.9464\n",
            "Epoch [60/100],KAN Test Loss: 0.4149,KAN Test Accuracy: 0.9464\n",
            "Epoch [70/100],KAN Test Loss: 0.4149,KAN Test Accuracy: 0.9464\n",
            "Epoch [80/100],KAN Test Loss: 0.4149,KAN Test Accuracy: 0.9464\n",
            "Epoch [90/100],KAN Test Loss: 0.4149,KAN Test Accuracy: 0.9464\n",
            "Epoch [100/100],KAN Test Loss: 0.4149,KAN Test Accuracy: 0.9464\n",
            "KAN Execution time: 24.70 seconds\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 1000x600 with 2 Axes>"
            ],
            "image/png": "\n"
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "MLP Classification Report:\n",
            "              precision    recall  f1-score   support\n",
            "\n",
            "           0       0.98      0.97      0.98       185\n",
            "           1       0.95      0.97      0.96        95\n",
            "\n",
            "    accuracy                           0.97       280\n",
            "   macro avg       0.97      0.97      0.97       280\n",
            "weighted avg       0.97      0.97      0.97       280\n",
            "\n",
            "MLP Confusion Matrix:\n",
            "[[180   5]\n",
            " [  3  92]]\n",
            "KAN Classification Report:\n",
            "              precision    recall  f1-score   support\n",
            "\n",
            "           0       0.95      0.97      0.96       185\n",
            "           1       0.94      0.89      0.92        95\n",
            "\n",
            "    accuracy                           0.95       280\n",
            "   macro avg       0.95      0.93      0.94       280\n",
            "weighted avg       0.95      0.95      0.95       280\n",
            "\n",
            "KAN Confusion Matrix:\n",
            "[[180   5]\n",
            " [ 10  85]]\n"
          ]
        }
      ],
      "source": [
        "mlp_2 = MLPClassifier(hidden_layer_sizes=(27),\n",
        "                    max_iter=1000,\n",
        "                    activation='relu',\n",
        "                    learning_rate='adaptive',\n",
        "                    learning_rate_init=0.001,\n",
        "                    alpha= 0.01,\n",
        "                    solver='adam',\n",
        "                    random_state=123,\n",
        "                    verbose = False)\n",
        "kan_2= KAN(width=[9,27,2],\n",
        "           grid=5,\n",
        "           k=2,\n",
        "           noise_scale=0.001,\n",
        "           seed=123)\n",
        "\n",
        "mlp_metrics_2,kan_metrics_2 = run_models(mlp_2,kan_2,epochs)\n",
        "plot_comparison_over_epochs(mlp_metrics_2, kan_metrics_2,epochs)\n",
        "mlp_2.fit(x_train_tensor, y_train_tensor)\n",
        "criteria_comparison(mlp_2,kan_2,x_test_tensor,y_test_tensor)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kWipDkJG2CPT"
      },
      "source": [
        "### Model Run with 1 hidden layer1 of 64 neurons"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "DClPPVtr7WWh",
        "outputId": "a68ed202-3050-41ba-9248-ad2abb63a9d5"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Best parameters: {'activation': 'relu', 'alpha': 0.01, 'hidden_layer_sizes': 64, 'learning_rate': 'adaptive', 'learning_rate_init': 0.001, 'solver': 'sgd'}\n"
          ]
        }
      ],
      "source": [
        "# Hyperparameter Tunning for the hidden layer\n",
        "param_grid_mlp_1 = {'hidden_layer_sizes': [(64)],\n",
        "    'activation': ['relu', 'tanh','logistics'],\n",
        "    'alpha':[0.01,0.1],\n",
        "    'learning_rate_init':[0.001,0.01,0.1,1],\n",
        "    'solver':['adam','sgd'],\n",
        "    'learning_rate':['adaptive','constant']}\n",
        "mlp_1 = MLPClassifier(max_iter=1000, random_state=123)\n",
        "clfmlp_1 = GridSearchCV(mlp_1, param_grid_mlp_1, cv=5, scoring='accuracy')\n",
        "clfmlp_1.fit(x_train_tensor.numpy(), y_train_tensor.numpy())\n",
        "print(f\"Best parameters: {clfmlp_1.best_params_}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 16,
      "metadata": {
        "id": "bfnuLwAp4xVU"
      },
      "outputs": [],
      "source": [
        "# # Remove comment to run the code using ctrl + /\n",
        "# kan_tune_1 = KANWrapper(width=[9,64,2], seed=123)\n",
        "# print(kan_hyperparameter(kan_tune_1,x_train_tensor,y_train_tensor))\n",
        "# # Result: Best parameters: {'grid': 6, 'k': 2, 'noise_scale': 0.001}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 17,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "8rb3CLxS2Ij4",
        "outputId": "e38149d7-4875-456b-e81f-7186dfdbf846"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "checkpoint directory created: ./model\n",
            "saving model version 0.0\n",
            "Epoch [10/100],MLP Train Loss: 0.4678, MLP Train Accuracy: 0.5322\n",
            "Epoch [20/100],MLP Train Loss: 0.0788, MLP Train Accuracy: 0.9212\n",
            "Epoch [30/100],MLP Train Loss: 0.0477, MLP Train Accuracy: 0.9523\n",
            "Epoch [40/100],MLP Train Loss: 0.0453, MLP Train Accuracy: 0.9547\n",
            "Epoch [50/100],MLP Train Loss: 0.0430, MLP Train Accuracy: 0.9570\n",
            "Epoch [60/100],MLP Train Loss: 0.0430, MLP Train Accuracy: 0.9570\n",
            "Epoch [70/100],MLP Train Loss: 0.0430, MLP Train Accuracy: 0.9570\n",
            "Epoch [80/100],MLP Train Loss: 0.0477, MLP Train Accuracy: 0.9523\n",
            "Epoch [90/100],MLP Train Loss: 0.0453, MLP Train Accuracy: 0.9547\n",
            "Epoch [100/100],MLP Train Loss: 0.0453, MLP Train Accuracy: 0.9547\n",
            "Epoch [10/100],MLP Test Loss: 0.0286, MLP Test Accuracy: 0.9714\n",
            "Epoch [20/100],MLP Test Loss: 0.0286, MLP Test Accuracy: 0.9714\n",
            "Epoch [30/100],MLP Test Loss: 0.0286, MLP Test Accuracy: 0.9714\n",
            "Epoch [40/100],MLP Test Loss: 0.0286, MLP Test Accuracy: 0.9714\n",
            "Epoch [50/100],MLP Test Loss: 0.0286, MLP Test Accuracy: 0.9714\n",
            "Epoch [60/100],MLP Test Loss: 0.0286, MLP Test Accuracy: 0.9714\n",
            "Epoch [70/100],MLP Test Loss: 0.0286, MLP Test Accuracy: 0.9714\n",
            "Epoch [80/100],MLP Test Loss: 0.0286, MLP Test Accuracy: 0.9714\n",
            "Epoch [90/100],MLP Test Loss: 0.0286, MLP Test Accuracy: 0.9714\n",
            "Epoch [100/100],MLP Test Loss: 0.0286, MLP Test Accuracy: 0.9714\n",
            "MLP Execution time: 1.07 seconds\n",
            "\n",
            "Epoch [10/100],KAN Train Loss: 0.0533, KAN Train Accuracy: 0.9928\n",
            "Epoch [20/100],KAN Train Loss: 0.0061, KAN Train Accuracy: 0.9976\n",
            "Epoch [30/100],KAN Train Loss: 0.0031, KAN Train Accuracy: 0.9976\n",
            "Epoch [40/100],KAN Train Loss: 0.0016, KAN Train Accuracy: 1.0000\n",
            "Epoch [50/100],KAN Train Loss: 0.0002, KAN Train Accuracy: 1.0000\n",
            "Epoch [60/100],KAN Train Loss: 0.0000, KAN Train Accuracy: 1.0000\n",
            "Epoch [70/100],KAN Train Loss: 0.0000, KAN Train Accuracy: 1.0000\n",
            "Epoch [80/100],KAN Train Loss: 0.0000, KAN Train Accuracy: 1.0000\n",
            "Epoch [90/100],KAN Train Loss: 0.0000, KAN Train Accuracy: 1.0000\n",
            "Epoch [100/100],KAN Train Loss: 0.0000, KAN Train Accuracy: 1.0000\n",
            "Epoch [10/100],KAN Test Loss: 0.3729,KAN Test Accuracy: 0.9571\n",
            "Epoch [20/100],KAN Test Loss: 0.3729,KAN Test Accuracy: 0.9571\n",
            "Epoch [30/100],KAN Test Loss: 0.3729,KAN Test Accuracy: 0.9571\n",
            "Epoch [40/100],KAN Test Loss: 0.3729,KAN Test Accuracy: 0.9571\n",
            "Epoch [50/100],KAN Test Loss: 0.3729,KAN Test Accuracy: 0.9571\n",
            "Epoch [60/100],KAN Test Loss: 0.3729,KAN Test Accuracy: 0.9571\n",
            "Epoch [70/100],KAN Test Loss: 0.3729,KAN Test Accuracy: 0.9571\n",
            "Epoch [80/100],KAN Test Loss: 0.3729,KAN Test Accuracy: 0.9571\n",
            "Epoch [90/100],KAN Test Loss: 0.3729,KAN Test Accuracy: 0.9571\n",
            "Epoch [100/100],KAN Test Loss: 0.3729,KAN Test Accuracy: 0.9571\n",
            "KAN Execution time: 57.68 seconds\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 1000x600 with 2 Axes>"
            ],
            "image/png": "\n"
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "MLP Classification Report:\n",
            "              precision    recall  f1-score   support\n",
            "\n",
            "           0       0.98      0.97      0.98       185\n",
            "           1       0.95      0.96      0.95        95\n",
            "\n",
            "    accuracy                           0.97       280\n",
            "   macro avg       0.96      0.97      0.96       280\n",
            "weighted avg       0.97      0.97      0.97       280\n",
            "\n",
            "MLP Confusion Matrix:\n",
            "[[180   5]\n",
            " [  4  91]]\n",
            "KAN Classification Report:\n",
            "              precision    recall  f1-score   support\n",
            "\n",
            "           0       0.97      0.97      0.97       185\n",
            "           1       0.94      0.94      0.94        95\n",
            "\n",
            "    accuracy                           0.96       280\n",
            "   macro avg       0.95      0.95      0.95       280\n",
            "weighted avg       0.96      0.96      0.96       280\n",
            "\n",
            "KAN Confusion Matrix:\n",
            "[[179   6]\n",
            " [  6  89]]\n"
          ]
        }
      ],
      "source": [
        "mlp_1 = MLPClassifier(hidden_layer_sizes=(64),\n",
        "                    max_iter=1000,\n",
        "                    activation='relu',\n",
        "                    learning_rate_init=0.001,\n",
        "                    learning_rate='adaptive',\n",
        "                    alpha= 0.01,\n",
        "                    solver='sgd',\n",
        "                    random_state=123,\n",
        "                    verbose = False)\n",
        "kan_1= KAN(width=[9,64,2],\n",
        "           grid=6,\n",
        "           k=2,\n",
        "           noise_scale=0.001,\n",
        "           seed=123)\n",
        "\n",
        "mlp_metrics_1,kan_metrics_1 = run_models(mlp_1,kan_1,epochs)\n",
        "plot_comparison_over_epochs(mlp_metrics_1, kan_metrics_1,epochs)\n",
        "mlp_1.fit(x_train_tensor, y_train_tensor)\n",
        "criteria_comparison(mlp_1,kan_1,x_test_tensor,y_test_tensor)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ogd9G4eXRTXg"
      },
      "source": [
        "### Model Run with 1 hidden layer of 18 neuron"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 18,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "CPe6TU-c9_wu",
        "outputId": "1da9007c-be44-4af5-f680-b4c58e6a259e"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Best parameters: {'activation': 'relu', 'alpha': 0.01, 'hidden_layer_sizes': 18, 'learning_rate': 'adaptive', 'learning_rate_init': 0.001, 'solver': 'sgd'}\n"
          ]
        }
      ],
      "source": [
        "# Hyperparameter Tunning for the hidden layer\n",
        "param_grid_mlp_2 = {'hidden_layer_sizes': [(18)],\n",
        "    'activation': ['relu', 'tanh','logistics'],\n",
        "    'alpha':[0.01,0.1],\n",
        "    'learning_rate_init':[0.001,0.01,0.1,1],\n",
        "    'solver':['adam','sgd'],\n",
        "    'learning_rate':['adaptive','constant']}\n",
        "mlp_2 = MLPClassifier(max_iter=1000, random_state=123)\n",
        "clfmlp_2 = GridSearchCV(mlp_2, param_grid_mlp_2, cv=5, scoring='accuracy')\n",
        "clfmlp_2.fit(x_train_tensor.numpy(), y_train_tensor.numpy())\n",
        "print(f\"Best parameters: {clfmlp_2.best_params_}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 19,
      "metadata": {
        "id": "tL_-GPeT6GMp"
      },
      "outputs": [],
      "source": [
        "# # Remove comment to run the code using ctrl + /\n",
        "# kan_tune_2 = KANWrapper(width=[9,18,2], seed=123)\n",
        "# print(kan_hyperparameter(kan_tune_2,x_train_tensor,y_train_tensor))\n",
        "# # Result:  Best parameters: {'grid': 5, 'k': 2, 'noise_scale': 0.01}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 21,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "M7PNrcYfRYjx",
        "outputId": "4b1aed6d-e1a9-4a91-f05d-441d27d60271"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "checkpoint directory created: ./model\n",
            "saving model version 0.0\n",
            "Epoch [10/100],MLP Train Loss: 0.1002, MLP Train Accuracy: 0.8998\n",
            "Epoch [20/100],MLP Train Loss: 0.0621, MLP Train Accuracy: 0.9379\n",
            "Epoch [30/100],MLP Train Loss: 0.0453, MLP Train Accuracy: 0.9547\n",
            "Epoch [40/100],MLP Train Loss: 0.0477, MLP Train Accuracy: 0.9523\n",
            "Epoch [50/100],MLP Train Loss: 0.0430, MLP Train Accuracy: 0.9570\n",
            "Epoch [60/100],MLP Train Loss: 0.0430, MLP Train Accuracy: 0.9570\n",
            "Epoch [70/100],MLP Train Loss: 0.0430, MLP Train Accuracy: 0.9570\n",
            "Epoch [80/100],MLP Train Loss: 0.0430, MLP Train Accuracy: 0.9570\n",
            "Epoch [90/100],MLP Train Loss: 0.0430, MLP Train Accuracy: 0.9570\n",
            "Epoch [100/100],MLP Train Loss: 0.0430, MLP Train Accuracy: 0.9570\n",
            "Epoch [10/100],MLP Test Loss: 0.0357, MLP Test Accuracy: 0.9643\n",
            "Epoch [20/100],MLP Test Loss: 0.0357, MLP Test Accuracy: 0.9643\n",
            "Epoch [30/100],MLP Test Loss: 0.0357, MLP Test Accuracy: 0.9643\n",
            "Epoch [40/100],MLP Test Loss: 0.0357, MLP Test Accuracy: 0.9643\n",
            "Epoch [50/100],MLP Test Loss: 0.0357, MLP Test Accuracy: 0.9643\n",
            "Epoch [60/100],MLP Test Loss: 0.0357, MLP Test Accuracy: 0.9643\n",
            "Epoch [70/100],MLP Test Loss: 0.0357, MLP Test Accuracy: 0.9643\n",
            "Epoch [80/100],MLP Test Loss: 0.0357, MLP Test Accuracy: 0.9643\n",
            "Epoch [90/100],MLP Test Loss: 0.0357, MLP Test Accuracy: 0.9643\n",
            "Epoch [100/100],MLP Test Loss: 0.0357, MLP Test Accuracy: 0.9643\n",
            "MLP Execution time: 1.69 seconds\n",
            "\n",
            "Epoch [10/100],KAN Train Loss: 0.0842, KAN Train Accuracy: 0.9737\n",
            "Epoch [20/100],KAN Train Loss: 0.0191, KAN Train Accuracy: 0.9952\n",
            "Epoch [30/100],KAN Train Loss: 0.0051, KAN Train Accuracy: 0.9976\n",
            "Epoch [40/100],KAN Train Loss: 0.0064, KAN Train Accuracy: 0.9976\n",
            "Epoch [50/100],KAN Train Loss: 0.0030, KAN Train Accuracy: 0.9976\n",
            "Epoch [60/100],KAN Train Loss: 0.0021, KAN Train Accuracy: 1.0000\n",
            "Epoch [70/100],KAN Train Loss: 0.0020, KAN Train Accuracy: 0.9976\n",
            "Epoch [80/100],KAN Train Loss: 0.0007, KAN Train Accuracy: 1.0000\n",
            "Epoch [90/100],KAN Train Loss: 0.0003, KAN Train Accuracy: 1.0000\n",
            "Epoch [100/100],KAN Train Loss: 0.0002, KAN Train Accuracy: 1.0000\n",
            "Epoch [10/100],KAN Test Loss: 0.2966,KAN Test Accuracy: 0.9571\n",
            "Epoch [20/100],KAN Test Loss: 0.2966,KAN Test Accuracy: 0.9571\n",
            "Epoch [30/100],KAN Test Loss: 0.2966,KAN Test Accuracy: 0.9571\n",
            "Epoch [40/100],KAN Test Loss: 0.2966,KAN Test Accuracy: 0.9571\n",
            "Epoch [50/100],KAN Test Loss: 0.2966,KAN Test Accuracy: 0.9571\n",
            "Epoch [60/100],KAN Test Loss: 0.2966,KAN Test Accuracy: 0.9571\n",
            "Epoch [70/100],KAN Test Loss: 0.2966,KAN Test Accuracy: 0.9571\n",
            "Epoch [80/100],KAN Test Loss: 0.2966,KAN Test Accuracy: 0.9571\n",
            "Epoch [90/100],KAN Test Loss: 0.2966,KAN Test Accuracy: 0.9571\n",
            "Epoch [100/100],KAN Test Loss: 0.2966,KAN Test Accuracy: 0.9571\n",
            "KAN Execution time: 18.61 seconds\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<Figure size 1000x600 with 2 Axes>"
            ],
            "image/png": "\n"
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "MLP Classification Report:\n",
            "              precision    recall  f1-score   support\n",
            "\n",
            "           0       0.98      0.97      0.98       185\n",
            "           1       0.95      0.96      0.95        95\n",
            "\n",
            "    accuracy                           0.97       280\n",
            "   macro avg       0.96      0.97      0.96       280\n",
            "weighted avg       0.97      0.97      0.97       280\n",
            "\n",
            "MLP Confusion Matrix:\n",
            "[[180   5]\n",
            " [  4  91]]\n",
            "KAN Classification Report:\n",
            "              precision    recall  f1-score   support\n",
            "\n",
            "           0       0.97      0.97      0.97       185\n",
            "           1       0.94      0.94      0.94        95\n",
            "\n",
            "    accuracy                           0.96       280\n",
            "   macro avg       0.95      0.95      0.95       280\n",
            "weighted avg       0.96      0.96      0.96       280\n",
            "\n",
            "KAN Confusion Matrix:\n",
            "[[179   6]\n",
            " [  6  89]]\n"
          ]
        }
      ],
      "source": [
        "mlp_2 = MLPClassifier(hidden_layer_sizes=(18),\n",
        "                    max_iter=1000,\n",
        "                    activation='relu',\n",
        "                    learning_rate='adaptive',\n",
        "                    learning_rate_init=0.001,\n",
        "                    alpha= 0.01,\n",
        "                    solver='sgd',\n",
        "                    random_state=123,\n",
        "                    verbose = False)\n",
        "kan_2= KAN(width=[9,18,2],\n",
        "           grid=5,\n",
        "           k=2,\n",
        "           noise_scale=0.01,\n",
        "           seed=123)\n",
        "\n",
        "mlp_metrics_2,kan_metrics_2 = run_models(mlp_2,kan_2,epochs)\n",
        "plot_comparison_over_epochs(mlp_metrics_2, kan_metrics_2,epochs)\n",
        "mlp_2.fit(x_train_tensor, y_train_tensor)\n",
        "criteria_comparison(mlp_2,kan_2,x_test_tensor,y_test_tensor)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [
        "FRVmujAS_5KM",
        "viFNCJ0n_4Gw",
        "RXqFvEH1BNB1",
        "gEgXBgcVMZhE",
        "QuDnSAvzc2lt",
        "hI3eyPDc1gbD",
        "TMS-nQNQaqwE",
        "yNbcf_WhsvNQ",
        "3wE_1Z3gMOp7",
        "cDe6EHenCQT0",
        "qSsKZk7TCafa"
      ],
      "provenance": [],
      "authorship_tag": "ABX9TyNAM1z3A3rzAL7kT81nAKf5",
      "include_colab_link": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}