Path: blob/master/notebooks/book1/19/finetune_cnn_jax.ipynb
1192 views
Kernel: Python 3
Please find torch implementation of this notebook here: https://colab.research.google.com/github/probml/pyprobml/blob/master/notebooks/book1/19/finetune_cnn_torch.ipynb
Author of the Notebook : Susnato Dhar (Github : https://github.com/susnato)
This notebook is JAX compatible version of the main notebook which can be found here.
All the credits goes to the author of the main notebook, I just converted it to JAX.
I used this repository to impelement the pre-trained version of ResNet18 in order to fine tune it!
I used the Dataset HotDog VS No HotDog from this link.
In [1]:
Out[1]:
Building wheel for augmax (setup.py) ... done
|████████████████████████████████| 184 kB 5.3 MB/s
|████████████████████████████████| 136 kB 44.7 MB/s
|████████████████████████████████| 72 kB 712 kB/s
Building wheel for jax-resnet (setup.py) ... done
In [2]:
Out[2]:
--2022-04-10 05:04:39-- http://d2l-data.s3-accelerate.amazonaws.com/hotdog.zip
Resolving d2l-data.s3-accelerate.amazonaws.com (d2l-data.s3-accelerate.amazonaws.com)... 108.156.127.60
Connecting to d2l-data.s3-accelerate.amazonaws.com (d2l-data.s3-accelerate.amazonaws.com)|108.156.127.60|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 261292301 (249M) [application/zip]
Saving to: ‘hotdog.zip’
hotdog.zip 100%[===================>] 249.19M 65.2MB/s in 4.1s
2022-04-10 05:04:43 (60.4 MB/s) - ‘hotdog.zip’ saved [261292301/261292301]
Imports
In [3]:
In [4]:
Load Data
In [5]:
In [6]:
Let's view some images(Because the images are normalized so we need to first convert them to the range of 0 to 1) in order to view them.
In [7]:
Out[7]:
/usr/local/lib/python3.7/dist-packages/jax/_src/random.py:371: FutureWarning: jax.random.shuffle is deprecated and will be removed in a future release. Use jax.random.permutation with independent=True.
warnings.warn(msg, FutureWarning)
Model
In [8]:
In [9]:
In [30]:
Out[30]:
Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0
In [11]:
In [12]:
Training
In [47]:
Out[47]:
/usr/local/lib/python3.7/dist-packages/jax/_src/random.py:371: FutureWarning: jax.random.shuffle is deprecated and will be removed in a future release. Use jax.random.permutation with independent=True.
warnings.warn(msg, FutureWarning)
Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0
Epoch : 1/10 Iteration : 31/31 Loss : 0.6226798295974731 Accuracy : 0.8125
Validation Results : Epoch : 1 Validation Loss : 0.6446382999420166 Validation Accuracy : 0.72265625
Epoch : 2/10 Iteration : 31/31 Loss : 0.5625612735748291 Accuracy : 0.828125
Validation Results : Epoch : 2 Validation Loss : 0.5996769666671753 Validation Accuracy : 0.7786458730697632
Epoch : 3/10 Iteration : 31/31 Loss : 0.5067282915115356 Accuracy : 0.859375
Validation Results : Epoch : 3 Validation Loss : 0.557037353515625 Validation Accuracy : 0.7994791865348816
Epoch : 4/10 Iteration : 31/31 Loss : 0.4547579288482666 Accuracy : 0.921875
Validation Results : Epoch : 4 Validation Loss : 0.5183985829353333 Validation Accuracy : 0.8111979365348816
Epoch : 5/10 Iteration : 31/31 Loss : 0.40734273195266724 Accuracy : 0.9375
Validation Results : Epoch : 5 Validation Loss : 0.48585638403892517 Validation Accuracy : 0.8216146230697632
Epoch : 6/10 Iteration : 31/31 Loss : 0.3637435734272003 Accuracy : 0.953125
Validation Results : Epoch : 6 Validation Loss : 0.45888015627861023 Validation Accuracy : 0.828125
Epoch : 7/10 Iteration : 31/31 Loss : 0.32245326042175293 Accuracy : 0.953125
Validation Results : Epoch : 7 Validation Loss : 0.43768686056137085 Validation Accuracy : 0.8268229365348816
Epoch : 8/10 Iteration : 31/31 Loss : 0.2837451696395874 Accuracy : 0.96875
Validation Results : Epoch : 8 Validation Loss : 0.4221525192260742 Validation Accuracy : 0.8268229365348816
Epoch : 9/10 Iteration : 31/31 Loss : 0.2477167248725891 Accuracy : 0.96875
Validation Results : Epoch : 9 Validation Loss : 0.41047853231430054 Validation Accuracy : 0.828125
Epoch : 10/10 Iteration : 31/31 Loss : 0.21442091464996338 Accuracy : 0.96875
Validation Results : Epoch : 10 Validation Loss : 0.4030081629753113 Validation Accuracy : 0.8307291865348816
Testing The Model
- 0 = No HotDog
- 1 = HotDog
In [51]:
Out[51]:
/usr/local/lib/python3.7/dist-packages/jax/_src/random.py:371: FutureWarning: jax.random.shuffle is deprecated and will be removed in a future release. Use jax.random.permutation with independent=True.
warnings.warn(msg, FutureWarning)
In [52]:
Out[52]:
True Label : 0
Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0
Prediction : 0
In [53]:
Out[53]:
True Label : 1
Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0
Prediction : 1
In [16]: