Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ethen8181
GitHub Repository: ethen8181/machine-learning
Path: blob/master/deep_learning/llm/rlhf/dpo_train.py
2597 views
1
"""Relies on liger kernel, flash attention 2, bfloat16,
2
requires A100 GPU and torch > 2.3"""
3
import os
4
import torch
5
import torch.distributed as dist
6
from datasets import load_dataset
7
from transformers import AutoTokenizer
8
from liger_kernel.transformers import AutoLigerKernelForCausalLM
9
from peft import LoraConfig
10
from trl import DPOTrainer, DPOConfig
11
12
13
def create_preference_triplets(example):
14
"""
15
Create preference triplets:
16
17
- `prompt`: prompt that is given to a model for text generation.
18
- `chosen`: preferred generated response for the corresponding prompt.
19
- `rejected`: response that is not preferred.
20
"""
21
chosen = extract_assistant_messages(example["chosen"], index=-1)
22
rejected = extract_assistant_messages(example["rejected"], index=-1)
23
24
return {
25
"prompt": example["prompt"],
26
"chosen": chosen,
27
"rejected": rejected
28
}
29
30
31
def extract_assistant_messages(messages, index=-1):
32
"""Recursively extract the last assistant messages from the end of the conversation."""
33
if messages[index]["role"] == "assistant":
34
return messages[index]["content"]
35
else:
36
extract_assistant_messages(messages, index - 1)
37
38
39
if __name__ == "__main__":
40
dataset = load_dataset(
41
"argilla/ultrafeedback-binarized-preferences-cleaned",
42
split="train",
43
verification_mode="no_checks",
44
cache_dir="/data"
45
)
46
dataset_dict = dataset.train_test_split(test_size=0.01, seed=54321)
47
dataset_dict_preprocessed = dataset_dict.map(
48
create_preference_triplets,
49
num_proc=8
50
)
51
52
model_name_or_path = "Qwen/Qwen2.5-3B-Instruct"
53
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, cache_dir="/data")
54
tokenizer.padding_side = "left"
55
tokenizer.pad_token = tokenizer.eos_token
56
57
model = AutoLigerKernelForCausalLM.from_pretrained(
58
model_name_or_path,
59
cache_dir="/data",
60
torch_dtype=torch.bfloat16,
61
attn_implementation="flash_attention_2",
62
)
63
model.train()
64
print(model)
65
66
hf_save_tmp_dir = "dpo_model"
67
training_args = DPOConfig(
68
output_dir=hf_save_tmp_dir,
69
bf16=True,
70
gradient_accumulation_steps=1,
71
per_device_train_batch_size=2,
72
per_device_eval_batch_size=4,
73
gradient_checkpointing=True,
74
gradient_checkpointing_kwargs={"use_reentrant": False},
75
max_steps=5000,
76
logging_steps=50,
77
learning_rate=0.0001,
78
beta=0.1,
79
max_length=1024,
80
max_prompt_length=512,
81
remove_unused_columns=False,
82
eval_strategy="steps",
83
eval_steps=1000,
84
save_strategy="steps",
85
save_steps=1000,
86
)
87
print(training_args)
88
peft_config = LoraConfig(
89
r=16,
90
target_modules=[
91
"q_proj",
92
"k_proj",
93
"v_proj",
94
"o_proj",
95
"down_proj",
96
"up_proj",
97
"gate_proj"
98
],
99
modules_to_save=[
100
"embed_tokens",
101
"lm_head"
102
]
103
)
104
print(peft_config)
105
dpo_trainer = DPOTrainer(
106
model,
107
train_dataset=dataset_dict_preprocessed["train"],
108
eval_dataset=dataset_dict_preprocessed["test"],
109
tokenizer=tokenizer,
110
args=training_args,
111
peft_config=peft_config,
112
)
113
dpo_trainer.train()
114
115