Path: blob/master/deep_learning/llm/rlhf/dpo_train.py
2597 views
"""Relies on liger kernel, flash attention 2, bfloat16,1requires A100 GPU and torch > 2.3"""2import os3import torch4import torch.distributed as dist5from datasets import load_dataset6from transformers import AutoTokenizer7from liger_kernel.transformers import AutoLigerKernelForCausalLM8from peft import LoraConfig9from trl import DPOTrainer, DPOConfig101112def create_preference_triplets(example):13"""14Create preference triplets:1516- `prompt`: prompt that is given to a model for text generation.17- `chosen`: preferred generated response for the corresponding prompt.18- `rejected`: response that is not preferred.19"""20chosen = extract_assistant_messages(example["chosen"], index=-1)21rejected = extract_assistant_messages(example["rejected"], index=-1)2223return {24"prompt": example["prompt"],25"chosen": chosen,26"rejected": rejected27}282930def extract_assistant_messages(messages, index=-1):31"""Recursively extract the last assistant messages from the end of the conversation."""32if messages[index]["role"] == "assistant":33return messages[index]["content"]34else:35extract_assistant_messages(messages, index - 1)363738if __name__ == "__main__":39dataset = load_dataset(40"argilla/ultrafeedback-binarized-preferences-cleaned",41split="train",42verification_mode="no_checks",43cache_dir="/data"44)45dataset_dict = dataset.train_test_split(test_size=0.01, seed=54321)46dataset_dict_preprocessed = dataset_dict.map(47create_preference_triplets,48num_proc=849)5051model_name_or_path = "Qwen/Qwen2.5-3B-Instruct"52tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, cache_dir="/data")53tokenizer.padding_side = "left"54tokenizer.pad_token = tokenizer.eos_token5556model = AutoLigerKernelForCausalLM.from_pretrained(57model_name_or_path,58cache_dir="/data",59torch_dtype=torch.bfloat16,60attn_implementation="flash_attention_2",61)62model.train()63print(model)6465hf_save_tmp_dir = "dpo_model"66training_args = DPOConfig(67output_dir=hf_save_tmp_dir,68bf16=True,69gradient_accumulation_steps=1,70per_device_train_batch_size=2,71per_device_eval_batch_size=4,72gradient_checkpointing=True,73gradient_checkpointing_kwargs={"use_reentrant": False},74max_steps=5000,75logging_steps=50,76learning_rate=0.0001,77beta=0.1,78max_length=1024,79max_prompt_length=512,80remove_unused_columns=False,81eval_strategy="steps",82eval_steps=1000,83save_strategy="steps",84save_steps=1000,85)86print(training_args)87peft_config = LoraConfig(88r=16,89target_modules=[90"q_proj",91"k_proj",92"v_proj",93"o_proj",94"down_proj",95"up_proj",96"gate_proj"97],98modules_to_save=[99"embed_tokens",100"lm_head"101]102)103print(peft_config)104dpo_trainer = DPOTrainer(105model,106train_dataset=dataset_dict_preprocessed["train"],107eval_dataset=dataset_dict_preprocessed["test"],108tokenizer=tokenizer,109args=training_args,110peft_config=peft_config,111)112dpo_trainer.train()113114115