Path: blob/master/labml_nn/hypernetworks/experiment.py
4918 views
import torch1import torch.nn as nn2from labml import experiment3from labml.configs import option4from labml.utils.pytorch import get_modules56from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs7from labml_nn.hypernetworks.hyper_lstm import HyperLSTM8from labml_nn.lstm import LSTM91011class AutoregressiveModel(nn.Module):12"""13## Auto regressive model14"""1516def __init__(self, n_vocab: int, d_model: int, rnn_model: nn.Module):17super().__init__()18# Token embedding module19self.src_embed = nn.Embedding(n_vocab, d_model)20self.lstm = rnn_model21self.generator = nn.Linear(d_model, n_vocab)2223def forward(self, x: torch.Tensor):24x = self.src_embed(x)25# Embed the tokens (`src`) and run it through the the transformer26res, state = self.lstm(x)27# Generate logits of the next token28return self.generator(res), state293031class Configs(NLPAutoRegressionConfigs):32"""33## Configurations3435The default configs can and will be over-ridden when we start the experiment36"""3738model: AutoregressiveModel39rnn_model: nn.Module4041d_model: int = 51242n_rhn: int = 1643n_z: int = 16444546@option(Configs.model)47def autoregressive_model(c: Configs):48"""49Initialize the auto-regressive model50"""51m = AutoregressiveModel(c.n_tokens, c.d_model, c.rnn_model)52return m.to(c.device)535455@option(Configs.rnn_model)56def hyper_lstm(c: Configs):57return HyperLSTM(c.d_model, c.d_model, c.n_rhn, c.n_z, 1)585960@option(Configs.rnn_model)61def lstm(c: Configs):62return LSTM(c.d_model, c.d_model, 1)636465def main():66# Create experiment67experiment.create(name="hyper_lstm", comment='')68# Create configs69conf = Configs()70# Load configurations71experiment.configs(conf,72# A dictionary of configurations to override73{'tokenizer': 'character',74'text': 'tiny_shakespeare',75'optimizer.learning_rate': 2.5e-4,76'optimizer.optimizer': 'Adam',77'prompt': 'It is',78'prompt_separator': '',7980'rnn_model': 'hyper_lstm',8182'train_loader': 'shuffled_train_loader',83'valid_loader': 'shuffled_valid_loader',8485'seq_len': 512,86'epochs': 128,87'batch_size': 2,88'inner_iterations': 25})8990# Set models for saving and loading91experiment.add_pytorch_models(get_modules(conf))9293# Start the experiment94with experiment.start():95# `TrainValidConfigs.run`96conf.run()979899if __name__ == '__main__':100main()101102103