Contact Us!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
Avatar for stephanie's main branch.

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

| Download

"Guiding Future STEM Leaders through Innovative Research Training" ~ thinkingbeyond.education

Views: 1157
Image: ubuntu2204
Kernel: Python 3

LSTM-based Sentiment Analysis on IMDB Dataset

This notebook implements a sentiment analysis model using LSTM (Long Short-Term Memory) neural networks to classify IMDB movie reviews as either positive or negative.

Key components:

  • PyTorch for deep learning implementation

  • Hugging Face's datasets library for loading the IMDB dataset

  • DistilBERT tokenizer for text preprocessing

  • LSTM architecture for sequence processing

The model will learn to classify movie reviews as positive (1) or negative (0) based on the text content.

1. Import Required Libraries

We need several Python libraries:

  • torch: Main PyTorch library for deep learning

  • transformers: For using pre-trained tokenizers

  • datasets: For loading the IMDB dataset

  • matplotlib: For visualizing training progress

  • numpy: For numerical operations

!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 !pip install transformers !pip install datasets !pip install scikit-learn !pip install ipywidgets !pip install tqdm import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader from transformers import AutoTokenizer from datasets import load_dataset import matplotlib.pyplot as plt import numpy as np from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
Looking in indexes: https://download.pytorch.org/whl/cu118 Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.5.1+cu121) Requirement already satisfied: torchvision in /usr/local/lib/python3.10/dist-packages (0.20.1+cu121) Requirement already satisfied: torchaudio in /usr/local/lib/python3.10/dist-packages (2.5.1+cu121) Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.16.1) Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch) (4.12.2) Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.4.2) Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.4) Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch) (2024.10.0) Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.10/dist-packages (from torch) (1.13.1) Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy==1.13.1->torch) (1.3.0) Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from torchvision) (1.26.4) Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.10/dist-packages (from torchvision) (11.0.0) Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (3.0.2) Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.47.1) Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.16.1) Requirement already satisfied: huggingface-hub<1.0,>=0.24.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.27.0) Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.26.4) Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (24.2) Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.2) Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2024.11.6) Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.32.3) Requirement already satisfied: tokenizers<0.22,>=0.21 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.21.0) Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.4.5) Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.67.1) Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.24.0->transformers) (2024.10.0) Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.24.0->transformers) (4.12.2) Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.4.0) Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.10) Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.2.3) Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2024.12.14) Collecting datasets Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB) Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from datasets) (3.16.1) Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (1.26.4) Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (17.0.0) Collecting dill<0.3.9,>=0.3.0 (from datasets) Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB) Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (2.2.2) Requirement already satisfied: requests>=2.32.2 in /usr/local/lib/python3.10/dist-packages (from datasets) (2.32.3) Requirement already satisfied: tqdm>=4.66.3 in /usr/local/lib/python3.10/dist-packages (from datasets) (4.67.1) Collecting xxhash (from datasets) Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB) Collecting multiprocess<0.70.17 (from datasets) Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB) Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets) Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB) Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.11.10) Requirement already satisfied: huggingface-hub>=0.23.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.27.0) Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets) (24.2) Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (6.0.2) Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (2.4.4) Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.2) Requirement already satisfied: async-timeout<6.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3) Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (24.3.0) Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.5.0) Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.1.0) Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (0.2.1) Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.18.3) Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.23.0->datasets) (4.12.2) Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (3.4.0) Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (3.10) Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (2.2.3) Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (2024.12.14) Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2) Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.2) Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.2) Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.17.0) Downloading datasets-3.2.0-py3-none-any.whl (480 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 480.6/480.6 kB 9.2 MB/s eta 0:00:00 Downloading dill-0.3.8-py3-none-any.whl (116 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 116.3/116.3 kB 4.0 MB/s eta 0:00:00 Downloading fsspec-2024.9.0-py3-none-any.whl (179 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 179.3/179.3 kB 6.6 MB/s eta 0:00:00 Downloading multiprocess-0.70.16-py310-none-any.whl (134 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 134.8/134.8 kB 11.9 MB/s eta 0:00:00 Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 194.1/194.1 kB 18.2 MB/s eta 0:00:00 Installing collected packages: xxhash, fsspec, dill, multiprocess, datasets Attempting uninstall: fsspec Found existing installation: fsspec 2024.10.0 Uninstalling fsspec-2024.10.0: Successfully uninstalled fsspec-2024.10.0 ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts. gcsfs 2024.10.0 requires fsspec==2024.10.0, but you have fsspec 2024.9.0 which is incompatible. Successfully installed datasets-3.2.0 dill-0.3.8 fsspec-2024.9.0 multiprocess-0.70.16 xxhash-3.5.0 Requirement already satisfied: scikit-learn in /usr/local/lib/python3.10/dist-packages (1.6.0) Requirement already satisfied: numpy>=1.19.5 in /usr/local/lib/python3.10/dist-packages (from scikit-learn) (1.26.4) Requirement already satisfied: scipy>=1.6.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn) (1.13.1) Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn) (1.4.2) Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn) (3.5.0) Requirement already satisfied: ipywidgets in /usr/local/lib/python3.10/dist-packages (7.7.1) Requirement already satisfied: ipykernel>=4.5.1 in /usr/local/lib/python3.10/dist-packages (from ipywidgets) (5.5.6) Requirement already satisfied: ipython-genutils~=0.2.0 in /usr/local/lib/python3.10/dist-packages (from ipywidgets) (0.2.0) Requirement already satisfied: traitlets>=4.3.1 in /usr/local/lib/python3.10/dist-packages (from ipywidgets) (5.7.1) Requirement already satisfied: widgetsnbextension~=3.6.0 in /usr/local/lib/python3.10/dist-packages (from ipywidgets) (3.6.10) Requirement already satisfied: ipython>=4.0.0 in /usr/local/lib/python3.10/dist-packages (from ipywidgets) (7.34.0) Requirement already satisfied: jupyterlab-widgets>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from ipywidgets) (3.0.13) Requirement already satisfied: jupyter-client in /usr/local/lib/python3.10/dist-packages (from ipykernel>=4.5.1->ipywidgets) (6.1.12) Requirement already satisfied: tornado>=4.2 in /usr/local/lib/python3.10/dist-packages (from ipykernel>=4.5.1->ipywidgets) (6.3.3) Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.10/dist-packages (from ipython>=4.0.0->ipywidgets) (75.1.0) Collecting jedi>=0.16 (from ipython>=4.0.0->ipywidgets) Downloading jedi-0.19.2-py2.py3-none-any.whl.metadata (22 kB) Requirement already satisfied: decorator in /usr/local/lib/python3.10/dist-packages (from ipython>=4.0.0->ipywidgets) (4.4.2) Requirement already satisfied: pickleshare in /usr/local/lib/python3.10/dist-packages (from ipython>=4.0.0->ipywidgets) (0.7.5) Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from ipython>=4.0.0->ipywidgets) (3.0.48) Requirement already satisfied: pygments in /usr/local/lib/python3.10/dist-packages (from ipython>=4.0.0->ipywidgets) (2.18.0) Requirement already satisfied: backcall in /usr/local/lib/python3.10/dist-packages (from ipython>=4.0.0->ipywidgets) (0.2.0) Requirement already satisfied: matplotlib-inline in /usr/local/lib/python3.10/dist-packages (from ipython>=4.0.0->ipywidgets) (0.1.7) Requirement already satisfied: pexpect>4.3 in /usr/local/lib/python3.10/dist-packages (from ipython>=4.0.0->ipywidgets) (4.9.0) Requirement already satisfied: notebook>=4.4.1 in /usr/local/lib/python3.10/dist-packages (from widgetsnbextension~=3.6.0->ipywidgets) (6.5.5) Requirement already satisfied: parso<0.9.0,>=0.8.4 in /usr/local/lib/python3.10/dist-packages (from jedi>=0.16->ipython>=4.0.0->ipywidgets) (0.8.4) Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (3.1.4) Requirement already satisfied: pyzmq<25,>=17 in /usr/local/lib/python3.10/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (24.0.1) Requirement already satisfied: argon2-cffi in /usr/local/lib/python3.10/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (23.1.0) Requirement already satisfied: jupyter-core>=4.6.1 in /usr/local/lib/python3.10/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (5.7.2) Requirement already satisfied: nbformat in /usr/local/lib/python3.10/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (5.10.4) Requirement already satisfied: nbconvert>=5 in /usr/local/lib/python3.10/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (7.16.4) Requirement already satisfied: nest-asyncio>=1.5 in /usr/local/lib/python3.10/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (1.6.0) Requirement already satisfied: Send2Trash>=1.8.0 in /usr/local/lib/python3.10/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (1.8.3) Requirement already satisfied: terminado>=0.8.3 in /usr/local/lib/python3.10/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (0.18.1) Requirement already satisfied: prometheus-client in /usr/local/lib/python3.10/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (0.21.1) Requirement already satisfied: nbclassic>=0.4.7 in /usr/local/lib/python3.10/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (1.1.0) Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.10/dist-packages (from jupyter-client->ipykernel>=4.5.1->ipywidgets) (2.8.2) Requirement already satisfied: ptyprocess>=0.5 in /usr/local/lib/python3.10/dist-packages (from pexpect>4.3->ipython>=4.0.0->ipywidgets) (0.7.0) Requirement already satisfied: wcwidth in /usr/local/lib/python3.10/dist-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->ipython>=4.0.0->ipywidgets) (0.2.13) Requirement already satisfied: platformdirs>=2.5 in /usr/local/lib/python3.10/dist-packages (from jupyter-core>=4.6.1->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (4.3.6) Requirement already satisfied: notebook-shim>=0.2.3 in /usr/local/lib/python3.10/dist-packages (from nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (0.2.4) Requirement already satisfied: beautifulsoup4 in /usr/local/lib/python3.10/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (4.12.3) Requirement already satisfied: bleach!=5.0.0 in /usr/local/lib/python3.10/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (6.2.0) Requirement already satisfied: defusedxml in /usr/local/lib/python3.10/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (0.7.1) Requirement already satisfied: jupyterlab-pygments in /usr/local/lib/python3.10/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (0.3.0) Requirement already satisfied: markupsafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (3.0.2) Requirement already satisfied: mistune<4,>=2.0.3 in /usr/local/lib/python3.10/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (3.0.2) Requirement already satisfied: nbclient>=0.5.0 in /usr/local/lib/python3.10/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (0.10.1) Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (24.2) Requirement already satisfied: pandocfilters>=1.4.1 in /usr/local/lib/python3.10/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (1.5.1) Requirement already satisfied: tinycss2 in /usr/local/lib/python3.10/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (1.4.0) Requirement already satisfied: fastjsonschema>=2.15 in /usr/local/lib/python3.10/dist-packages (from nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (2.21.1) Requirement already satisfied: jsonschema>=2.6 in /usr/local/lib/python3.10/dist-packages (from nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (4.23.0) Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.1->jupyter-client->ipykernel>=4.5.1->ipywidgets) (1.17.0) Requirement already satisfied: argon2-cffi-bindings in /usr/local/lib/python3.10/dist-packages (from argon2-cffi->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (21.2.0) Requirement already satisfied: webencodings in /usr/local/lib/python3.10/dist-packages (from bleach!=5.0.0->nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (0.5.1) Requirement already satisfied: attrs>=22.2.0 in /usr/local/lib/python3.10/dist-packages (from jsonschema>=2.6->nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (24.3.0) Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.10/dist-packages (from jsonschema>=2.6->nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (2024.10.1) Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.10/dist-packages (from jsonschema>=2.6->nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (0.35.1) Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.10/dist-packages (from jsonschema>=2.6->nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (0.22.3) Requirement already satisfied: jupyter-server<3,>=1.8 in /usr/local/lib/python3.10/dist-packages (from notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (1.24.0) Requirement already satisfied: cffi>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from argon2-cffi-bindings->argon2-cffi->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (1.17.1) Requirement already satisfied: soupsieve>1.2 in /usr/local/lib/python3.10/dist-packages (from beautifulsoup4->nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (2.6) Requirement already satisfied: pycparser in /usr/local/lib/python3.10/dist-packages (from cffi>=1.0.1->argon2-cffi-bindings->argon2-cffi->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (2.22) Requirement already satisfied: anyio<4,>=3.1.0 in /usr/local/lib/python3.10/dist-packages (from jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (3.7.1) Requirement already satisfied: websocket-client in /usr/local/lib/python3.10/dist-packages (from jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (1.8.0) Requirement already satisfied: idna>=2.8 in /usr/local/lib/python3.10/dist-packages (from anyio<4,>=3.1.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (3.10) Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.10/dist-packages (from anyio<4,>=3.1.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (1.3.1) Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<4,>=3.1.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets) (1.2.2) Downloading jedi-0.19.2-py2.py3-none-any.whl (1.6 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.6/1.6 MB 38.7 MB/s eta 0:00:00 Installing collected packages: jedi Successfully installed jedi-0.19.2 Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (4.67.1)

