Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/rwkv/experiment.py
4925 views
1
import inspect
2
import math
3
4
import torch
5
import torch.nn as nn
6
from labml_nn.rwkv.configs import RWKVConfigs
7
8
from labml_nn.rwkv import RWKV
9
from labml_nn.rwkv import TimeMixing
10
from labml import experiment
11
from labml.configs import option
12
from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
13
14
15
class Configs(NLPAutoRegressionConfigs):
16
"""
17
## Configurations
18
19
This inherits from
20
[`NLPAutoRegressionConfigs`](../../experiments/nlp_autoregression.html#NLPAutoRegressionConfigs)
21
"""
22
23
# RWKV model
24
model: RWKV
25
26
rwkv: RWKVConfigs
27
# number of warmup iterations
28
warmup_iters: int = 2000
29
# total number of training iterations
30
max_iters: int = 600000
31
# weight decay
32
weight_decay: float = 1e-1
33
# Custom optimizer
34
beta1: float = 0.9
35
beta2: float = 0.95
36
optimizer = 'rwkv_optimizer'
37
38
39
@option(Configs.rwkv, 'RWKV')
40
def _rwkv_configs(c: Configs):
41
"""
42
### RWKV configurations
43
"""
44
45
# We use our
46
# [configurable RWKV implementation](../configs.html#RWKVConfigs)
47
conf = RWKVConfigs()
48
# Set the vocabulary sizes for embeddings and generating logits
49
conf.n_src_vocab = c.n_tokens
50
conf.n_tgt_vocab = c.n_tokens
51
52
return conf
53
54
55
def _init_weights(module, rwkv: RWKVConfigs):
56
# initialize Vector Parameters in TimeMixing
57
if isinstance(module, TimeMixing):
58
layer_id = module.layer_id
59
n_layer = module.n_layer
60
n_embd = module.n_embd
61
attn_sz = n_embd
62
63
with torch.no_grad():
64
ratio_0_to_1 = layer_id / (n_layer - 1) # 0 to 1
65
ratio_1_to_almost0 = 1.0 - (layer_id / n_layer) # 1 to ~0
66
ddd = torch.ones(1, 1, n_embd)
67
for i in range(n_embd):
68
ddd[0, 0, i] = i / n_embd
69
70
decay_speed = torch.ones(attn_sz)
71
for h in range(attn_sz):
72
decay_speed[h] = -5 + 8 * (h / (attn_sz - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
73
module.time_decay = nn.Parameter(decay_speed)
74
75
zigzag = torch.tensor([(i + 1) % 3 - 1 for i in range(attn_sz)]) * 0.5
76
module.time_first = nn.Parameter(torch.ones(attn_sz) * math.log(0.3) + zigzag)
77
module.time_mix_key = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
78
module.time_mix_value = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
79
module.time_mix_receptance = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))
80
81
82
@option(Configs.model)
83
def _model(c: Configs):
84
"""
85
Create RWKV model and initialize weights
86
"""
87
m = RWKV(c.rwkv).to(c.device)
88
89
# Apply custom weight initialization
90
m.apply(_init_weights, c.rwkv)
91
92
return m
93
94
95
@option(NLPAutoRegressionConfigs.optimizer)
96
def _configure_optimizers(c: NLPAutoRegressionConfigs):
97
# start with all of the candidate parameters
98
param_dict = {pn: p for pn, p in c.model.named_parameters()}
99
# filter out those that do not require grad
100
param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
101
# create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
102
# i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
103
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
104
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
105
optim_groups = [
106
{'params': decay_params, 'weight_decay': c.weight_decay},
107
{'params': nodecay_params, 'weight_decay': 0.0}
108
]
109
num_decay_params = sum(p.numel() for p in decay_params)
110
num_nodecay_params = sum(p.numel() for p in nodecay_params)
111
print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
112
print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
113
# Create AdamW optimizer and use the fused version if it is available
114
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
115
use_fused = fused_available and c.device_type == 'cuda'
116
extra_args = dict(fused=True) if use_fused else dict()
117
optimizer = torch.optim.AdamW(optim_groups, lr=c.learning_rate, betas=c.betas, **extra_args)
118
print(f"using fused AdamW: {use_fused}")
119
120
return optimizer
121
122
123
def main():
124
# Create experiment
125
experiment.create(name="RWKV")
126
# Create configs
127
conf = Configs()
128
print(conf.model)
129
# Override configurations
130
experiment.configs(conf, {
131
# Use character level tokenizer
132
'tokenizer': 'character',
133
# Prompt separator is blank
134
'prompt_separator': '',
135
# Starting prompt for sampling
136
'prompt': 'It is ',
137
# Use Tiny Shakespeare dataset
138
'text': 'tiny_shakespeare',
139
140
# Use a context size of $128$
141
'seq_len': 128,
142
# Train for $32$ epochs
143
'epochs': 32,
144
# Batch size $128$
145
'batch_size': 128,
146
# Switch between training and validation for $10$ times
147
# per epoch
148
'inner_iterations': 10,
149
150
'rwkv.block_size': 1024,
151
# model
152
'rwkv.n_layer': 12,
153
'rwkv.n_heads': 12,
154
'rwkv.n_embd': 768
155
})
156
157
print(conf.model)
158
# Set models for saving and loading
159
experiment.add_pytorch_models({'model': conf.model})
160
161
# Start the experiment
162
with experiment.start():
163
# Run training
164
conf.run()
165
166
167
#
168
if __name__ == '__main__':
169
main()
170
171