Path: blob/master/labml_nn/rwkv/experiment.py
4925 views
import inspect1import math23import torch4import torch.nn as nn5from labml_nn.rwkv.configs import RWKVConfigs67from labml_nn.rwkv import RWKV8from labml_nn.rwkv import TimeMixing9from labml import experiment10from labml.configs import option11from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs121314class Configs(NLPAutoRegressionConfigs):15"""16## Configurations1718This inherits from19[`NLPAutoRegressionConfigs`](../../experiments/nlp_autoregression.html#NLPAutoRegressionConfigs)20"""2122# RWKV model23model: RWKV2425rwkv: RWKVConfigs26# number of warmup iterations27warmup_iters: int = 200028# total number of training iterations29max_iters: int = 60000030# weight decay31weight_decay: float = 1e-132# Custom optimizer33beta1: float = 0.934beta2: float = 0.9535optimizer = 'rwkv_optimizer'363738@option(Configs.rwkv, 'RWKV')39def _rwkv_configs(c: Configs):40"""41### RWKV configurations42"""4344# We use our45# [configurable RWKV implementation](../configs.html#RWKVConfigs)46conf = RWKVConfigs()47# Set the vocabulary sizes for embeddings and generating logits48conf.n_src_vocab = c.n_tokens49conf.n_tgt_vocab = c.n_tokens5051return conf525354def _init_weights(module, rwkv: RWKVConfigs):55# initialize Vector Parameters in TimeMixing56if isinstance(module, TimeMixing):57layer_id = module.layer_id58n_layer = module.n_layer59n_embd = module.n_embd60attn_sz = n_embd6162with torch.no_grad():63ratio_0_to_1 = layer_id / (n_layer - 1) # 0 to 164ratio_1_to_almost0 = 1.0 - (layer_id / n_layer) # 1 to ~065ddd = torch.ones(1, 1, n_embd)66for i in range(n_embd):67ddd[0, 0, i] = i / n_embd6869decay_speed = torch.ones(attn_sz)70for h in range(attn_sz):71decay_speed[h] = -5 + 8 * (h / (attn_sz - 1)) ** (0.7 + 1.3 * ratio_0_to_1)72module.time_decay = nn.Parameter(decay_speed)7374zigzag = torch.tensor([(i + 1) % 3 - 1 for i in range(attn_sz)]) * 0.575module.time_first = nn.Parameter(torch.ones(attn_sz) * math.log(0.3) + zigzag)76module.time_mix_key = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))77module.time_mix_value = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)78module.time_mix_receptance = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))798081@option(Configs.model)82def _model(c: Configs):83"""84Create RWKV model and initialize weights85"""86m = RWKV(c.rwkv).to(c.device)8788# Apply custom weight initialization89m.apply(_init_weights, c.rwkv)9091return m929394@option(NLPAutoRegressionConfigs.optimizer)95def _configure_optimizers(c: NLPAutoRegressionConfigs):96# start with all of the candidate parameters97param_dict = {pn: p for pn, p in c.model.named_parameters()}98# filter out those that do not require grad99param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}100# create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.101# i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.102decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]103nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]104optim_groups = [105{'params': decay_params, 'weight_decay': c.weight_decay},106{'params': nodecay_params, 'weight_decay': 0.0}107]108num_decay_params = sum(p.numel() for p in decay_params)109num_nodecay_params = sum(p.numel() for p in nodecay_params)110print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")111print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")112# Create AdamW optimizer and use the fused version if it is available113fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters114use_fused = fused_available and c.device_type == 'cuda'115extra_args = dict(fused=True) if use_fused else dict()116optimizer = torch.optim.AdamW(optim_groups, lr=c.learning_rate, betas=c.betas, **extra_args)117print(f"using fused AdamW: {use_fused}")118119return optimizer120121122def main():123# Create experiment124experiment.create(name="RWKV")125# Create configs126conf = Configs()127print(conf.model)128# Override configurations129experiment.configs(conf, {130# Use character level tokenizer131'tokenizer': 'character',132# Prompt separator is blank133'prompt_separator': '',134# Starting prompt for sampling135'prompt': 'It is ',136# Use Tiny Shakespeare dataset137'text': 'tiny_shakespeare',138139# Use a context size of $128$140'seq_len': 128,141# Train for $32$ epochs142'epochs': 32,143# Batch size $128$144'batch_size': 128,145# Switch between training and validation for $10$ times146# per epoch147'inner_iterations': 10,148149'rwkv.block_size': 1024,150# model151'rwkv.n_layer': 12,152'rwkv.n_heads': 12,153'rwkv.n_embd': 768154})155156print(conf.model)157# Set models for saving and loading158experiment.add_pytorch_models({'model': conf.model})159160# Start the experiment161with experiment.start():162# Run training163conf.run()164165166#167if __name__ == '__main__':168main()169170171