Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
| Download
"Guiding Future STEM Leaders through Innovative Research Training" ~ thinkingbeyond.education
Project: stephanie's main branch
Path: ThinkingBeyond Activities / BeyondAI-2024-Mentee-Projects / emeka / Wine_Dataset_Code_Implementation.ipynb~1
Views: 1076Image: ubuntu2204
{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "<a href=\"https://colab.research.google.com/github/ThinkingBeyond/BeyondAI-2024/blob/main/emeka/Wine_Dataset_Code_Implementation.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" ] }, { "cell_type": "markdown", "metadata": { "id": "IsMaxq6AVXqf" }, "source": [ "This is a code that compares Kolmogorov Arnold Networks (KANs) with the Multi-Layer Pereptron (MLPs) using the wine dataset\n", "\n", "I aim to use graphical representation to understand how each model converges, the maximum accuracy achieved by each model, the loss each model has while training on the dataset.\n", "\n", "The comparison would also be done using some criterias such as precision,recall,f1 score and confusion matrix" ] }, { "cell_type": "markdown", "metadata": { "id": "FRVmujAS_5KM" }, "source": [ "##Import Necessary Libraries" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "collapsed": true, "id": "BcXEEdI7_JeR", "outputId": "292bab53-281e-4b7a-a15d-fca7e0fb5410" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Collecting git+https://github.com/KindXiaoming/pykan.git\n", " Cloning https://github.com/KindXiaoming/pykan.git to /tmp/pip-req-build-iop8nlv8\n", " Running command git clone --filter=blob:none --quiet https://github.com/KindXiaoming/pykan.git /tmp/pip-req-build-iop8nlv8\n", " Resolved https://github.com/KindXiaoming/pykan.git to commit f871c26d4df788ec1ba309c2c9c1803d82606b06\n", " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", "Reading package lists... Done\n", "Building dependency tree... Done\n", "Reading state information... Done\n", "python3.10-venv is already the newest version (3.10.12-1~22.04.7).\n", "0 upgraded, 0 newly installed, 0 to remove and 49 not upgraded.\n", "Python 3.10.12\n", "Requirement already satisfied: scikit-learn==1.6.0 in /usr/local/lib/python3.10/dist-packages (1.6.0)\n", "Requirement already satisfied: pandas==2.2.2 in /usr/local/lib/python3.10/dist-packages (2.2.2)\n", "Requirement already satisfied: matplotlib==3.8.0 in /usr/local/lib/python3.10/dist-packages (3.8.0)\n", "Requirement already satisfied: torch==2.5.1+cu121 in /usr/local/lib/python3.10/dist-packages (2.5.1+cu121)\n", "Requirement already satisfied: numpy==1.26.4 in /usr/local/lib/python3.10/dist-packages (1.26.4)\n", "Requirement already satisfied: scipy>=1.6.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn==1.6.0) (1.13.1)\n", "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn==1.6.0) (1.4.2)\n", "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn==1.6.0) (3.5.0)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas==2.2.2) (2.8.2)\n", "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas==2.2.2) (2024.2)\n", "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas==2.2.2) (2024.2)\n", "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib==3.8.0) (1.3.1)\n", "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib==3.8.0) (0.12.1)\n", "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib==3.8.0) (4.55.3)\n", "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib==3.8.0) (1.4.7)\n", "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib==3.8.0) (24.2)\n", "Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib==3.8.0) (11.0.0)\n", "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib==3.8.0) (3.2.0)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch==2.5.1+cu121) (3.16.1)\n", "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch==2.5.1+cu121) (4.12.2)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch==2.5.1+cu121) (3.4.2)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch==2.5.1+cu121) (3.1.4)\n", "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch==2.5.1+cu121) (2024.10.0)\n", "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.10/dist-packages (from torch==2.5.1+cu121) (1.13.1)\n", "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy==1.13.1->torch==2.5.1+cu121) (1.3.0)\n", "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas==2.2.2) (1.17.0)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch==2.5.1+cu121) (3.0.2)\n", "Collecting git+https://github.com/trevorstephens/gplearn.git\n", " Cloning https://github.com/trevorstephens/gplearn.git to /tmp/pip-req-build-3byfndk4\n", " Running command git clone --filter=blob:none --quiet https://github.com/trevorstephens/gplearn.git /tmp/pip-req-build-3byfndk4\n", " Resolved https://github.com/trevorstephens/gplearn.git to commit 64517a85fd6d6c50f9ee9e5599f97458da278951\n", " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", "Requirement already satisfied: scikit-learn>=1.0.2 in /usr/local/lib/python3.10/dist-packages (from gplearn==0.5.dev0) (1.6.0)\n", "Requirement already satisfied: joblib>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from gplearn==0.5.dev0) (1.4.2)\n", "Requirement already satisfied: numpy>=1.19.5 in /usr/local/lib/python3.10/dist-packages (from scikit-learn>=1.0.2->gplearn==0.5.dev0) (1.26.4)\n", "Requirement already satisfied: scipy>=1.6.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn>=1.0.2->gplearn==0.5.dev0) (1.13.1)\n", "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn>=1.0.2->gplearn==0.5.dev0) (3.5.0)\n" ] } ], "source": [ "!pip install git+https://github.com/KindXiaoming/pykan.git\n", "!apt-get install python3.10-venv # Install the Python 3.10 virtual environment package\n", "!python3.10 -m venv .venv310 # Create a virtual environment named .venv310 using Python 3.10\n", "!source .venv310/bin/activate\n", "!python --version # 3.10.12\n", "!pip install scikit-learn==1.6.0 pandas==2.2.2 matplotlib==3.8.0 torch==2.5.1+cu121 numpy==1.26.4\n", "!pip install git+https://github.com/trevorstephens/gplearn.git" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "OtDshUPVgoba" }, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import torch\n", "import torch.nn as nn\n", "from kan import *\n", "from sklearn.neural_network import MLPClassifier\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.model_selection import GridSearchCV\n", "from sklearn.metrics import accuracy_score, confusion_matrix, recall_score, classification_report\n", "from sklearn.preprocessing import StandardScaler\n", "from gplearn.genetic import SymbolicRegressor, SymbolicTransformer\n", "from sklearn.utils.estimator_checks import check_estimator\n", "import time\n", "torch.manual_seed(123)\n", "import warnings\n", "warnings.filterwarnings(\"ignore\")" ] }, { "cell_type": "markdown", "metadata": { "id": "viFNCJ0n_4Gw" }, "source": [ "##Models\n", "The models are being called with their default hyperparameter and layers\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "-jPhfJ1l_1sT", "outputId": "8b197afb-2727-4ae0-e08b-563be349613e" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "checkpoint directory created: ./model\n", "saving model version 0.0\n" ] } ], "source": [ "# MLP Model\n", "mlp = MLPClassifier()\n", "# KAN Model\n", "kan = KAN(width=[2,2,1]) # should be adjusted based on dataset" ] }, { "cell_type": "markdown", "metadata": { "id": "RXqFvEH1BNB1" }, "source": [ "#Functions\n", "All the functions to be used for calculating the accuracies, losses, precision, recall, confusion matrix,plotting the graph are written here." ] }, { "cell_type": "markdown", "metadata": { "id": "gEgXBgcVMZhE" }, "source": [ "##Model Functions" ] }, { "cell_type": "code", "execution_count": 50, "metadata": { "id": "xt0grRxSyGU8" }, "outputs": [], "source": [ "def MLP_run(mlp,X_train,y_train,X_test,y_test,epochs):\n", " y_train = torch.tensor(y_train, dtype=torch.float32)\n", " y_test = torch.tensor(y_test, dtype=torch.float32)\n", " train_accuracies,test_accuracies = [],[]\n", " train_losses,test_losses = [],[]\n", " criterion = nn.MSELoss()\n", " start_time = time.time()\n", "\n", " for epoch in range(epochs):\n", " mlp.partial_fit(X_train, y_train, classes=np.unique(y_train))\n", " y_pred_train = mlp.predict(X_train)\n", " train_acc = accuracy_score(y_train, torch.tensor(y_pred_train))\n", " train_loss = criterion(torch.tensor(y_pred_train, dtype=torch.float32), torch.tensor(y_train, dtype=torch.float32))\n", " train_accuracies.append(train_acc)\n", " train_losses.append(train_loss.item())\n", " if (epoch + 1) % 10 == 0:\n", " print(f'Epoch [{epoch+1}/{epochs}],MLP Train Loss: {train_loss.item():.4f}, MLP Train Accuracy: {train_acc:.4f}')\n", "\n", " for epoch in range(epochs):\n", " y_pred_test = mlp.predict(X_test)\n", " test_acc = accuracy_score(torch.tensor(y_test), torch.tensor(y_pred_test))\n", " test_loss = criterion(torch.tensor(y_pred_test, dtype=torch.float32), torch.tensor(y_test, dtype=torch.float32))\n", " test_accuracies.append(test_acc)\n", " test_losses.append(test_loss.item())\n", " if (epoch + 1) % 10 == 0:\n", " print(f'Epoch [{epoch+1}/{epochs}],MLP Test Loss: {test_loss.item():.4f}, MLP Test Accuracy: {test_acc:.4f}')\n", " end_time = time.time()\n", " MLP_Execution_Time = end_time - start_time\n", " return train_accuracies,test_accuracies,train_losses,test_losses,MLP_Execution_Time\n", "\n", "\n", "def KAN_run(model,X_train,y_train,X_test,y_test,epochs):\n", " y_train = torch.tensor(y_train, dtype=torch.long)\n", " y_test = torch.tensor(y_test, dtype=torch.long)\n", " optimizer = torch.optim.Adam(model.parameters(), lr=0.1)\n", " loss_fn = nn.CrossEntropyLoss()\n", " kan_train_accuracy = []\n", " kan_test_accuracy = []\n", " kan_train_loss = []\n", " kan_test_loss = []\n", " starting_time = time.time()\n", "\n", " for epoch in range(epochs): #training\n", " optimizer.zero_grad()\n", " outputs = model(X_train)\n", " predicted = torch.argmax(outputs, dim=1)\n", " train_accuracy = accuracy_score(y_train, predicted)\n", " train_loss = loss_fn(outputs, y_train)\n", " train_loss.backward()\n", " optimizer.step()\n", " kan_train_accuracy.append(train_accuracy)\n", " kan_train_loss.append(train_loss.item())\n", " if (epoch + 1) % 10 == 0:\n", " print(f'Epoch [{epoch+1}/{epochs}],KAN Train Loss: {train_loss.item():.4f}, KAN Train Accuracy: {train_accuracy:.4f}')\n", "\n", " for epoch in range(epochs): # testing\n", " predicted = model(x_test_tensor)\n", " test_accuracy = accuracy_score(y_test, predicted.argmax(1))\n", " test_loss = loss_fn(predicted, y_test)\n", " kan_test_accuracy.append(test_accuracy)\n", " kan_test_loss.append(test_loss.item())\n", " if (epoch + 1) % 10 == 0:\n", " print(f'Epoch [{epoch+1}/{epochs}],KAN Test Loss: {test_loss.item():.4f},KAN Test Accuracy: {test_accuracy:.4f}')\n", " final_time = time.time()\n", " KAN_Execution_Time = final_time - starting_time\n", " return kan_train_accuracy, kan_test_accuracy, kan_train_loss, kan_test_loss, KAN_Execution_Time" ] }, { "cell_type": "markdown", "metadata": { "id": "QuDnSAvzc2lt" }, "source": [ "##Model Run Function" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "lxvrmcWvC1Ue" }, "outputs": [], "source": [ "def run_models(mlp_model,kan_model,epochs):\n", " mlp_model_run = MLP_run(mlp_model,x_train_tensor,y_train_tensor,x_test_tensor,y_test_tensor,epochs)\n", " mlp_metrics = [mlp_model_run[0],mlp_model_run[1],mlp_model_run[2],mlp_model_run[3]]\n", " print(f'MLP Execution time: {mlp_model_run[4]:.2f} seconds')\n", "\n", " kan_model_run = KAN_run(kan_model,x_train_tensor,y_train_long,x_test_tensor,y_test_long,epochs)\n", " kan_metrics = [kan_model_run[0],kan_model_run[1],kan_model_run[2],kan_model_run[3]]\n", " print(f'KAN Execution time: {kan_model_run[4]:.2f} seconds')\n", " return mlp_metrics,kan_metrics" ] }, { "cell_type": "markdown", "metadata": { "id": "hI3eyPDc1gbD" }, "source": [ "##Plot Function" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "LUsqvwwO1eq7" }, "outputs": [], "source": [ "def plot_comparison_over_epochs(mlp_metrics,kan_metrics, num_epochs):\n", " epochs = range(1, num_epochs + 1)\n", " mlp_train_acc, mlp_test_acc, mlp_train_loss, mlp_test_loss = mlp_metrics\n", " kan_train_acc, kan_test_acc, kan_train_loss, kan_test_loss = kan_metrics\n", " plt.figure(figsize=(10,5))\n", "\n", " # Graph of Train Accuracy for both models over number of epochs\n", " plt.subplot(121)\n", " plt.plot(epochs, mlp_train_acc, label='MLP Train Accuracy',color='blue',marker='x')\n", " plt.plot(epochs, kan_train_acc, label='KAN Train Accuracy',color='red',marker='x')\n", " plt.plot(epochs, mlp_test_acc, label='MLP Test Accuracy',color='green',marker='x')\n", " plt.plot(epochs, kan_test_acc, label='KAN Test Accuracy',color='yellow',marker='x')\n", " plt.ylim(0.5, 1)\n", " plt.xlabel('Epochs')\n", " plt.ylabel('Accuracy')\n", " plt.title('MLP and KAN Accuracy Over Epochs')\n", " plt.legend()\n", "\n", " # Graph of loss of both models during training\n", " plt.subplot(122)\n", " plt.plot(epochs, mlp_train_loss, label='MLP Train Loss',color='blue',marker='x')\n", " plt.plot(epochs, kan_train_loss, label='KAN Train Loss',color='red',marker='x')\n", " plt.ylim(0, 1)\n", " plt.xlabel('Epochs')\n", " plt.ylabel('Loss')\n", " plt.title('Training Loss Over Epochs')\n", " plt.legend()\n", "\n", " plt.tight_layout()\n", " plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "TMS-nQNQaqwE" }, "source": [ "##Comparison Criteria Functions" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "O5i0dI9bavhf" }, "outputs": [], "source": [ "def criteria_comparison(mlp,kan,x_test_tensor,y_test_tensor):\n", " y_pred_mlp = mlp.predict(x_test_tensor)\n", " y_pred_kan = kan(x_test_tensor).argmax(1)\n", " mlp_classification= classification_report(y_test_tensor, y_pred_mlp)\n", " kan_classification = classification_report(y_test_long, y_pred_kan)\n", " mlp_confusion = confusion_matrix(y_test_tensor, y_pred_mlp)\n", " kan_confusion = confusion_matrix(y_test_long, y_pred_kan)\n", " print(f'MLP Classification Report:\\n{mlp_classification}')\n", " print(f'MLP Confusion Matrix:\\n{mlp_confusion}')\n", " print(f'KAN Classification Report:\\n{kan_classification}')\n", " print(f'KAN Confusion Matrix:\\n{kan_confusion}')" ] }, { "cell_type": "markdown", "metadata": { "id": "yNbcf_WhsvNQ" }, "source": [ "##KAN Hyperparameter Tunning Function" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "id": "RAybot1JQsZQ" }, "outputs": [], "source": [ "class KANWrapper:\n", " def __init__(self, width, grid=3, k=2, noise_scale=0.1, mult_arity=2, seed=123):\n", " self.width = width\n", " self.grid = grid\n", " self.k = k\n", " self.noise_scale = noise_scale\n", " self.mult_arity = mult_arity\n", " self.seed = seed\n", " self.model = KAN(width=width, grid=grid, k=k, noise_scale=noise_scale, mult_arity=mult_arity, seed=seed)\n", " def fit(self, X, y):\n", " self.model.update_grid_from_samples(torch.tensor(X, dtype=torch.float32)) # Assuming X is a NumPy array\n", " return self\n", " def predict(self, X):\n", " X_tensor = torch.tensor(X, dtype=torch.float32) # Assuming X is a NumPy array\n", " outputs = self.model(X_tensor)\n", " return outputs.argmax(1).numpy()\n", " def get_params(self, deep=True):\n", " return {\"width\": self.width,\n", " \"grid\": self.grid,\n", " \"k\": self.k,\n", " \"noise_scale\": self.noise_scale,\n", " \"mult_arity\": self.mult_arity,\n", " \"seed\": self.seed\n", " }\n", " def set_params(self, **parameters):\n", " for parameter, value in parameters.items():\n", " setattr(self, parameter, value)\n", " self.model = KAN(width=self.width, grid=self.grid, k=self.k, noise_scale=self.noise_scale, mult_arity=self.mult_arity, seed=self.seed)\n", " return self\n", "\n", "def kan_hyperparameter(kan_tune,x_train_tensor,y_train_tensor):\n", " param_grid_kan0 = {'grid':[4,5,6],\n", " 'k':[2,3,4],\n", " 'noise_scale':[0.0001,0.01,0.1,0.001]}\n", " clfkan = GridSearchCV(kan_tune, param_grid_kan0, cv=5, scoring='accuracy')\n", " clfkan.fit(x_train_tensor.numpy(), y_train_tensor.numpy())\n", " return f\"Best parameters: {clfkan.best_params_}\"" ] }, { "cell_type": "markdown", "metadata": { "id": "3wE_1Z3gMOp7" }, "source": [ "#Dataset\n", "The dataset to be used for the analysis is loaded here as well as preparation for the dataset using standard scaler(to make every feature have a mean of 0 and a standard deviation of 1) for faster convergence and imporved performance" ] }, { "cell_type": "markdown", "metadata": { "id": "cDe6EHenCQT0" }, "source": [ "##Load Dataset" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "id": "VLD4BDBHCTRr" }, "outputs": [], "source": [ "url_wine = 'https://gist.githubusercontent.com/tijptjik/9408623/raw/b237fa5848349a14a14e5d4107dc7897c21951f5/wine.csv'\n", "df = pd.read_csv(url_wine)\n", "df['Wine'] = df['Wine'].replace({1: 0, 2: 1, 3: 2})\n", "x = df.drop('Wine', axis=1).values\n", "y = df['Wine'].values" ] }, { "cell_type": "markdown", "metadata": { "id": "qSsKZk7TCafa" }, "source": [ "##Prepare Dataset" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "id": "bHLgU-5yCZLj" }, "outputs": [], "source": [ "x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.4, random_state=42)\n", "# Standardize features\n", "scaler = StandardScaler()\n", "X_train = scaler.fit_transform(x_train)\n", "X_test = scaler.transform(x_test)\n", "# Converting data to tensors\n", "x_train_tensor = torch.tensor(X_train, dtype=torch.float32)\n", "y_train_tensor = torch.tensor(y_train, dtype=torch.float32)\n", "y_train_long = torch.tensor(y_train, dtype=torch.long)\n", "x_test_tensor = torch.tensor(X_test, dtype=torch.float32)\n", "y_test_tensor = torch.tensor(y_test, dtype=torch.float32)\n", "y_test_long = torch.tensor(y_test, dtype=torch.long)" ] }, { "cell_type": "markdown", "metadata": { "id": "6nF2UeWRHF6_" }, "source": [ "# Model Run\n", "Running the models over 100 epochs with different hidden layers for MLP and KAN with that of KAN being lower than MLP at all points\n", "\n", "Plotting the accuracies and loss over epochs\n", "\n", "Calculating other comparison criteria like the classification report and confusion matrix" ] }, { "cell_type": "code", "execution_count": 54, "metadata": { "id": "4cb1C0sFySfW" }, "outputs": [], "source": [ "epochs = 100" ] }, { "cell_type": "markdown", "metadata": { "id": "WgRQmGitCypQ" }, "source": [ "## Model Run with 2 hidden layers of 27 and 10 neurons respectively\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Kb8rWnAAqRH3", "outputId": "40bf9a37-9b59-4c05-a7bb-e0198eb7c51e" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Best parameters: {'activation': 'tanh', 'alpha': 0.001, 'hidden_layer_sizes': (27, 10), 'learning_rate': 'adaptive', 'learning_rate_init': 0.1, 'solver': 'adam'}\n" ] } ], "source": [ "# mlp hyperparameter tunning for mlp0\n", "param_grid_mlp0 = {\n", " 'hidden_layer_sizes': [(27,10)],\n", " 'activation': ['relu', 'tanh'],\n", " 'alpha':[0.001,0.01],\n", " 'solver':['adam','sgd'],\n", " 'learning_rate_init':[0.1,0.00001],\n", " 'learning_rate':['adaptive','constant']}\n", "mlp_tune0 = MLPClassifier(max_iter=1000, random_state=123)\n", "clfmlp0 = GridSearchCV(mlp_tune0, param_grid_mlp0, cv=5, scoring='accuracy')\n", "clfmlp0.fit(x_train_tensor.numpy(), y_train_tensor.numpy())\n", "print(f\"Best parameters: {clfmlp0.best_params_}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mhOCH9QGJYPh" }, "outputs": [], "source": [ "# Remove comment to run the code using ctrl + /\n", "kan_tune0 = KANWrapper(width=[13,27,10,3], seed=123)\n", "print(kan_hyperparameter(kan_tune0,x_train_tensor,y_train_tensor))\n", "# Result: Best parameters: {'grid': 5, 'k': 2, 'noise_scale': 0.0001}" ] }, { "cell_type": "code", "execution_count": 60, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "FjZENRdneG24", "outputId": "4d5e8574-68bd-4e5d-bdf9-75f062094fdc" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "checkpoint directory created: ./model\n", "saving model version 0.0\n", "Epoch [10/100],MLP Train Loss: 0.0000, MLP Train Accuracy: 1.0000\n", "Epoch [20/100],MLP Train Loss: 0.0000, MLP Train Accuracy: 1.0000\n", "Epoch [30/100],MLP Train Loss: 0.0000, MLP Train Accuracy: 1.0000\n", "Epoch [40/100],MLP Train Loss: 0.0000, MLP Train Accuracy: 1.0000\n", "Epoch [50/100],MLP Train Loss: 0.0000, MLP Train Accuracy: 1.0000\n", "Epoch [60/100],MLP Train Loss: 0.0000, MLP Train Accuracy: 1.0000\n", "Epoch [70/100],MLP Train Loss: 0.0000, MLP Train Accuracy: 1.0000\n", "Epoch [80/100],MLP Train Loss: 0.0000, MLP Train Accuracy: 1.0000\n", "Epoch [90/100],MLP Train Loss: 0.0000, MLP Train Accuracy: 1.0000\n", "Epoch [100/100],MLP Train Loss: 0.0000, MLP Train Accuracy: 1.0000\n", "Epoch [10/100],MLP Test Loss: 0.0139, MLP Test Accuracy: 0.9861\n", "Epoch [20/100],MLP Test Loss: 0.0139, MLP Test Accuracy: 0.9861\n", "Epoch [30/100],MLP Test Loss: 0.0139, MLP Test Accuracy: 0.9861\n", "Epoch [40/100],MLP Test Loss: 0.0139, MLP Test Accuracy: 0.9861\n", "Epoch [50/100],MLP Test Loss: 0.0139, MLP Test Accuracy: 0.9861\n", "Epoch [60/100],MLP Test Loss: 0.0139, MLP Test Accuracy: 0.9861\n", "Epoch [70/100],MLP Test Loss: 0.0139, MLP Test Accuracy: 0.9861\n", "Epoch [80/100],MLP Test Loss: 0.0139, MLP Test Accuracy: 0.9861\n", "Epoch [90/100],MLP Test Loss: 0.0139, MLP Test Accuracy: 0.9861\n", "Epoch [100/100],MLP Test Loss: 0.0139, MLP Test Accuracy: 0.9861\n", "MLP Execution time: 0.69 seconds\n", "Epoch [10/100],KAN Train Loss: 0.0010, KAN Train Accuracy: 1.0000\n", "Epoch [20/100],KAN Train Loss: 0.0000, KAN Train Accuracy: 1.0000\n", "Epoch [30/100],KAN Train Loss: 0.0000, KAN Train Accuracy: 1.0000\n", "Epoch [40/100],KAN Train Loss: 0.0000, KAN Train Accuracy: 1.0000\n", "Epoch [50/100],KAN Train Loss: 0.0000, KAN Train Accuracy: 1.0000\n", "Epoch [60/100],KAN Train Loss: 0.0000, KAN Train Accuracy: 1.0000\n", "Epoch [70/100],KAN Train Loss: 0.0000, KAN Train Accuracy: 1.0000\n", "Epoch [80/100],KAN Train Loss: 0.0000, KAN Train Accuracy: 1.0000\n", "Epoch [90/100],KAN Train Loss: 0.0000, KAN Train Accuracy: 1.0000\n", "Epoch [100/100],KAN Train Loss: 0.0000, KAN Train Accuracy: 1.0000\n", "Epoch [10/100],KAN Test Loss: 3.9198,KAN Test Accuracy: 0.9167\n", "Epoch [20/100],KAN Test Loss: 3.9198,KAN Test Accuracy: 0.9167\n", "Epoch [30/100],KAN Test Loss: 3.9198,KAN Test Accuracy: 0.9167\n", "Epoch [40/100],KAN Test Loss: 3.9198,KAN Test Accuracy: 0.9167\n", "Epoch [50/100],KAN Test Loss: 3.9198,KAN Test Accuracy: 0.9167\n", "Epoch [60/100],KAN Test Loss: 3.9198,KAN Test Accuracy: 0.9167\n", "Epoch [70/100],KAN Test Loss: 3.9198,KAN Test Accuracy: 0.9167\n", "Epoch [80/100],KAN Test Loss: 3.9198,KAN Test Accuracy: 0.9167\n", "Epoch [90/100],KAN Test Loss: 3.9198,KAN Test Accuracy: 0.9167\n", "Epoch [100/100],KAN Test Loss: 3.9198,KAN Test Accuracy: 0.9167\n", "KAN Execution time: 30.04 seconds\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ "<Figure size 1000x500 with 2 Axes>" ], "image/png": "\n" }, "metadata": {} }, { "output_type": "stream", "name": "stdout", "text": [ "MLP Classification Report:\n", " precision recall f1-score support\n", "\n", " 0.0 1.00 1.00 1.00 26\n", " 1.0 1.00 0.96 0.98 27\n", " 2.0 0.95 1.00 0.97 19\n", "\n", " accuracy 0.99 72\n", " macro avg 0.98 0.99 0.99 72\n", "weighted avg 0.99 0.99 0.99 72\n", "\n", "MLP Confusion Matrix:\n", "[[26 0 0]\n", " [ 0 26 1]\n", " [ 0 0 19]]\n", "KAN Classification Report:\n", " precision recall f1-score support\n", "\n", " 0 0.84 1.00 0.91 26\n", " 1 1.00 0.85 0.92 27\n", " 2 0.94 0.89 0.92 19\n", "\n", " accuracy 0.92 72\n", " macro avg 0.93 0.92 0.92 72\n", "weighted avg 0.93 0.92 0.92 72\n", "\n", "KAN Confusion Matrix:\n", "[[26 0 0]\n", " [ 3 23 1]\n", " [ 2 0 17]]\n" ] } ], "source": [ "# MLP Model\n", "mlp_0 = MLPClassifier(hidden_layer_sizes=(27,10),\n", " max_iter= 1000,\n", " activation='tanh',\n", " learning_rate_init= 0.1,\n", " learning_rate='adaptive',\n", " alpha= 0.001,\n", " solver='adam',\n", " random_state=123,\n", " verbose = False)\n", "\n", "# KAN Model\n", "kan_0 = KAN(width=[13,27,10,3], grid=5, k=2, noise_scale=0.0001,seed=123)\n", "mlp_metrics_0,kan_metrics_0 = run_models(mlp_0,kan_0,epochs)\n", "plot_comparison_over_epochs(mlp_metrics_0, kan_metrics_0,epochs)\n", "mlp_0.fit(x_train_tensor, y_train_tensor)\n", "criteria_comparison(mlp_0,kan_0,x_test_tensor,y_test_tensor)" ] }, { "cell_type": "markdown", "metadata": { "id": "YbpyCTy8yB21" }, "source": [ "## Model Run with 2 hidden layers of 40 and 15 neurons respectively" ] }, { "cell_type": "code", "execution_count": 44, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "up71XoAz3s7c", "outputId": "9f04628c-d95c-4ebc-c131-a00bfe7d0380" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Best parameters: {'activation': 'tanh', 'alpha': 0.1, 'hidden_layer_sizes': (40, 15), 'learning_rate': 'adaptive', 'learning_rate_init': 0.01, 'solver': 'adam'}\n" ] } ], "source": [ "# mlp hyperparameter tunning for mlp1\n", "param_grid_mlp1 = {\n", " 'hidden_layer_sizes': [(40,15)],\n", " 'activation': ['relu', 'tanh'],\n", " 'alpha':[0.001,0.01,0.1,1],\n", " 'solver':['adam','sgd'],\n", " 'learning_rate_init':[0.01,0.0001],\n", " 'learning_rate':['adaptive','constant']}\n", "mlp_tune1 = MLPClassifier(max_iter=1000, random_state=123)\n", "clfmlp1 = GridSearchCV(mlp_tune1, param_grid_mlp1, cv=5, scoring='accuracy')\n", "clfmlp1.fit(x_train_tensor.numpy(), y_train_tensor.numpy())\n", "print(f\"Best parameters: {clfmlp1.best_params_}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Dd4u2ScZaWo4" }, "outputs": [], "source": [ "# # Remove comment to run the code using ctrl + /\n", "# kan_tune1 = KANWrapper(width=[13,40,15,3], seed=123)\n", "# print(kan_hyperparameter(kan_tune1,x_train_tensor,y_train_tensor))\n", "# # Results: Best parameters: {'grid': 6, 'k': 2, 'noise_scale': 0.0001}" ] }, { "cell_type": "code", "execution_count": 43, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "8rb3CLxS2Ij4", "outputId": "22b9e67e-495f-469f-f5a7-7d7a414c5607" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "checkpoint directory created: ./model\n", "saving model version 0.0\n", "Epoch [10/100],MLP Train Loss: 0.0283, MLP Train Accuracy: 0.9717\n", "Epoch [20/100],MLP Train Loss: 0.0000, MLP Train Accuracy: 1.0000\n", "Epoch [30/100],MLP Train Loss: 0.0000, MLP Train Accuracy: 1.0000\n", "Epoch [40/100],MLP Train Loss: 0.0000, MLP Train Accuracy: 1.0000\n", "Epoch [50/100],MLP Train Loss: 0.0000, MLP Train Accuracy: 1.0000\n", "Epoch [60/100],MLP Train Loss: 0.0000, MLP Train Accuracy: 1.0000\n", "Epoch [70/100],MLP Train Loss: 0.0000, MLP Train Accuracy: 1.0000\n", "Epoch [80/100],MLP Train Loss: 0.0000, MLP Train Accuracy: 1.0000\n", "Epoch [90/100],MLP Train Loss: 0.0000, MLP Train Accuracy: 1.0000\n", "Epoch [100/100],MLP Train Loss: 0.0000, MLP Train Accuracy: 1.0000\n", "Epoch [10/100],MLP Test Loss: 0.0139, MLP Test Accuracy: 0.9861\n", "Epoch [20/100],MLP Test Loss: 0.0139, MLP Test Accuracy: 0.9861\n", "Epoch [30/100],MLP Test Loss: 0.0139, MLP Test Accuracy: 0.9861\n", "Epoch [40/100],MLP Test Loss: 0.0139, MLP Test Accuracy: 0.9861\n", "Epoch [50/100],MLP Test Loss: 0.0139, MLP Test Accuracy: 0.9861\n", "Epoch [60/100],MLP Test Loss: 0.0139, MLP Test Accuracy: 0.9861\n", "Epoch [70/100],MLP Test Loss: 0.0139, MLP Test Accuracy: 0.9861\n", "Epoch [80/100],MLP Test Loss: 0.0139, MLP Test Accuracy: 0.9861\n", "Epoch [90/100],MLP Test Loss: 0.0139, MLP Test Accuracy: 0.9861\n", "Epoch [100/100],MLP Test Loss: 0.0139, MLP Test Accuracy: 0.9861\n", "MLP Execution time: 0.42 seconds\n", "Epoch [10/100],KAN Train Loss: 0.0000, KAN Train Accuracy: 1.0000\n", "Epoch [20/100],KAN Train Loss: 0.0000, KAN Train Accuracy: 1.0000\n", "Epoch [30/100],KAN Train Loss: 0.0000, KAN Train Accuracy: 1.0000\n", "Epoch [40/100],KAN Train Loss: 0.0000, KAN Train Accuracy: 1.0000\n", "Epoch [50/100],KAN Train Loss: 0.0000, KAN Train Accuracy: 1.0000\n", "Epoch [60/100],KAN Train Loss: 0.0000, KAN Train Accuracy: 1.0000\n", "Epoch [70/100],KAN Train Loss: 0.0000, KAN Train Accuracy: 1.0000\n", "Epoch [80/100],KAN Train Loss: 0.0000, KAN Train Accuracy: 1.0000\n", "Epoch [90/100],KAN Train Loss: 0.0000, KAN Train Accuracy: 1.0000\n", "Epoch [100/100],KAN Train Loss: 0.0000, KAN Train Accuracy: 1.0000\n", "Epoch [10/100],KAN Test Loss: 1.1705,KAN Test Accuracy: 0.9722\n", "Epoch [20/100],KAN Test Loss: 1.1705,KAN Test Accuracy: 0.9722\n", "Epoch [30/100],KAN Test Loss: 1.1705,KAN Test Accuracy: 0.9722\n", "Epoch [40/100],KAN Test Loss: 1.1705,KAN Test Accuracy: 0.9722\n", "Epoch [50/100],KAN Test Loss: 1.1705,KAN Test Accuracy: 0.9722\n", "Epoch [60/100],KAN Test Loss: 1.1705,KAN Test Accuracy: 0.9722\n", "Epoch [70/100],KAN Test Loss: 1.1705,KAN Test Accuracy: 0.9722\n", "Epoch [80/100],KAN Test Loss: 1.1705,KAN Test Accuracy: 0.9722\n", "Epoch [90/100],KAN Test Loss: 1.1705,KAN Test Accuracy: 0.9722\n", "Epoch [100/100],KAN Test Loss: 1.1705,KAN Test Accuracy: 0.9722\n", "KAN Execution time: 53.67 seconds\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ "<Figure size 1000x500 with 2 Axes>" ], "image/png": "\n" }, "metadata": {} }, { "output_type": "stream", "name": "stdout", "text": [ "MLP Classification Report:\n", " precision recall f1-score support\n", "\n", " 0.0 0.96 1.00 0.98 26\n", " 1.0 1.00 0.96 0.98 27\n", " 2.0 1.00 1.00 1.00 19\n", "\n", " accuracy 0.99 72\n", " macro avg 0.99 0.99 0.99 72\n", "weighted avg 0.99 0.99 0.99 72\n", "\n", "MLP Confusion Matrix:\n", "[[26 0 0]\n", " [ 1 26 0]\n", " [ 0 0 19]]\n", "KAN Classification Report:\n", " precision recall f1-score support\n", "\n", " 0 0.93 1.00 0.96 26\n", " 1 1.00 0.96 0.98 27\n", " 2 1.00 0.95 0.97 19\n", "\n", " accuracy 0.97 72\n", " macro avg 0.98 0.97 0.97 72\n", "weighted avg 0.97 0.97 0.97 72\n", "\n", "KAN Confusion Matrix:\n", "[[26 0 0]\n", " [ 1 26 0]\n", " [ 1 0 18]]\n" ] } ], "source": [ "mlp_1 = MLPClassifier(hidden_layer_sizes=(40,15),\n", " max_iter=1000,\n", " activation='tanh',\n", " learning_rate_init= 0.01,\n", " learning_rate='adaptive',\n", " alpha= 0.1,\n", " solver='adam',\n", " random_state=123,\n", " verbose = False)\n", "\n", "\n", "kan_1= KAN(width=[13,40,15,3], grid=6, k=2, noise_scale=0.0001, seed=123)\n", "kan_1.update_grid_from_samples(x_train_tensor)\n", "num_epochs = 100\n", "mlp_metrics_1,kan_metrics_1 = run_models(mlp_1,kan_1,num_epochs)\n", "plot_comparison_over_epochs(mlp_metrics_1, kan_metrics_1,num_epochs)\n", "mlp_1.fit(x_train_tensor, y_train_tensor)\n", "criteria_comparison(mlp_1,kan_1,x_test_tensor,y_test_tensor)" ] }, { "cell_type": "markdown", "metadata": { "id": "N3yuk8W21J-u" }, "source": [ "## Model Run with 2 hidden layers of 15 and 7 neurons respectively" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "OJk5Cm-n1stv", "outputId": "aca6c902-57f9-43d9-a585-696e472b050c" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Best parameters: {'activation': 'relu', 'alpha': 1, 'hidden_layer_sizes': (15, 7), 'learning_rate': 'adaptive', 'learning_rate_init': 0.1, 'solver': 'adam'}\n" ] } ], "source": [ "# mlp hyperparameter tunning for mlp2\n", "param_grid_mlp2 = {\n", " 'hidden_layer_sizes': [(15,7)],\n", " 'activation': ['relu', 'tanh'],\n", " 'alpha':[0.001,0.01,0.1,1],\n", " 'solver':['adam','sgd'],\n", " 'learning_rate_init':[0.01,0.001,0.1],\n", " 'learning_rate':['adaptive','constant']}\n", "mlp_tune2 = MLPClassifier(max_iter=1000, random_state=123)\n", "clfmlp2 = GridSearchCV(mlp_tune2, param_grid_mlp2, cv=5, scoring='accuracy')\n", "clfmlp2.fit(x_train_tensor.numpy(), y_train_tensor.numpy())\n", "print(f\"Best parameters: {clfmlp2.best_params_}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sJy1xDM3t1tX" }, "outputs": [], "source": [ "# # Remove comment to run the code using ctrl + /\n", "# kan_tune2 = KANWrapper(width=[13,15,7,3], seed=123)\n", "# print(kan_hyperparameter(kan_tune2,x_train_tensor,y_train_tensor))\n", "# # Result: Best parameters: {'grid': 6, 'k': 3, 'noise_scale': 0.1}" ] }, { "cell_type": "code", "execution_count": 62, "metadata": { "id": "sy6SpkkW1RaN", "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "outputId": "8c764c2c-d53e-44c1-b9ad-2bd84710d10d" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "checkpoint directory created: ./model\n", "saving model version 0.0\n", "Epoch [10/100],MLP Train Loss: 0.0000, MLP Train Accuracy: 1.0000\n", "Epoch [20/100],MLP Train Loss: 0.0000, MLP Train Accuracy: 1.0000\n", "Epoch [30/100],MLP Train Loss: 0.0000, MLP Train Accuracy: 1.0000\n", "Epoch [40/100],MLP Train Loss: 0.0000, MLP Train Accuracy: 1.0000\n", "Epoch [50/100],MLP Train Loss: 0.0000, MLP Train Accuracy: 1.0000\n", "Epoch [60/100],MLP Train Loss: 0.0000, MLP Train Accuracy: 1.0000\n", "Epoch [70/100],MLP Train Loss: 0.0000, MLP Train Accuracy: 1.0000\n", "Epoch [80/100],MLP Train Loss: 0.0000, MLP Train Accuracy: 1.0000\n", "Epoch [90/100],MLP Train Loss: 0.0000, MLP Train Accuracy: 1.0000\n", "Epoch [100/100],MLP Train Loss: 0.0000, MLP Train Accuracy: 1.0000\n", "Epoch [10/100],MLP Test Loss: 0.0139, MLP Test Accuracy: 0.9861\n", "Epoch [20/100],MLP Test Loss: 0.0139, MLP Test Accuracy: 0.9861\n", "Epoch [30/100],MLP Test Loss: 0.0139, MLP Test Accuracy: 0.9861\n", "Epoch [40/100],MLP Test Loss: 0.0139, MLP Test Accuracy: 0.9861\n", "Epoch [50/100],MLP Test Loss: 0.0139, MLP Test Accuracy: 0.9861\n", "Epoch [60/100],MLP Test Loss: 0.0139, MLP Test Accuracy: 0.9861\n", "Epoch [70/100],MLP Test Loss: 0.0139, MLP Test Accuracy: 0.9861\n", "Epoch [80/100],MLP Test Loss: 0.0139, MLP Test Accuracy: 0.9861\n", "Epoch [90/100],MLP Test Loss: 0.0139, MLP Test Accuracy: 0.9861\n", "Epoch [100/100],MLP Test Loss: 0.0139, MLP Test Accuracy: 0.9861\n", "MLP Execution time: 0.38 seconds\n", "Epoch [10/100],KAN Train Loss: 0.0000, KAN Train Accuracy: 1.0000\n", "Epoch [20/100],KAN Train Loss: 0.0000, KAN Train Accuracy: 1.0000\n", "Epoch [30/100],KAN Train Loss: 0.0000, KAN Train Accuracy: 1.0000\n", "Epoch [40/100],KAN Train Loss: 0.0000, KAN Train Accuracy: 1.0000\n", "Epoch [50/100],KAN Train Loss: 0.0000, KAN Train Accuracy: 1.0000\n", "Epoch [60/100],KAN Train Loss: 0.0000, KAN Train Accuracy: 1.0000\n", "Epoch [70/100],KAN Train Loss: 0.0000, KAN Train Accuracy: 1.0000\n", "Epoch [80/100],KAN Train Loss: 0.0000, KAN Train Accuracy: 1.0000\n", "Epoch [90/100],KAN Train Loss: 0.0000, KAN Train Accuracy: 1.0000\n", "Epoch [100/100],KAN Train Loss: 0.0000, KAN Train Accuracy: 1.0000\n", "Epoch [10/100],KAN Test Loss: 0.2545,KAN Test Accuracy: 0.9861\n", "Epoch [20/100],KAN Test Loss: 0.2545,KAN Test Accuracy: 0.9861\n", "Epoch [30/100],KAN Test Loss: 0.2545,KAN Test Accuracy: 0.9861\n", "Epoch [40/100],KAN Test Loss: 0.2545,KAN Test Accuracy: 0.9861\n", "Epoch [50/100],KAN Test Loss: 0.2545,KAN Test Accuracy: 0.9861\n", "Epoch [60/100],KAN Test Loss: 0.2545,KAN Test Accuracy: 0.9861\n", "Epoch [70/100],KAN Test Loss: 0.2545,KAN Test Accuracy: 0.9861\n", "Epoch [80/100],KAN Test Loss: 0.2545,KAN Test Accuracy: 0.9861\n", "Epoch [90/100],KAN Test Loss: 0.2545,KAN Test Accuracy: 0.9861\n", "Epoch [100/100],KAN Test Loss: 0.2545,KAN Test Accuracy: 0.9861\n", "KAN Execution time: 15.29 seconds\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ "<Figure size 1000x500 with 2 Axes>" ], "image/png": "\n" }, "metadata": {} }, { "output_type": "stream", "name": "stdout", "text": [ "MLP Classification Report:\n", " precision recall f1-score support\n", "\n", " 0.0 1.00 1.00 1.00 26\n", " 1.0 1.00 1.00 1.00 27\n", " 2.0 1.00 1.00 1.00 19\n", "\n", " accuracy 1.00 72\n", " macro avg 1.00 1.00 1.00 72\n", "weighted avg 1.00 1.00 1.00 72\n", "\n", "MLP Confusion Matrix:\n", "[[26 0 0]\n", " [ 0 27 0]\n", " [ 0 0 19]]\n", "KAN Classification Report:\n", " precision recall f1-score support\n", "\n", " 0 0.96 1.00 0.98 26\n", " 1 1.00 0.96 0.98 27\n", " 2 1.00 1.00 1.00 19\n", "\n", " accuracy 0.99 72\n", " macro avg 0.99 0.99 0.99 72\n", "weighted avg 0.99 0.99 0.99 72\n", "\n", "KAN Confusion Matrix:\n", "[[26 0 0]\n", " [ 1 26 0]\n", " [ 0 0 19]]\n" ] } ], "source": [ "mlp_2 = MLPClassifier(hidden_layer_sizes=(15,7),\n", " max_iter=1000,\n", " activation='relu',\n", " learning_rate_init= 0.1,\n", " learning_rate='adaptive',\n", " alpha= 1,\n", " solver='adam',\n", " random_state=42,\n", " verbose = False)\n", "\n", "\n", "kan_2= KAN(width=[13,15,7,3], grid=6, k=3, noise_scale=0.1, seed=123)\n", "kan_2.update_grid_from_samples(x_train_tensor)\n", "num_epochs = 100\n", "mlp_metrics_2,kan_metrics_2 = run_models(mlp_2,kan_2,num_epochs)\n", "plot_comparison_over_epochs(mlp_metrics_2, kan_metrics_2,num_epochs)\n", "mlp_2.fit(x_test_tensor, y_test_tensor)\n", "criteria_comparison(mlp_2,kan_2,x_test_tensor,y_test_tensor)" ] } ], "metadata": { "colab": { "collapsed_sections": [ "FRVmujAS_5KM", "viFNCJ0n_4Gw", "RXqFvEH1BNB1", "gEgXBgcVMZhE", "QuDnSAvzc2lt", "hI3eyPDc1gbD", "TMS-nQNQaqwE", "yNbcf_WhsvNQ", "3wE_1Z3gMOp7", "WgRQmGitCypQ" ], "provenance": [], "authorship_tag": "ABX9TyPTC94244GratjCrT3VGaGD", "include_colab_link": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }