Path: blob/master/labml_nn/neox/utils/finetune.py
4925 views
from typing import List, Dict12import torch3from torch import nn45from labml_nn.neox.model import TransformerLayer, NeoXModule678class FineTuner:9def __init__(self, layers: List[NeoXModule]):10self.layers = layers1112def get_trainable_params(self) -> Dict[str, nn.Parameter]:13params = {}14for i, layer in enumerate(self.layers):15params.update(self.get_layer_trainable_params(layer, prefix=f'layer_{i :02d}'))1617return params1819def get_layer_trainable_params(self, layer: NeoXModule, prefix: str) -> Dict[str, nn.Parameter]:20raise NotImplementedError2122def set_trainable_params(self):23for layer in self.layers:24# Set `requires_grad` to `False` for the entire layer.25layer.requires_grad_(False)26#27for p in self.get_trainable_params().values():28p.requires_grad_(True)2930def state_dict(self):31return {n: p.data.cpu() for n, p in self.get_trainable_params().items()}3233def load_state_dict(self, state_dict: Dict[str, torch.Tensor]):34params = self.get_trainable_params()35for n, p in params.items():36p.data[:] = state_dict[n].to(p.data.device)3738for n in state_dict.keys():39assert n in params, n404142class FineTuneBiases(FineTuner):43def get_layer_trainable_params(self, layer: NeoXModule, prefix: str) -> Dict[str, nn.Parameter]:44params = {}4546if isinstance(layer, TransformerLayer):47# No need to train the mlp bias because we are adding it with attention output48params[f'{prefix}.attention.output.bias'] = layer.attention.output.bias49params[f'{prefix}.attention.qkv_lin.bias'] = layer.attention.qkv_lin.bias50params[f'{prefix}.ffn.dense_h_h4.bias'] = layer.ffn.dense_h_h4.bias51else:52pass5354return params555657