2. Custom Dataset Class

The SentimentDataset class handles data preprocessing:

  1. Accepts raw text data and tokenizer

  2. Converts text to token IDs using the tokenizer

  3. Handles padding and truncation to ensure fixed length

  4. Creates attention masks for valid tokens

  5. Converts labels to tensor format

Key parameters:

  • max_length: Maximum sequence length (default: 128)

  • padding: Set to 'max_length' to ensure uniform sizes

  • truncation: True to handle reviews longer than max_length

class SentimentDataset(Dataset): def __init__(self, dataset_split, tokenizer, max_length=128): self.data = dataset_split self.tokenizer = tokenizer self.max_length = max_length def __len__(self): return len(self.data) def __getitem__(self, idx): text = self.data[idx]['text'] label = self.data[idx]['label'] encoding = self.tokenizer( text, max_length=self.max_length, padding='max_length', truncation=True, return_tensors='pt' ) return { 'input_ids': encoding['input_ids'].squeeze(), 'attention_mask': encoding['attention_mask'].squeeze(), 'labels': torch.tensor(label, dtype=torch.float) }

3. LSTM Model Architecture

The LSTMClassifier implements a neural network with:

  1. Embedding Layer: Converts token IDs to dense vectors

    • Input: Vocabulary size

    • Output: 300-dimensional embeddings

  2. LSTM Layer: Processes the sequence

    • Input: 300-dimensional vectors

    • Hidden size: 256 dimensions

    • 2 stacked LSTM layers

    • Includes dropout for regularization

  3. Output Layer: Final classification

    • Linear layer converting to single score

    • Sigmoid activation (implicit through loss function)

