Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/hypernetworks/experiment.py
4918 views
1
import torch
2
import torch.nn as nn
3
from labml import experiment
4
from labml.configs import option
5
from labml.utils.pytorch import get_modules
6
7
from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
8
from labml_nn.hypernetworks.hyper_lstm import HyperLSTM
9
from labml_nn.lstm import LSTM
10
11
12
class AutoregressiveModel(nn.Module):
13
"""
14
## Auto regressive model
15
"""
16
17
def __init__(self, n_vocab: int, d_model: int, rnn_model: nn.Module):
18
super().__init__()
19
# Token embedding module
20
self.src_embed = nn.Embedding(n_vocab, d_model)
21
self.lstm = rnn_model
22
self.generator = nn.Linear(d_model, n_vocab)
23
24
def forward(self, x: torch.Tensor):
25
x = self.src_embed(x)
26
# Embed the tokens (`src`) and run it through the the transformer
27
res, state = self.lstm(x)
28
# Generate logits of the next token
29
return self.generator(res), state
30
31
32
class Configs(NLPAutoRegressionConfigs):
33
"""
34
## Configurations
35
36
The default configs can and will be over-ridden when we start the experiment
37
"""
38
39
model: AutoregressiveModel
40
rnn_model: nn.Module
41
42
d_model: int = 512
43
n_rhn: int = 16
44
n_z: int = 16
45
46
47
@option(Configs.model)
48
def autoregressive_model(c: Configs):
49
"""
50
Initialize the auto-regressive model
51
"""
52
m = AutoregressiveModel(c.n_tokens, c.d_model, c.rnn_model)
53
return m.to(c.device)
54
55
56
@option(Configs.rnn_model)
57
def hyper_lstm(c: Configs):
58
return HyperLSTM(c.d_model, c.d_model, c.n_rhn, c.n_z, 1)
59
60
61
@option(Configs.rnn_model)
62
def lstm(c: Configs):
63
return LSTM(c.d_model, c.d_model, 1)
64
65
66
def main():
67
# Create experiment
68
experiment.create(name="hyper_lstm", comment='')
69
# Create configs
70
conf = Configs()
71
# Load configurations
72
experiment.configs(conf,
73
# A dictionary of configurations to override
74
{'tokenizer': 'character',
75
'text': 'tiny_shakespeare',
76
'optimizer.learning_rate': 2.5e-4,
77
'optimizer.optimizer': 'Adam',
78
'prompt': 'It is',
79
'prompt_separator': '',
80
81
'rnn_model': 'hyper_lstm',
82
83
'train_loader': 'shuffled_train_loader',
84
'valid_loader': 'shuffled_valid_loader',
85
86
'seq_len': 512,
87
'epochs': 128,
88
'batch_size': 2,
89
'inner_iterations': 25})
90
91
# Set models for saving and loading
92
experiment.add_pytorch_models(get_modules(conf))
93
94
# Start the experiment
95
with experiment.start():
96
# `TrainValidConfigs.run`
97
conf.run()
98
99
100
if __name__ == '__main__':
101
main()
102
103