Path: blob/master/labml_nn/neox/samples/finetune.py
4935 views
"""1---2title: Fine Tune GPT-NeoX3summary: >4Fine tune GPT-NeoX biases with Fairscale pipeline parallel module5---67# Fine Tune GPT-NeoX89This shows how to fine tune GPT-NeoX with pipeline parallelism.10"""1112import fairscale13import torch14import torch.nn as nn15import torch.utils.data16import torch.utils.data17import typing18from torch.utils.data import DataLoader, RandomSampler1920from labml import experiment, monit, tracker, lab21from labml.configs import option22from labml.logger import inspect23from labml_nn.neox.utils.text_dataset import get_training_data24from labml_nn.neox.utils.finetune import FineTuneBiases25from labml_nn.neox.model import LayerGenerator, NeoXModule26from labml_nn.neox.utils import balance_layers_simple27from labml_nn.neox.utils.trainer import PipelineParallelTrainerConf282930@option(PipelineParallelTrainerConf.layers, 'PipelineBiases')31def neox_layers(c: PipelineParallelTrainerConf):32"""33### Load GPT-NeoX layers34"""35return list(LayerGenerator(is_clone_layers=c.is_clone_layers,36filter_layers=c.filter_layers,37dtype=c.dtype,38).load())394041@option(PipelineParallelTrainerConf.fine_tuner, 'PipelineBiases')42def fine_tune_biases(c: PipelineParallelTrainerConf):43"""44### Create fine tuner for biases45"""4647fine_tuner = FineTuneBiases(typing.cast(typing.List[NeoXModule], c.layers))48# Mark biases as trainable49fine_tuner.set_trainable_params()5051#52return fine_tuner535455@option(PipelineParallelTrainerConf.model, 'PipelineBiases')56def pipe_model(c: PipelineParallelTrainerConf):57"""58### Create pipeline parallel model59"""6061if c.is_checkpointing:62raise NotImplementedError()63else:64layers = c.layers6566# Make sure the finetuner is initialized67_ = c.fine_tuner6869# Create the Pipe module70with monit.section('Pipe'):71# Get the layer distribution across GPUs72balance = balance_layers_simple(len(layers), c.n_gpus)73inspect(balance=balance)74# Devices for each GPU75devices = [torch.device(f'cuda:{i}') for i in range(c.n_gpus)]76# Create Fairscale Pipe module77pipe_model = fairscale.nn.Pipe(nn.Sequential(*layers),78balance=balance,79devices=devices,80chunks=c.chunks)8182#83return pipe_model848586@option(PipelineParallelTrainerConf.train_loader)87def tiny_shakespeare(c: PipelineParallelTrainerConf):88"""89#### Tiny Shakespeare dataset90"""91dataset = get_training_data(c.max_seq_len)9293return DataLoader(dataset,94batch_size=c.batch_size,95sampler=RandomSampler(dataset, replacement=True))969798def main():99# Create experiment100experiment.create(name='pipe_neox_biases',101writers={'screen', 'web_api'})102103# Initialize configs104conf = PipelineParallelTrainerConf()105experiment.configs(conf, {106'learning_rate': 3e-4,107'is_checkpointing': False,108'max_seq_len': 128,109'batch_size': 64,110'chunks': 8,111})112113# Start the experiment114with experiment.start():115# Initialize the model. Do this before the loop for cleaner logs.116_ = conf.model117118# Train119for epoch in monit.loop(conf.epochs):120conf.train_epoch()121tracker.new_line()122torch.save(conf.fine_tuner.state_dict(), str(lab.get_data_path() / 'fine_tune.pt'))123124125#126if __name__ == '__main__':127main()128129130