The model incorporates attention masks to handle variable-length sequences properly.

class LSTMClassifier(nn.Module): def __init__(self, vocab_size, embedding_dim=300, hidden_dim=256, n_layers=2, dropout=0.3): super().__init__() self.embedding = nn.Embedding(vocab_size, embedding_dim) self.lstm = nn.LSTM(embedding_dim, hidden_dim, n_layers, batch_first=True, dropout=dropout if n_layers > 1 else 0) self.dropout = nn.Dropout(dropout) self.fc = nn.Linear(hidden_dim, 1) def forward(self, input_ids, attention_mask): embedded = self.embedding(input_ids) # Apply attention mask embedded = embedded * attention_mask.unsqueeze(-1) # LSTM forward pass lstm_out, _ = self.lstm(embedded) # Get final hidden state final_hidden_state = lstm_out[:, -1, :] # Apply dropout and classification layer output = self.dropout(final_hidden_state) output = self.fc(output) return output

4. Training and Evaluation Functions

The training loop includes:

  1. Per Epoch:

    • Training phase with gradient updates

    • Validation phase without gradients

    • Loss computation and optimization

    • Progress tracking and metrics calculation

  2. Key Features:

    • Gradient clipping (max norm: 1.0)

    • Early saving of best model

    • Training and validation loss tracking

    • Accuracy monitoring

  3. Hyperparameters:

    • Learning rate: 2e-5

    • Batch size: 32

    • Number of epochs: 5

    • Loss function: BCEWithLogitsLoss

