Contact Us!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
Avatar for stephanie's main branch.

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

| Download

"Guiding Future STEM Leaders through Innovative Research Training" ~ thinkingbeyond.education

Views: 1086
Image: ubuntu2204
Kernel: Python 3
from tensorflow import lite import tensorflow as tf from tensorflow import keras from tensorflow.keras import layers import numpy as np import pandas as pd import random, os import shutil import matplotlib.pyplot as plt from matplotlib.image import imread from tensorflow.keras.preprocessing.image import ImageDataGenerator from tensorflow.keras.metrics import categorical_accuracy from sklearn.model_selection import train_test_split
import kagglehub # Download latest version path = kagglehub.dataset_download("sovitrath/diabetic-retinopathy-224x224-2019-data") print("Path to dataset files:", path)
Downloading from https://www.kaggle.com/api/v1/datasets/download/sovitrath/diabetic-retinopathy-224x224-2019-data?dataset_version_number=4...
100%|██████████| 238M/238M [00:02<00:00, 95.5MB/s]
Extracting files...
Path to dataset files: /root/.cache/kagglehub/datasets/sovitrath/diabetic-retinopathy-224x224-2019-data/versions/4
# Add an additional column, mapping to the type df = pd.read_csv(r'../root/.cache/kagglehub/datasets/sovitrath/diabetic-retinopathy-224x224-2019-data/versions/4/train.csv') diagnosis_dict_binary = { 0: 'No_DR', 1: 'DR', 2: 'DR', 3: 'DR', 4: 'DR' } diagnosis_dict = { 0: 'No_DR', 1: 'Mild', 2: 'Moderate', 3: 'Severe', 4: 'Proliferate_DR', } df['binary_type'] = df['diagnosis'].map(diagnosis_dict_binary.get) df['type'] = df['diagnosis'].map(diagnosis_dict.get) df.head()
train_intermediate, val = train_test_split(df, test_size = 0.15, stratify = df['type']) train, test = train_test_split(train_intermediate, test_size = 0.15 / (1 - 0.15), stratify = train_intermediate['type']) #print number print(train['type'].value_counts(), '\n')
type No_DR 1263 Moderate 699 Mild 258 Proliferate_DR 207 Severe 135 Name: count, dtype: int64
#create directories for train, val and test. update it everytime the code runs with shutil.rmtree base_dir = '.\dataset' train_dir = os.path.join(base_dir, 'train') val_dir = os.path.join(base_dir, 'val') test_dir = os.path.join(base_dir, 'test') if os.path.exists(base_dir): shutil.rmtree(base_dir) #make directories for train, val and test os.makedirs(train_dir, exist_ok=True) os.makedirs(val_dir, exist_ok=True) os.makedirs(test_dir, exist_ok=True) #ensure the directories have the same columns valid_types = ['No_DR', 'Mild', 'Moderate', 'Severe', 'Proliferate_DR'] df = df[df['type'].isin(valid_types)] assert all(train['type'].isin(valid_types)) assert all(val['type'].isin(valid_types)) assert all(test['type'].isin(valid_types))
#copy the files from the sourcefile to train, validation and test directory respectively src_dir = r'../root/.cache/kagglehub/datasets/sovitrath/diabetic-retinopathy-224x224-2019-data/versions/4/colored_images' for index, row in train.iterrows(): diagnosis = row['type'] binary_diagnosis = row['binary_type'] id_code = row['id_code'] + ".png" srcfile = os.path.join(src_dir, diagnosis, id_code) dstfile = os.path.join(train_dir, diagnosis) os.makedirs(dstfile, exist_ok = True) if os.path.exists(srcfile): shutil.copy(srcfile, dstfile) for index, row in val.iterrows(): diagnosis = row['type'] binary_diagnosis = row['binary_type'] id_code = row['id_code'] + ".png" srcfile = os.path.join(src_dir, diagnosis, id_code) dstfile = os.path.join(val_dir, diagnosis) os.makedirs(dstfile, exist_ok=True) if os.path.exists(srcfile): shutil.copy(srcfile, dstfile) for index, row in test.iterrows(): diagnosis = row['type'] binary_diagnosis = row['binary_type'] id_code = row['id_code'] + ".png" srcfile = os.path.join(src_dir, diagnosis, id_code) dstfile = os.path.join(test_dir, diagnosis) os.makedirs(dstfile, exist_ok=True) if os.path.exists(srcfile): shutil.copy(srcfile, dstfile) for subdir in [train_dir, val_dir, test_dir]: print(f"\nContents of {subdir}:") for root, dirs, files in os.walk(subdir): print(f"{root}: {len(files)} files")
Contents of .\dataset/train: .\dataset/train: 0 files .\dataset/train/Moderate: 699 files .\dataset/train/Mild: 258 files .\dataset/train/Severe: 135 files .\dataset/train/Proliferate_DR: 207 files .\dataset/train/No_DR: 1263 files Contents of .\dataset/val: .\dataset/val: 0 files .\dataset/val/Moderate: 150 files .\dataset/val/Mild: 56 files .\dataset/val/Severe: 29 files .\dataset/val/Proliferate_DR: 44 files .\dataset/val/No_DR: 271 files Contents of .\dataset/test: .\dataset/test: 0 files .\dataset/test/Moderate: 150 files .\dataset/test/Mild: 56 files .\dataset/test/Severe: 29 files .\dataset/test/Proliferate_DR: 44 files .\dataset/test/No_DR: 271 files
train_path = train_dir val_path = val_dir test_path = test_dir train_batches = ImageDataGenerator(rescale = 1./255).flow_from_directory(train_path, target_size=(224,224), shuffle = True) val_batches = ImageDataGenerator(rescale = 1./255).flow_from_directory(val_path, target_size=(224,224), shuffle = True) test_batches = ImageDataGenerator(rescale = 1./255).flow_from_directory(test_path, target_size=(224,224), shuffle = False)
Found 2562 images belonging to 5 classes. Found 550 images belonging to 5 classes. Found 550 images belonging to 5 classes.
#cnn model model = tf.keras.Sequential([ layers.Conv2D(8, (3,3), padding="valid", input_shape=(224,224,3), activation = 'relu'), layers.MaxPooling2D(pool_size=(2,2)), layers.BatchNormalization(), layers.Conv2D(16, (3,3), padding="valid", activation = 'relu'), layers.MaxPooling2D(pool_size=(2,2)), layers.BatchNormalization(), layers.Conv2D(32, (4,4), padding="valid", activation = 'relu'), layers.MaxPooling2D(pool_size=(2,2)), layers.BatchNormalization(), layers.Flatten(), layers.Dense(32, activation = 'relu'), layers.Dropout(0.15), layers.Dense(5, activation = 'softmax') ]) model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5), loss='categorical_crossentropy', metrics=['accuracy']) history = model.fit(train_batches, epochs=50, validation_data=val_batches)
/usr/local/lib/python3.10/dist-packages/keras/src/layers/convolutional/base_conv.py:107: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead. super().__init__(activity_regularizer=activity_regularizer, **kwargs)
Epoch 1/10
/usr/local/lib/python3.10/dist-packages/keras/src/trainers/data_adapters/py_dataset_adapter.py:122: UserWarning: Your `PyDataset` class should call `super().__init__(**kwargs)` in its constructor. `**kwargs` can include `workers`, `use_multiprocessing`, `max_queue_size`. Do not pass these arguments to `fit()`, as they will be ignored. self._warn_if_super_not_called()
81/81 ━━━━━━━━━━━━━━━━━━━━ 101s 1s/step - accuracy: 0.4029 - loss: 1.5315 - val_accuracy: 0.5673 - val_loss: 1.5060 Epoch 2/10 81/81 ━━━━━━━━━━━━━━━━━━━━ 145s 1s/step - accuracy: 0.6735 - loss: 0.9081 - val_accuracy: 0.6727 - val_loss: 1.3945 Epoch 3/10 81/81 ━━━━━━━━━━━━━━━━━━━━ 139s 1s/step - accuracy: 0.7026 - loss: 0.8166 - val_accuracy: 0.6945 - val_loss: 1.2591 Epoch 4/10 81/81 ━━━━━━━━━━━━━━━━━━━━ 142s 1s/step - accuracy: 0.7062 - loss: 0.8137 - val_accuracy: 0.6855 - val_loss: 1.1358 Epoch 5/10 81/81 ━━━━━━━━━━━━━━━━━━━━ 144s 1s/step - accuracy: 0.7130 - loss: 0.7965 - val_accuracy: 0.7127 - val_loss: 0.9882 Epoch 6/10 81/81 ━━━━━━━━━━━━━━━━━━━━ 139s 1s/step - accuracy: 0.7532 - loss: 0.7223 - val_accuracy: 0.7327 - val_loss: 0.8859 Epoch 7/10 81/81 ━━━━━━━━━━━━━━━━━━━━ 144s 1s/step - accuracy: 0.7584 - loss: 0.7100 - val_accuracy: 0.7382 - val_loss: 0.8226 Epoch 8/10 81/81 ━━━━━━━━━━━━━━━━━━━━ 94s 1s/step - accuracy: 0.7351 - loss: 0.7226 - val_accuracy: 0.7327 - val_loss: 0.7908 Epoch 9/10 81/81 ━━━━━━━━━━━━━━━━━━━━ 143s 1s/step - accuracy: 0.7378 - loss: 0.7089 - val_accuracy: 0.7545 - val_loss: 0.7670 Epoch 10/10 81/81 ━━━━━━━━━━━━━━━━━━━━ 140s 1s/step - accuracy: 0.7330 - loss: 0.7138 - val_accuracy: 0.7545 - val_loss: 0.7665
import matplotlib.pyplot as plt import numpy as np from sklearn.metrics import confusion_matrix, f1_score, precision_score, recall_score # Custom Callback to compute metrics after each epoch class MetricsCallback(tf.keras.callbacks.Callback): def __init__(self, validation_batches): self.val_batches = validation_batches self.sensitivity = [] self.specificity = [] self.f1_scores = [] def on_epoch_end(self, epoch, logs=None): # Get predictions and ground truths y_pred_probs = self.model.predict(self.val_batches) # Probabilities y_pred = np.argmax(y_pred_probs, axis=1) # Predicted classes y_true = self.val_batches.labels # True labels # Compute confusion matrix cm = confusion_matrix(y_true, y_pred) tp = np.diag(cm) # True positives for each class fp = np.sum(cm, axis=0) - tp # False positives for each class fn = np.sum(cm, axis=1) - tp # False negatives for each class tn = np.sum(cm) - (tp + fp + fn) # True negatives for each class # Avoid division by zero recall = tp / (tp + fn + np.finfo(float).eps) # Sensitivity (Recall) specificity = tn / (tn + fp + np.finfo(float).eps) precision = tp / (tp + fp + np.finfo(float).eps) f1 = 2 * (precision * recall) / (precision + recall + np.finfo(float).eps) # Store averages across classes self.sensitivity.append(np.mean(recall)) self.specificity.append(np.mean(specificity)) self.f1_scores.append(np.mean(f1)) # Initialize callback metrics_callback = MetricsCallback(validation_batches=val_batches) # Train the model with the callback history = model.fit(train_batches, epochs=50, validation_data=val_batches, callbacks=[metrics_callback]) # Plot Accuracy plt.plot(history.history['accuracy'], label='Train Accuracy') plt.plot(history.history['val_accuracy'], label='Validation Accuracy') plt.xlabel('Epochs') plt.ylabel('Accuracy') plt.title('Accuracy vs Epochs') plt.legend() plt.savefig('cnn acc vs epoch.png') plt.show() # Plot Sensitivity, Specificity, and F1-score epochs = range(1, len(metrics_callback.sensitivity) + 1) plt.plot(epochs, metrics_callback.sensitivity, label='Sensitivity') plt.plot(epochs, metrics_callback.specificity, label='Specificity') plt.plot(epochs, metrics_callback.f1_scores, label='F1 Score') plt.xlabel('Epochs') plt.ylabel('Metrics') plt.title('Sensitivity, Specificity, and F1-Score vs Epochs') plt.legend() plt.savefig('cnn spec f1 vs epoch.png') plt.show()
Epoch 1/50 18/18 ━━━━━━━━━━━━━━━━━━━━ 8s 440ms/step 81/81 ━━━━━━━━━━━━━━━━━━━━ 107s 1s/step - accuracy: 0.8073 - loss: 0.5413 - val_accuracy: 0.7327 - val_loss: 0.7587 Epoch 2/50 18/18 ━━━━━━━━━━━━━━━━━━━━ 8s 418ms/step 81/81 ━━━━━━━━━━━━━━━━━━━━ 138s 1s/step - accuracy: 0.8078 - loss: 0.5459 - val_accuracy: 0.7400 - val_loss: 0.7507 Epoch 3/50 18/18 ━━━━━━━━━━━━━━━━━━━━ 6s 327ms/step 81/81 ━━━━━━━━━━━━━━━━━━━━ 141s 1s/step - accuracy: 0.8205 - loss: 0.5216 - val_accuracy: 0.7327 - val_loss: 0.7526 Epoch 4/50 18/18 ━━━━━━━━━━━━━━━━━━━━ 6s 320ms/step 81/81 ━━━━━━━━━━━━━━━━━━━━ 104s 1s/step - accuracy: 0.8220 - loss: 0.5207 - val_accuracy: 0.7327 - val_loss: 0.7552 Epoch 5/50 18/18 ━━━━━━━━━━━━━━━━━━━━ 6s 332ms/step 81/81 ━━━━━━━━━━━━━━━━━━━━ 102s 1s/step - accuracy: 0.8213 - loss: 0.5000 - val_accuracy: 0.7364 - val_loss: 0.7481 Epoch 6/50 18/18 ━━━━━━━━━━━━━━━━━━━━ 6s 316ms/step 81/81 ━━━━━━━━━━━━━━━━━━━━ 107s 1s/step - accuracy: 0.8227 - loss: 0.5007 - val_accuracy: 0.7400 - val_loss: 0.7574 Epoch 7/50 18/18 ━━━━━━━━━━━━━━━━━━━━ 8s 429ms/step 81/81 ━━━━━━━━━━━━━━━━━━━━ 101s 1s/step - accuracy: 0.8283 - loss: 0.4953 - val_accuracy: 0.7364 - val_loss: 0.7528 Epoch 8/50 18/18 ━━━━━━━━━━━━━━━━━━━━ 6s 324ms/step 81/81 ━━━━━━━━━━━━━━━━━━━━ 144s 1s/step - accuracy: 0.8353 - loss: 0.4914 - val_accuracy: 0.7418 - val_loss: 0.7406 Epoch 9/50 18/18 ━━━━━━━━━━━━━━━━━━━━ 6s 326ms/step 81/81 ━━━━━━━━━━━━━━━━━━━━ 104s 1s/step - accuracy: 0.8286 - loss: 0.4643 - val_accuracy: 0.7382 - val_loss: 0.7535 Epoch 10/50 18/18 ━━━━━━━━━━━━━━━━━━━━ 6s 322ms/step 81/81 ━━━━━━━━━━━━━━━━━━━━ 107s 1s/step - accuracy: 0.8268 - loss: 0.4771 - val_accuracy: 0.7436 - val_loss: 0.7535 Epoch 11/50 18/18 ━━━━━━━━━━━━━━━━━━━━ 8s 432ms/step 81/81 ━━━━━━━━━━━━━━━━━━━━ 101s 1s/step - accuracy: 0.8484 - loss: 0.4625 - val_accuracy: 0.7382 - val_loss: 0.7537 Epoch 12/50 18/18 ━━━━━━━━━━━━━━━━━━━━ 6s 331ms/step 81/81 ━━━━━━━━━━━━━━━━━━━━ 144s 1s/step - accuracy: 0.8515 - loss: 0.4411 - val_accuracy: 0.7345 - val_loss: 0.7652 Epoch 13/50 18/18 ━━━━━━━━━━━━━━━━━━━━ 6s 322ms/step 81/81 ━━━━━━━━━━━━━━━━━━━━ 106s 1s/step - accuracy: 0.8459 - loss: 0.4426 - val_accuracy: 0.7364 - val_loss: 0.7536 Epoch 14/50 18/18 ━━━━━━━━━━━━━━━━━━━━ 6s 326ms/step 81/81 ━━━━━━━━━━━━━━━━━━━━ 111s 1s/step - accuracy: 0.8416 - loss: 0.4430 - val_accuracy: 0.7345 - val_loss: 0.7559 Epoch 15/50 18/18 ━━━━━━━━━━━━━━━━━━━━ 6s 324ms/step 81/81 ━━━━━━━━━━━━━━━━━━━━ 138s 1s/step - accuracy: 0.8490 - loss: 0.4353 - val_accuracy: 0.7455 - val_loss: 0.7509 Epoch 16/50 18/18 ━━━━━━━━━━━━━━━━━━━━ 8s 436ms/step 81/81 ━━━━━━━━━━━━━━━━━━━━ 103s 1s/step - accuracy: 0.8572 - loss: 0.4196 - val_accuracy: 0.7400 - val_loss: 0.7536 Epoch 17/50 18/18 ━━━━━━━━━━━━━━━━━━━━ 6s 337ms/step 81/81 ━━━━━━━━━━━━━━━━━━━━ 145s 1s/step - accuracy: 0.8584 - loss: 0.4279 - val_accuracy: 0.7509 - val_loss: 0.7486 Epoch 18/50 18/18 ━━━━━━━━━━━━━━━━━━━━ 7s 380ms/step 81/81 ━━━━━━━━━━━━━━━━━━━━ 109s 1s/step - accuracy: 0.8700 - loss: 0.4032 - val_accuracy: 0.7491 - val_loss: 0.7484 Epoch 19/50 18/18 ━━━━━━━━━━━━━━━━━━━━ 6s 317ms/step 81/81 ━━━━━━━━━━━━━━━━━━━━ 105s 1s/step - accuracy: 0.8678 - loss: 0.4065 - val_accuracy: 0.7436 - val_loss: 0.7501 Epoch 20/50 18/18 ━━━━━━━━━━━━━━━━━━━━ 6s 329ms/step 81/81 ━━━━━━━━━━━━━━━━━━━━ 141s 1s/step - accuracy: 0.8708 - loss: 0.3996 - val_accuracy: 0.7473 - val_loss: 0.7510 Epoch 21/50 18/18 ━━━━━━━━━━━━━━━━━━━━ 8s 433ms/step 81/81 ━━━━━━━━━━━━━━━━━━━━ 103s 1s/step - accuracy: 0.8808 - loss: 0.3935 - val_accuracy: 0.7382 - val_loss: 0.7557 Epoch 22/50 18/18 ━━━━━━━━━━━━━━━━━━━━ 8s 434ms/step 81/81 ━━━━━━━━━━━━━━━━━━━━ 105s 1s/step - accuracy: 0.8671 - loss: 0.4118 - val_accuracy: 0.7400 - val_loss: 0.7632 Epoch 23/50 18/18 ━━━━━━━━━━━━━━━━━━━━ 8s 429ms/step 81/81 ━━━━━━━━━━━━━━━━━━━━ 102s 1s/step - accuracy: 0.8693 - loss: 0.3885 - val_accuracy: 0.7527 - val_loss: 0.7595 Epoch 24/50 18/18 ━━━━━━━━━━━━━━━━━━━━ 6s 314ms/step 81/81 ━━━━━━━━━━━━━━━━━━━━ 141s 1s/step - accuracy: 0.8671 - loss: 0.3895 - val_accuracy: 0.7418 - val_loss: 0.7526 Epoch 25/50 18/18 ━━━━━━━━━━━━━━━━━━━━ 8s 433ms/step 81/81 ━━━━━━━━━━━━━━━━━━━━ 147s 1s/step - accuracy: 0.8796 - loss: 0.3842 - val_accuracy: 0.7364 - val_loss: 0.7678 Epoch 26/50 18/18 ━━━━━━━━━━━━━━━━━━━━ 8s 429ms/step 81/81 ━━━━━━━━━━━━━━━━━━━━ 139s 1s/step - accuracy: 0.8792 - loss: 0.3600 - val_accuracy: 0.7491 - val_loss: 0.7655 Epoch 27/50 18/18 ━━━━━━━━━━━━━━━━━━━━ 8s 431ms/step 81/81 ━━━━━━━━━━━━━━━━━━━━ 101s 1s/step - accuracy: 0.8941 - loss: 0.3391 - val_accuracy: 0.7400 - val_loss: 0.7691 Epoch 28/50 18/18 ━━━━━━━━━━━━━━━━━━━━ 6s 321ms/step 81/81 ━━━━━━━━━━━━━━━━━━━━ 145s 1s/step - accuracy: 0.8877 - loss: 0.3532 - val_accuracy: 0.7418 - val_loss: 0.7658 Epoch 29/50 18/18 ━━━━━━━━━━━━━━━━━━━━ 6s 325ms/step 81/81 ━━━━━━━━━━━━━━━━━━━━ 102s 1s/step - accuracy: 0.8948 - loss: 0.3400 - val_accuracy: 0.7436 - val_loss: 0.7761 Epoch 30/50 18/18 ━━━━━━━━━━━━━━━━━━━━ 6s 324ms/step 81/81 ━━━━━━━━━━━━━━━━━━━━ 147s 1s/step - accuracy: 0.8936 - loss: 0.3344 - val_accuracy: 0.7527 - val_loss: 0.7652 Epoch 31/50 18/18 ━━━━━━━━━━━━━━━━━━━━ 6s 320ms/step 81/81 ━━━━━━━━━━━━━━━━━━━━ 137s 1s/step - accuracy: 0.8954 - loss: 0.3314 - val_accuracy: 0.7327 - val_loss: 0.7747 Epoch 32/50 18/18 ━━━━━━━━━━━━━━━━━━━━ 7s 370ms/step 81/81 ━━━━━━━━━━━━━━━━━━━━ 108s 1s/step - accuracy: 0.9006 - loss: 0.3306 - val_accuracy: 0.7436 - val_loss: 0.7662 Epoch 33/50 18/18 ━━━━━━━━━━━━━━━━━━━━ 8s 433ms/step 81/81 ━━━━━━━━━━━━━━━━━━━━ 137s 1s/step - accuracy: 0.8943 - loss: 0.3268 - val_accuracy: 0.7436 - val_loss: 0.7721 Epoch 34/50 18/18 ━━━━━━━━━━━━━━━━━━━━ 6s 314ms/step 81/81 ━━━━━━━━━━━━━━━━━━━━ 142s 1s/step - accuracy: 0.9086 - loss: 0.3117 - val_accuracy: 0.7436 - val_loss: 0.7863 Epoch 35/50 18/18 ━━━━━━━━━━━━━━━━━━━━ 6s 315ms/step 81/81 ━━━━━━━━━━━━━━━━━━━━ 100s 1s/step - accuracy: 0.9065 - loss: 0.2979 - val_accuracy: 0.7418 - val_loss: 0.7767 Epoch 36/50 18/18 ━━━━━━━━━━━━━━━━━━━━ 6s 334ms/step 81/81 ━━━━━━━━━━━━━━━━━━━━ 106s 1s/step - accuracy: 0.9149 - loss: 0.3005 - val_accuracy: 0.7455 - val_loss: 0.7790 Epoch 37/50 18/18 ━━━━━━━━━━━━━━━━━━━━ 8s 431ms/step 81/81 ━━━━━━━━━━━━━━━━━━━━ 136s 1s/step - accuracy: 0.9074 - loss: 0.2953 - val_accuracy: 0.7364 - val_loss: 0.7779 Epoch 38/50 18/18 ━━━━━━━━━━━━━━━━━━━━ 6s 320ms/step 81/81 ━━━━━━━━━━━━━━━━━━━━ 143s 1s/step - accuracy: 0.9103 - loss: 0.3055 - val_accuracy: 0.7509 - val_loss: 0.7785 Epoch 39/50 18/18 ━━━━━━━━━━━━━━━━━━━━ 9s 491ms/step 81/81 ━━━━━━━━━━━━━━━━━━━━ 105s 1s/step - accuracy: 0.9185 - loss: 0.2735 - val_accuracy: 0.7418 - val_loss: 0.7780 Epoch 40/50 18/18 ━━━━━━━━━━━━━━━━━━━━ 8s 422ms/step 81/81 ━━━━━━━━━━━━━━━━━━━━ 137s 1s/step - accuracy: 0.9141 - loss: 0.2989 - val_accuracy: 0.7364 - val_loss: 0.7859 Epoch 41/50 18/18 ━━━━━━━━━━━━━━━━━━━━ 9s 506ms/step 81/81 ━━━━━━━━━━━━━━━━━━━━ 102s 1s/step - accuracy: 0.9229 - loss: 0.2660 - val_accuracy: 0.7382 - val_loss: 0.7951 Epoch 42/50 18/18 ━━━━━━━━━━━━━━━━━━━━ 8s 429ms/step 81/81 ━━━━━━━━━━━━━━━━━━━━ 100s 1s/step - accuracy: 0.9228 - loss: 0.2747 - val_accuracy: 0.7491 - val_loss: 0.7982 Epoch 43/50 18/18 ━━━━━━━━━━━━━━━━━━━━ 8s 422ms/step 81/81 ━━━━━━━━━━━━━━━━━━━━ 100s 1s/step - accuracy: 0.9129 - loss: 0.2866 - val_accuracy: 0.7345 - val_loss: 0.7970 Epoch 44/50 18/18 ━━━━━━━━━━━━━━━━━━━━ 6s 321ms/step 81/81 ━━━━━━━━━━━━━━━━━━━━ 145s 1s/step - accuracy: 0.9183 - loss: 0.2825 - val_accuracy: 0.7382 - val_loss: 0.7967 Epoch 45/50 18/18 ━━━━━━━━━━━━━━━━━━━━ 6s 321ms/step 81/81 ━━━━━━━━━━━━━━━━━━━━ 100s 1s/step - accuracy: 0.9215 - loss: 0.2690 - val_accuracy: 0.7436 - val_loss: 0.7990 Epoch 46/50 18/18 ━━━━━━━━━━━━━━━━━━━━ 6s 317ms/step 81/81 ━━━━━━━━━━━━━━━━━━━━ 102s 1s/step - accuracy: 0.9264 - loss: 0.2446 - val_accuracy: 0.7364 - val_loss: 0.7988 Epoch 47/50 18/18 ━━━━━━━━━━━━━━━━━━━━ 7s 395ms/step 81/81 ━━━━━━━━━━━━━━━━━━━━ 143s 1s/step - accuracy: 0.9116 - loss: 0.2727 - val_accuracy: 0.7436 - val_loss: 0.7939 Epoch 48/50 18/18 ━━━━━━━━━━━━━━━━━━━━ 6s 317ms/step 81/81 ━━━━━━━━━━━━━━━━━━━━ 140s 1s/step - accuracy: 0.9241 - loss: 0.2640 - val_accuracy: 0.7509 - val_loss: 0.7995 Epoch 49/50 18/18 ━━━━━━━━━━━━━━━━━━━━ 10s 537ms/step 81/81 ━━━━━━━━━━━━━━━━━━━━ 143s 1s/step - accuracy: 0.9366 - loss: 0.2348 - val_accuracy: 0.7545 - val_loss: 0.7973 Epoch 50/50 18/18 ━━━━━━━━━━━━━━━━━━━━ 7s 396ms/step 81/81 ━━━━━━━━━━━━━━━━━━━━ 140s 1s/step - accuracy: 0.9305 - loss: 0.2520 - val_accuracy: 0.7455 - val_loss: 0.7954
Image in a Jupyter notebookImage in a Jupyter notebook
files.download('cnn acc vs epoch.png') files.download('cnn spec f1 vs epoch.png')
<IPython.core.display.Javascript object>
<IPython.core.display.Javascript object>
<IPython.core.display.Javascript object>
<IPython.core.display.Javascript object>
model.save('64x3-CNN.keras')
test_loss, test_accuracy = model.evaluate(test_batches) print(f"Test Accuracy: {test_accuracy * 100:.2f}%")
18/18 ━━━━━━━━━━━━━━━━━━━━ 6s 329ms/step - accuracy: 0.7009 - loss: 0.8493 Test Accuracy: 73.82%
from sklearn.metrics import roc_curve, auc from sklearn.preprocessing import label_binarize import matplotlib.pyplot as plt import numpy as np y_test = test_batches.labels # Binarize the output (e.g., for 5 classes) y_test_bin = label_binarize(y_test, classes=[0, 1, 2, 3, 4]) n_classes = y_test_bin.shape[1] # Number of classes # Ensure predictions are also in the right shape y_pred_probs = model.predict(test_batches) # Probabilities for all classes # Compute ROC curve and AUC for each class fpr = {} tpr = {} roc_auc = {} for i in range(n_classes): fpr[i], tpr[i], _ = roc_curve(y_test_bin[:, i], y_pred_probs[:, i]) roc_auc[i] = auc(fpr[i], tpr[i]) # Plot all ROC curves plt.figure() colors = ['blue', 'green', 'red', 'cyan', 'magenta'] for i in range(n_classes): plt.plot(fpr[i], tpr[i], color=colors[i], lw=2, label=f'Class {i} (AUC = {roc_auc[i]:.2f})') # Add random guess line plt.plot([0, 1], [0, 1], color='gray', linestyle='--') plt.xlabel('False Positive Rate') plt.ylabel('True Positive Rate') plt.title('Receiver Operating Characteristic for Multi-Class') plt.legend(loc='lower right') plt.savefig('cnn roc curve.png') plt.show()
/usr/local/lib/python3.10/dist-packages/keras/src/trainers/data_adapters/py_dataset_adapter.py:122: UserWarning: Your `PyDataset` class should call `super().__init__(**kwargs)` in its constructor. `**kwargs` can include `workers`, `use_multiprocessing`, `max_queue_size`. Do not pass these arguments to `fit()`, as they will be ignored. self._warn_if_super_not_called()
18/18 ━━━━━━━━━━━━━━━━━━━━ 8s 435ms/step
Image in a Jupyter notebook
from google.colab import files files.download('cnn roc curve.png')
<IPython.core.display.Javascript object>
<IPython.core.display.Javascript object>
#generate confusion matrix from sklearn.metrics import classification_report y_pred = np.argmax(y_pred_probs, axis=1) # Predicted class labels print(classification_report(y_test, y_pred, target_names=test_batches.class_indices.keys())) from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay cm = confusion_matrix(y_test, y_pred) disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=test_batches.class_indices.keys()) disp.plot(cmap="Blues") plt.savefig('cnn confusion matrix.png') plt.show()
precision recall f1-score support Mild 0.47 0.50 0.49 56 Moderate 0.60 0.73 0.66 150 No_DR 0.89 0.97 0.93 271 Proliferate_DR 0.33 0.02 0.04 44 Severe 0.60 0.21 0.31 29 accuracy 0.74 550 macro avg 0.58 0.49 0.49 550 weighted avg 0.71 0.74 0.71 550
Image in a Jupyter notebook
from google.colab import files files.download('cnn confusion matrix.png')
<IPython.core.display.Javascript object>
<IPython.core.display.Javascript object>
#evaluate training and validation accuracy on a iteratively halving dataset from math import ceil, pow src_dir = r'../root/.cache/kagglehub/datasets/sovitrath/diabetic-retinopathy-224x224-2019-data/versions/4/colored_images' for i in range(5): total_files = len(train) denom = pow(2, i) print(denom) files_to_copy = ceil(total_files / denom) # Half the files, rounded up if odd print(files_to_copy) if os.path.exists(train_dir): shutil.rmtree(train_dir) os.makedirs(train_dir, exist_ok=True) # Counter to track the number of files copied copied_files = 0 for index, row in train.iterrows(): if copied_files >= files_to_copy: break # Stop when half of the files are copied diagnosis = row['type'] binary_diagnosis = row['binary_type'] id_code = row['id_code'] + ".png" srcfile = os.path.join(src_dir, diagnosis, id_code) dstfile = os.path.join(train_dir, binary_diagnosis) os.makedirs(dstfile, exist_ok=True) if os.path.exists(srcfile): shutil.copy(srcfile, dstfile) copied_files += 1 # Increment the counter for index, row in val.iterrows(): diagnosis = row['type'] binary_diagnosis = row['binary_type'] id_code = row['id_code'] + ".png" srcfile = os.path.join(src_dir, diagnosis, id_code) dstfile = os.path.join(val_dir, binary_diagnosis) os.makedirs(dstfile, exist_ok=True) if os.path.exists(srcfile): shutil.copy(srcfile, dstfile) for index, row in test.iterrows(): diagnosis = row['type'] binary_diagnosis = row['binary_type'] id_code = row['id_code'] + ".png" srcfile = os.path.join(src_dir, diagnosis, id_code) dstfile = os.path.join(test_dir, binary_diagnosis) os.makedirs(dstfile, exist_ok=True) if os.path.exists(srcfile): shutil.copy(srcfile, dstfile) for subdir in [train_dir, val_dir, test_dir]: print(f"\nContents of {subdir}:") for root, dirs, files in os.walk(subdir): print(f"{root}: {len(files)} files") train_path = train_dir val_path = val_dir test_path = test_dir train_batches = ImageDataGenerator(rescale = 1./255).flow_from_directory(train_path, target_size=(224,224), shuffle = True) val_batches = ImageDataGenerator(rescale = 1./255).flow_from_directory(val_path, target_size=(224,224), shuffle = True) test_batches = ImageDataGenerator(rescale = 1./255).flow_from_directory(test_path, target_size=(224,224), shuffle = False) #cnn model model = tf.keras.Sequential([ layers.Conv2D(8, (3,3), padding="valid", input_shape=(224,224,3), activation = 'relu'), layers.MaxPooling2D(pool_size=(2,2)), layers.BatchNormalization(), layers.Conv2D(16, (3,3), padding="valid", activation = 'relu'), layers.MaxPooling2D(pool_size=(2,2)), layers.BatchNormalization(), layers.Conv2D(32, (4,4), padding="valid", activation = 'relu'), layers.MaxPooling2D(pool_size=(2,2)), layers.BatchNormalization(), layers.Flatten(), layers.Dense(32, activation = 'relu'), layers.Dropout(0.15), layers.Dense(2, activation = 'softmax') ]) model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5), loss='categorical_crossentropy', metrics=['accuracy']) history = model.fit(train_batches, epochs=10, validation_data=val_batches)
1.0 2562 Contents of .\dataset/train: .\dataset/train: 0 files .\dataset/train/No_DR: 1263 files .\dataset/train/DR: 1299 files Contents of .\dataset/val: .\dataset/val: 0 files .\dataset/val/No_DR: 271 files .\dataset/val/DR: 279 files Contents of .\dataset/test: .\dataset/test: 0 files .\dataset/test/No_DR: 271 files .\dataset/test/DR: 279 files Found 2562 images belonging to 2 classes. Found 550 images belonging to 2 classes. Found 550 images belonging to 2 classes.
/usr/local/lib/python3.10/dist-packages/keras/src/layers/convolutional/base_conv.py:107: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead. super().__init__(activity_regularizer=activity_regularizer, **kwargs)
Epoch 1/10
/usr/local/lib/python3.10/dist-packages/keras/src/trainers/data_adapters/py_dataset_adapter.py:122: UserWarning: Your `PyDataset` class should call `super().__init__(**kwargs)` in its constructor. `**kwargs` can include `workers`, `use_multiprocessing`, `max_queue_size`. Do not pass these arguments to `fit()`, as they will be ignored. self._warn_if_super_not_called()
81/81 ━━━━━━━━━━━━━━━━━━━━ 95s 1s/step - accuracy: 0.7844 - loss: 0.4764 - val_accuracy: 0.5073 - val_loss: 0.8788 Epoch 2/10 81/81 ━━━━━━━━━━━━━━━━━━━━ 138s 1s/step - accuracy: 0.9100 - loss: 0.2328 - val_accuracy: 0.5073 - val_loss: 1.0356 Epoch 3/10 81/81 ━━━━━━━━━━━━━━━━━━━━ 87s 1s/step - accuracy: 0.9173 - loss: 0.2120 - val_accuracy: 0.5073 - val_loss: 1.1552 Epoch 4/10 81/81 ━━━━━━━━━━━━━━━━━━━━ 144s 1s/step - accuracy: 0.9285 - loss: 0.1944 - val_accuracy: 0.5109 - val_loss: 0.9722 Epoch 5/10 81/81 ━━━━━━━━━━━━━━━━━━━━ 87s 1s/step - accuracy: 0.9311 - loss: 0.1797 - val_accuracy: 0.5491 - val_loss: 0.6548 Epoch 6/10 81/81 ━━━━━━━━━━━━━━━━━━━━ 143s 1s/step - accuracy: 0.9346 - loss: 0.1593 - val_accuracy: 0.7582 - val_loss: 0.4198 Epoch 7/10 81/81 ━━━━━━━━━━━━━━━━━━━━ 86s 1s/step - accuracy: 0.9482 - loss: 0.1526 - val_accuracy: 0.9164 - val_loss: 0.2454 Epoch 8/10 81/81 ━━━━━━━━━━━━━━━━━━━━ 89s 1s/step - accuracy: 0.9395 - loss: 0.1521 - val_accuracy: 0.9236 - val_loss: 0.1954 Epoch 9/10 81/81 ━━━━━━━━━━━━━━━━━━━━ 87s 1s/step - accuracy: 0.9404 - loss: 0.1574 - val_accuracy: 0.9309 - val_loss: 0.1698 Epoch 10/10 81/81 ━━━━━━━━━━━━━━━━━━━━ 143s 1s/step - accuracy: 0.9452 - loss: 0.1475 - val_accuracy: 0.9327 - val_loss: 0.1669 2.0 1281 Contents of .\dataset/train: .\dataset/train: 0 files .\dataset/train/No_DR: 664 files .\dataset/train/DR: 617 files Contents of .\dataset/val: .\dataset/val: 0 files .\dataset/val/No_DR: 271 files .\dataset/val/DR: 279 files Contents of .\dataset/test: .\dataset/test: 0 files .\dataset/test/No_DR: 271 files .\dataset/test/DR: 279 files Found 1281 images belonging to 2 classes. Found 550 images belonging to 2 classes. Found 550 images belonging to 2 classes. Epoch 1/10 41/41 ━━━━━━━━━━━━━━━━━━━━ 52s 1s/step - accuracy: 0.7730 - loss: 0.4892 - val_accuracy: 0.5073 - val_loss: 0.6800 Epoch 2/10 41/41 ━━━━━━━━━━━━━━━━━━━━ 47s 1s/step - accuracy: 0.8863 - loss: 0.2912 - val_accuracy: 0.5073 - val_loss: 0.6994 Epoch 3/10 41/41 ━━━━━━━━━━━━━━━━━━━━ 51s 1s/step - accuracy: 0.8942 - loss: 0.2338 - val_accuracy: 0.5073 - val_loss: 0.7084 Epoch 4/10 41/41 ━━━━━━━━━━━━━━━━━━━━ 79s 1s/step - accuracy: 0.9099 - loss: 0.2347 - val_accuracy: 0.5073 - val_loss: 0.7110 Epoch 5/10 41/41 ━━━━━━━━━━━━━━━━━━━━ 47s 1s/step - accuracy: 0.9302 - loss: 0.2254 - val_accuracy: 0.5127 - val_loss: 0.6887 Epoch 6/10 41/41 ━━━━━━━━━━━━━━━━━━━━ 52s 1s/step - accuracy: 0.9352 - loss: 0.1886 - val_accuracy: 0.5218 - val_loss: 0.6459 Epoch 7/10 41/41 ━━━━━━━━━━━━━━━━━━━━ 79s 1s/step - accuracy: 0.9431 - loss: 0.1787 - val_accuracy: 0.5327 - val_loss: 0.6414 Epoch 8/10 41/41 ━━━━━━━━━━━━━━━━━━━━ 80s 1s/step - accuracy: 0.9385 - loss: 0.1534 - val_accuracy: 0.5982 - val_loss: 0.5673 Epoch 9/10 41/41 ━━━━━━━━━━━━━━━━━━━━ 82s 1s/step - accuracy: 0.9442 - loss: 0.1668 - val_accuracy: 0.7182 - val_loss: 0.4592 Epoch 10/10 41/41 ━━━━━━━━━━━━━━━━━━━━ 49s 1s/step - accuracy: 0.9510 - loss: 0.1555 - val_accuracy: 0.7745 - val_loss: 0.4030 4.0 641 Contents of .\dataset/train: .\dataset/train: 0 files .\dataset/train/No_DR: 335 files .\dataset/train/DR: 306 files Contents of .\dataset/val: .\dataset/val: 0 files .\dataset/val/No_DR: 271 files .\dataset/val/DR: 279 files Contents of .\dataset/test: .\dataset/test: 0 files .\dataset/test/No_DR: 271 files .\dataset/test/DR: 279 files Found 641 images belonging to 2 classes. Found 550 images belonging to 2 classes. Found 550 images belonging to 2 classes. Epoch 1/10 21/21 ━━━━━━━━━━━━━━━━━━━━ 30s 1s/step - accuracy: 0.6456 - loss: 0.6968 - val_accuracy: 0.6873 - val_loss: 0.6828 Epoch 2/10 21/21 ━━━━━━━━━━━━━━━━━━━━ 27s 1s/step - accuracy: 0.8668 - loss: 0.3267 - val_accuracy: 0.7309 - val_loss: 0.6757 Epoch 3/10 21/21 ━━━━━━━━━━━━━━━━━━━━ 27s 1s/step - accuracy: 0.8835 - loss: 0.2830 - val_accuracy: 0.5982 - val_loss: 0.6689 Epoch 4/10 21/21 ━━━━━━━━━━━━━━━━━━━━ 41s 1s/step - accuracy: 0.8973 - loss: 0.2757 - val_accuracy: 0.5145 - val_loss: 0.6677 Epoch 5/10 21/21 ━━━━━━━━━━━━━━━━━━━━ 41s 1s/step - accuracy: 0.9076 - loss: 0.2304 - val_accuracy: 0.5145 - val_loss: 0.6660 Epoch 6/10 21/21 ━━━━━━━━━━━━━━━━━━━━ 27s 1s/step - accuracy: 0.9293 - loss: 0.1915 - val_accuracy: 0.5145 - val_loss: 0.6673 Epoch 7/10 21/21 ━━━━━━━━━━━━━━━━━━━━ 27s 1s/step - accuracy: 0.9245 - loss: 0.1830 - val_accuracy: 0.5127 - val_loss: 0.6822 Epoch 8/10 21/21 ━━━━━━━━━━━━━━━━━━━━ 27s 1s/step - accuracy: 0.9402 - loss: 0.1654 - val_accuracy: 0.5436 - val_loss: 0.6704 Epoch 9/10 21/21 ━━━━━━━━━━━━━━━━━━━━ 27s 1s/step - accuracy: 0.9411 - loss: 0.1661 - val_accuracy: 0.5455 - val_loss: 0.6825 Epoch 10/10 21/21 ━━━━━━━━━━━━━━━━━━━━ 30s 1s/step - accuracy: 0.9319 - loss: 0.1554 - val_accuracy: 0.5618 - val_loss: 0.6682 8.0 321 Contents of .\dataset/train: .\dataset/train: 0 files .\dataset/train/No_DR: 168 files .\dataset/train/DR: 153 files Contents of .\dataset/val: .\dataset/val: 0 files .\dataset/val/No_DR: 271 files .\dataset/val/DR: 279 files Contents of .\dataset/test: .\dataset/test: 0 files .\dataset/test/No_DR: 271 files .\dataset/test/DR: 279 files Found 321 images belonging to 2 classes. Found 550 images belonging to 2 classes. Found 550 images belonging to 2 classes. Epoch 1/10 11/11 ━━━━━━━━━━━━━━━━━━━━ 19s 1s/step - accuracy: 0.6672 - loss: 0.6967 - val_accuracy: 0.5073 - val_loss: 0.6918 Epoch 2/10 11/11 ━━━━━━━━━━━━━━━━━━━━ 20s 2s/step - accuracy: 0.8493 - loss: 0.3656 - val_accuracy: 0.5073 - val_loss: 0.6956 Epoch 3/10 11/11 ━━━━━━━━━━━━━━━━━━━━ 16s 1s/step - accuracy: 0.8947 - loss: 0.3088 - val_accuracy: 0.5073 - val_loss: 0.7002 Epoch 4/10 11/11 ━━━━━━━━━━━━━━━━━━━━ 16s 1s/step - accuracy: 0.8497 - loss: 0.3160 - val_accuracy: 0.5073 - val_loss: 0.7053 Epoch 5/10 11/11 ━━━━━━━━━━━━━━━━━━━━ 21s 1s/step - accuracy: 0.8942 - loss: 0.2618 - val_accuracy: 0.5073 - val_loss: 0.7103 Epoch 6/10 11/11 ━━━━━━━━━━━━━━━━━━━━ 16s 1s/step - accuracy: 0.9385 - loss: 0.2016 - val_accuracy: 0.5073 - val_loss: 0.7169 Epoch 7/10 11/11 ━━━━━━━━━━━━━━━━━━━━ 21s 2s/step - accuracy: 0.9371 - loss: 0.1914 - val_accuracy: 0.5073 - val_loss: 0.7235 Epoch 8/10 11/11 ━━━━━━━━━━━━━━━━━━━━ 17s 1s/step - accuracy: 0.9385 - loss: 0.1765 - val_accuracy: 0.5073 - val_loss: 0.7315 Epoch 9/10 11/11 ━━━━━━━━━━━━━━━━━━━━ 16s 1s/step - accuracy: 0.9635 - loss: 0.1654 - val_accuracy: 0.5073 - val_loss: 0.7362 Epoch 10/10 11/11 ━━━━━━━━━━━━━━━━━━━━ 16s 1s/step - accuracy: 0.9161 - loss: 0.1862 - val_accuracy: 0.5073 - val_loss: 0.7401 16.0 161 Contents of .\dataset/train: .\dataset/train: 0 files .\dataset/train/No_DR: 89 files .\dataset/train/DR: 72 files Contents of .\dataset/val: .\dataset/val: 0 files .\dataset/val/No_DR: 271 files .\dataset/val/DR: 279 files Contents of .\dataset/test: .\dataset/test: 0 files .\dataset/test/No_DR: 271 files .\dataset/test/DR: 279 files Found 161 images belonging to 2 classes. Found 550 images belonging to 2 classes. Found 550 images belonging to 2 classes. Epoch 1/10 6/6 ━━━━━━━━━━━━━━━━━━━━ 15s 2s/step - accuracy: 0.6488 - loss: 0.7023 - val_accuracy: 0.5691 - val_loss: 0.6915 Epoch 2/10 6/6 ━━━━━━━━━━━━━━━━━━━━ 20s 2s/step - accuracy: 0.7992 - loss: 0.4822 - val_accuracy: 0.5364 - val_loss: 0.6905 Epoch 3/10 6/6 ━━━━━━━━━━━━━━━━━━━━ 20s 2s/step - accuracy: 0.8545 - loss: 0.3240 - val_accuracy: 0.5073 - val_loss: 0.6910 Epoch 4/10 6/6 ━━━━━━━━━━━━━━━━━━━━ 12s 2s/step - accuracy: 0.8747 - loss: 0.3692 - val_accuracy: 0.5073 - val_loss: 0.6940 Epoch 5/10 6/6 ━━━━━━━━━━━━━━━━━━━━ 20s 2s/step - accuracy: 0.9120 - loss: 0.2714 - val_accuracy: 0.5073 - val_loss: 0.6983 Epoch 6/10 6/6 ━━━━━━━━━━━━━━━━━━━━ 21s 2s/step - accuracy: 0.9072 - loss: 0.2942 - val_accuracy: 0.5073 - val_loss: 0.7036 Epoch 7/10 6/6 ━━━━━━━━━━━━━━━━━━━━ 12s 2s/step - accuracy: 0.8756 - loss: 0.2688 - val_accuracy: 0.5073 - val_loss: 0.7078 Epoch 8/10 6/6 ━━━━━━━━━━━━━━━━━━━━ 12s 2s/step - accuracy: 0.9226 - loss: 0.2576 - val_accuracy: 0.5073 - val_loss: 0.7089 Epoch 9/10 6/6 ━━━━━━━━━━━━━━━━━━━━ 20s 2s/step - accuracy: 0.9151 - loss: 0.2522 - val_accuracy: 0.5073 - val_loss: 0.7091 Epoch 10/10 6/6 ━━━━━━━━━━━━━━━━━━━━ 12s 2s/step - accuracy: 0.9447 - loss: 0.2030 - val_accuracy: 0.5073 - val_loss: 0.7103
import matplotlib.pyplot as plt # Data database_size = [2562, 1281, 641, 321, 161] cnn_train_acc = [94.52, 95.10, 93.19, 91.61, 94.47] cnn_val_acc = [93.27, 77.45, 56.18, 50.73, 50.73] vit_train_acc = [91.67, 88.21, 70.20, 54.34, 62.67] vit_val_acc = [92.73, 94, 69.27, 63.09, 63.09] # Plot figure, axis = plt.subplots(1, 2, figsize=(10, 5)) # Define the y-ticks with 5 increments y_ticks = np.arange(50, 101, 5) # CNN plot axis[0].plot(database_size, cnn_train_acc, label='CNN Train Accuracy') axis[0].plot(database_size, cnn_val_acc, label='CNN Validation Accuracy') axis[0].set_xlabel('Database Size') axis[0].set_ylabel('Accuracy') axis[0].set_title('CNN Accuracy vs Database Size') #axis[0].set_xticks(database_size) # Set the x-ticks explicitly axis[0].set_yticks(y_ticks) # Set consistent y-ticks axis[0].invert_xaxis() # Reverse the x-axis axis[0].legend() # ViT plot axis[1].plot(database_size, vit_train_acc, label='ViT Train Accuracy') axis[1].plot(database_size, vit_val_acc, label='ViT Validation Accuracy') axis[1].set_xlabel('Database Size') axis[1].set_ylabel('Accuracy') axis[1].set_title('ViT Accuracy vs Database Size') #axis[1].set_xticks(database_size) # Set the x-ticks explicitly axis[1].set_yticks(y_ticks) # Set consistent y-ticks axis[1].invert_xaxis() # Reverse the x-axis axis[1].legend() plt.savefig('cnn vs vit.png') plt.tight_layout() plt.show()
Image in a Jupyter notebook
from google.colab import files files.download('cnn vs vit.png')
<IPython.core.display.Javascript object>
<IPython.core.display.Javascript object>