Path: blob/master/notebooks/book1/15/cnn1d_sentiment_jax.ipynb
1192 views
Please find torch implementation of this notebook here: https://colab.research.google.com/github/probml/pyprobml/blob/master/notebooks/book1/15/cnn1d_sentiment_torch.ipynb
1d CNNs for sentiment classification
We use 1d CNNs for IMDB movie review classification. Based on sec 15.3 of http://d2l.ai/chapter_natural-language-processing-applications/sentiment-analysis-cnn.html
Collecting flax
Downloading flax-0.4.1-py3-none-any.whl (184 kB)
|████████████████████████████████| 184 kB 4.3 MB/s
Collecting optax
Downloading optax-0.1.1-py3-none-any.whl (136 kB)
|████████████████████████████████| 136 kB 50.7 MB/s
Requirement already satisfied: jax>=0.3 in /usr/local/lib/python3.7/dist-packages (from flax) (0.3.4)
Requirement already satisfied: msgpack in /usr/local/lib/python3.7/dist-packages (from flax) (1.0.3)
Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from flax) (3.2.2)
Requirement already satisfied: numpy>=1.12 in /usr/local/lib/python3.7/dist-packages (from flax) (1.21.5)
Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax>=0.3->flax) (3.3.0)
Requirement already satisfied: absl-py in /usr/local/lib/python3.7/dist-packages (from jax>=0.3->flax) (1.0.0)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from jax>=0.3->flax) (3.10.0.2)
Requirement already satisfied: scipy>=1.2.1 in /usr/local/lib/python3.7/dist-packages (from jax>=0.3->flax) (1.4.1)
Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py->jax>=0.3->flax) (1.15.0)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax) (0.11.0)
Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax) (2.8.2)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax) (1.4.0)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->flax) (3.0.7)
Requirement already satisfied: jaxlib>=0.1.37 in /usr/local/lib/python3.7/dist-packages (from optax->flax) (0.3.2+cuda11.cudnn805)
Collecting chex>=0.0.4
Downloading chex-0.1.2-py3-none-any.whl (72 kB)
|████████████████████████████████| 72 kB 786 kB/s
Requirement already satisfied: dm-tree>=0.1.5 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax->flax) (0.1.6)
Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax->flax) (0.11.2)
Requirement already satisfied: flatbuffers<3.0,>=1.12 in /usr/local/lib/python3.7/dist-packages (from jaxlib>=0.1.37->optax->flax) (2.0)
Installing collected packages: chex, optax, flax
Successfully installed chex-0.1.2 flax-0.4.1 optax-0.1.1
Data
We use IMDB dataset. Details in this colab.
We tokenize using words, and drop words which occur less than 5 times in training set when creating the vocab.
We pad all sequences to length 500, for efficient minibatching.
Putting it altogether.
Model
We load pretrained Glove vectors. We use these to initialize the embedding layers, one of which is frozen.
We use 2 embedding layers, one with frozen weights, and one with learnable weights. We feed their concatenation to the 1d CNN. We then do average pooling over time before passing into the final MLP to map to the 2 output logits.