def train_model(model, train_loader, valid_loader, criterion, optimizer, device, num_epochs=5): best_valid_loss = float('inf') train_losses, valid_losses = [], [] for epoch in range(num_epochs): # Training model.train() train_loss = 0 for batch in train_loader: input_ids = batch['input_ids'].to(device) attention_mask = batch['attention_mask'].to(device) labels = batch['labels'].to(device) optimizer.zero_grad() outputs = model(input_ids, attention_mask) loss = criterion(outputs.squeeze(), labels) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() train_loss += loss.item() train_loss = train_loss / len(train_loader) train_losses.append(train_loss) # Validation model.eval() valid_loss = 0 correct = 0 total = 0 with torch.no_grad(): for batch in valid_loader: input_ids = batch['input_ids'].to(device) attention_mask = batch['attention_mask'].to(device) labels = batch['labels'].to(device) outputs = model(input_ids, attention_mask) loss = criterion(outputs.squeeze(), labels) valid_loss += loss.item() predictions = (outputs.squeeze() > 0.5).float() correct += (predictions == labels).sum().item() total += labels.size(0) valid_loss = valid_loss / len(valid_loader) valid_losses.append(valid_loss) accuracy = correct / total print(f'Epoch {epoch+1}/{num_epochs}:') print(f'Training Loss: {train_loss:.4f}') print(f'Validation Loss: {valid_loss:.4f}') print(f'Validation Accuracy: {accuracy:.4f}') if valid_loss < best_valid_loss: best_valid_loss = valid_loss torch.save(model.state_dict(), 'best_model.pt') return train_losses, valid_losses

