Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
labmlai
GitHub Repository: labmlai/annotated_deep_learning_paper_implementations
Path: blob/master/labml_nn/neox/utils/finetune.py
4925 views
1
from typing import List, Dict
2
3
import torch
4
from torch import nn
5
6
from labml_nn.neox.model import TransformerLayer, NeoXModule
7
8
9
class FineTuner:
10
def __init__(self, layers: List[NeoXModule]):
11
self.layers = layers
12
13
def get_trainable_params(self) -> Dict[str, nn.Parameter]:
14
params = {}
15
for i, layer in enumerate(self.layers):
16
params.update(self.get_layer_trainable_params(layer, prefix=f'layer_{i :02d}'))
17
18
return params
19
20
def get_layer_trainable_params(self, layer: NeoXModule, prefix: str) -> Dict[str, nn.Parameter]:
21
raise NotImplementedError
22
23
def set_trainable_params(self):
24
for layer in self.layers:
25
# Set `requires_grad` to `False` for the entire layer.
26
layer.requires_grad_(False)
27
#
28
for p in self.get_trainable_params().values():
29
p.requires_grad_(True)
30
31
def state_dict(self):
32
return {n: p.data.cpu() for n, p in self.get_trainable_params().items()}
33
34
def load_state_dict(self, state_dict: Dict[str, torch.Tensor]):
35
params = self.get_trainable_params()
36
for n, p in params.items():
37
p.data[:] = state_dict[n].to(p.data.device)
38
39
for n in state_dict.keys():
40
assert n in params, n
41
42
43
class FineTuneBiases(FineTuner):
44
def get_layer_trainable_params(self, layer: NeoXModule, prefix: str) -> Dict[str, nn.Parameter]:
45
params = {}
46
47
if isinstance(layer, TransformerLayer):
48
# No need to train the mlp bias because we are adding it with attention output
49
params[f'{prefix}.attention.output.bias'] = layer.attention.output.bias
50
params[f'{prefix}.attention.qkv_lin.bias'] = layer.attention.qkv_lin.bias
51
params[f'{prefix}.ffn.dense_h_h4.bias'] = layer.ffn.dense_h_h4.bias
52
else:
53
pass
54
55
return params
56
57