5. Main Training Pipeline

The training process follows these steps:

  1. Setup:

    • GPU/CPU device selection

    • Dataset loading and preprocessing

    • Model initialization

  2. Training:

    • Batched processing of reviews

    • Forward and backward passes

    • Model parameter updates

  3. Monitoring:

    • Loss tracking for both training and validation

    • Learning curves visualization

  4. Results:

    • Best model saved during training

    • Final performance metrics

    • Training progress plots

def main(): # Detailed CUDA diagnostics print("PyTorch version:", torch.__version__) print("CUDA available:", torch.cuda.is_available()) print("CUDA version:", torch.version.cuda) if not torch.cuda.is_available(): print("\nWARNING: CUDA not available. Checking system:") print("1. Check if NVIDIA GPU is present:") import subprocess try: nvidia_smi = subprocess.check_output(["nvidia-smi"]).decode('utf-8') print(nvidia_smi) except: print("nvidia-smi command failed - GPU may not be present or drivers not installed") # Try to force CUDA device try: device = torch.device("cuda:0") torch.cuda.set_device(device) print(f"\nSuccessfully set device to: {device}") print(f"Using GPU: {torch.cuda.get_device_name(0)}") print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB") except Exception as e: print(f"\nFailed to set CUDA device: {e}") print("Falling back to CPU") device = torch.device("cpu") # Load dataset dataset = load_dataset('imdb') tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased') # Create datasets train_dataset = SentimentDataset(dataset['train'], tokenizer) valid_dataset = SentimentDataset(dataset['test'], tokenizer) # Create dataloaders batch_size = 32 train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) valid_loader = DataLoader(valid_dataset, batch_size=batch_size) # Initialize model model = LSTMClassifier( vocab_size=tokenizer.vocab_size, embedding_dim=300, hidden_dim=256, n_layers=2 ).to(device) # Training parameters criterion = nn.BCEWithLogitsLoss() optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5) # Train model train_losses, valid_losses = train_model( model, train_loader, valid_loader, criterion, optimizer, device, num_epochs=5 ) # Plot training curves plt.figure(figsize=(10, 6)) plt.plot(train_losses, label='Training Loss') plt.plot(valid_losses, label='Validation Loss') plt.xlabel('Epoch') plt.ylabel('Loss') plt.title('Training and Validation Loss') plt.legend() plt.show() if __name__ == '__main__': main()
PyTorch version: 2.5.1+cu121 CUDA available: True CUDA version: 12.1 Successfully set device to: cuda:0 Using GPU: Tesla T4 GPU Memory: 15.84 GB
/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_auth.py:94: UserWarning: The secret `HF_TOKEN` does not exist in your Colab secrets. To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session. You will be able to reuse this secret in all of your notebooks. Please note that authentication is recommended but still optional to access public models or datasets. warnings.warn(
README.md: 0%| | 0.00/7.81k [00:00<?, ?B/s]
train-00000-of-00001.parquet: 0%| | 0.00/21.0M [00:00<?, ?B/s]
test-00000-of-00001.parquet: 0%| | 0.00/20.5M [00:00<?, ?B/s]
unsupervised-00000-of-00001.parquet: 0%| | 0.00/42.0M [00:00<?, ?B/s]
Generating train split: 0%| | 0/25000 [00:00<?, ? examples/s]
Generating test split: 0%| | 0/25000 [00:00<?, ? examples/s]
Generating unsupervised split: 0%| | 0/50000 [00:00<?, ? examples/s]
tokenizer_config.json: 0%| | 0.00/48.0 [00:00<?, ?B/s]
config.json: 0%| | 0.00/483 [00:00<?, ?B/s]
vocab.txt: 0%| | 0.00/232k [00:00<?, ?B/s]
tokenizer.json: 0%| | 0.00/466k [00:00<?, ?B/s]