Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ethen8181
GitHub Repository: ethen8181/machine-learning
Path: blob/master/deep_learning/llm/rlhf/generate.py
2581 views
1
"""
2
Script for generating answers/completions on ultrafeedback dataset with
3
models or models + adapters
4
"""
5
import os
6
import torch
7
import numpy as np
8
import pandas as pd
9
import transformers
10
import pytorch_lightning as pl
11
from peft import PeftModel
12
from datasets import load_dataset
13
from dataclasses import dataclass
14
from torch.utils.data import DataLoader
15
from transformers import (
16
AutoTokenizer,
17
AutoModelForCausalLM,
18
GenerationConfig,
19
PreTrainedTokenizerBase,
20
)
21
from typing import Any, Dict, List, Optional, Union
22
from transformers.data.data_collator import DataCollatorMixin
23
from transformers.utils import PaddingStrategy
24
25
26
@dataclass
27
class DataCollatorForGeneration(DataCollatorMixin):
28
"""
29
tokenize raw text (prompt) as well as padding while forming a batch for data loader.
30
"""
31
32
tokenizer: PreTrainedTokenizerBase
33
max_seq_len: int = 512
34
padding: Union[bool, str, PaddingStrategy] = True
35
return_tensors: str = "pt"
36
prompt_col_name: str = "prompt"
37
38
def __post_init__(self):
39
self.tokenizer.padding_side = "left"
40
self.tokenizer.pad_token = self.tokenizer.eos_token
41
42
def __call__(
43
self, features: List[Dict[str, Any]], return_tensors=None
44
) -> Dict[str, Any]:
45
46
prompts = [feature[self.prompt_col_name] for feature in features]
47
tokenized_text = self.tokenizer(
48
prompts,
49
padding=self.padding,
50
max_length=self.max_seq_len,
51
truncation=True,
52
return_attention_mask=True,
53
return_tensors=self.return_tensors,
54
)
55
56
batch = {
57
"prompts": prompts,
58
"input_ids": tokenized_text["input_ids"],
59
"attention_mask": tokenized_text["attention_mask"],
60
}
61
return batch
62
63
64
class LLMGenerateLightningModule(pl.LightningModule):
65
"""
66
Generate responses from LLM. Expects input prompts, tokenized input_ids, attention_mask
67
"""
68
69
def __init__(
70
self,
71
pretrained_model_name_or_path,
72
generation_config,
73
prediction_config,
74
adapter_path=None,
75
cache_dir="/data",
76
):
77
super().__init__()
78
self.model = AutoModelForCausalLM.from_pretrained(
79
pretrained_model_name_or_path, cache_dir=cache_dir
80
)
81
if adapter_path:
82
peft_model = PeftModel.from_pretrained(self.model, adapter_path)
83
self.model = peft_model.merge_and_unload()
84
85
self.tokenizer = AutoTokenizer.from_pretrained(
86
pretrained_model_name_or_path, padding_side="left", cache_dir=cache_dir
87
)
88
self.tokenizer.pad_token = self.tokenizer.eos_token
89
90
self.generation_config = generation_config
91
self._setup_prediction(prediction_config)
92
93
def predict_step(self, batch, batch_idx, dataloader_idx=None):
94
prompts = batch["prompts"]
95
input_ids = batch["input_ids"]
96
attention_mask = batch["attention_mask"]
97
98
responses = self.generate(input_ids, attention_mask)
99
100
prediction_output = {
101
"prompts": prompts,
102
"responses": responses,
103
}
104
self.prediction_outputs.append(prediction_output)
105
return prediction_output
106
107
def generate(self, input_ids, attention_mask):
108
model_output = self.model.generate(
109
input_ids,
110
attention_mask=attention_mask,
111
generation_config=self.generation_config
112
)
113
# crop input prompt from generated response
114
input_seq_length = input_ids.shape[-1]
115
model_output_answer_only = model_output[:, input_seq_length:]
116
responses = self.tokenizer.batch_decode(model_output_answer_only, skip_special_tokens=True)
117
return responses
118
119
def _setup_prediction(self, prediction_config):
120
if prediction_config:
121
self.prediction_outputs = []
122
self._prediction_partition_idx = 0
123
self.prediction_partition_format = prediction_config["prediction_partition_format"]
124
self.prediction_output_path = prediction_config["prediction_output_path"]
125
self.prediction_accumulation_steps = prediction_config.get("prediction_accumulation_steps", 100)
126
127
def _save_prediction_outputs(self):
128
if self.prediction_output_path:
129
data = {field: [] for field in self.prediction_outputs[0]}
130
for prediction_output in self.prediction_outputs:
131
for field in data:
132
data[field].extend(prediction_output[field])
133
134
partition_file_name = self.prediction_partition_format.format(
135
rank=self.global_rank, partition=self._prediction_partition_idx
136
)
137
formatted_output_path = os.path.join(
138
self.prediction_output_path, partition_file_name
139
)
140
141
# saves prediction batch locally via pandas data frame
142
df_prediction_outputs = pd.DataFrame.from_dict(data)
143
os.makedirs(self.prediction_output_path, exist_ok=True)
144
df_prediction_outputs.to_parquet(formatted_output_path, index=False)
145
146
self._prediction_partition_idx += 1
147
self.prediction_outputs.clear()
148
149
def on_predict_batch_end(self, outputs, batch, batch_idx):
150
if len(self.prediction_outputs) == self.prediction_accumulation_steps:
151
self._save_prediction_outputs()
152
153
def on_predict_epoch_end(self):
154
if len(self.prediction_outputs) > 0:
155
self._save_prediction_outputs()
156
157
158
159
if __name__ == "__main__":
160
pretrained_model_name_or_path = "Qwen/Qwen2.5-3B-Instruct"
161
# generate response from a instruct model, versus a model trained
162
# on ultra feedback dataset using LoRA
163
adapter_path = None
164
prediction_output_path = "prediction_instruction_3B_model"
165
# adapter_path = "dpo_model_v7"
166
# prediction_output_path = "prediction_dpo_model_v7"
167
168
dataset = load_dataset(
169
"argilla/ultrafeedback-binarized-preferences-cleaned",
170
split="train",
171
verification_mode="no_checks",
172
cache_dir="/data"
173
)
174
dataset_dict = dataset.train_test_split(test_size=0.001, seed=54321)
175
examples = dataset_dict["test"]
176
177
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
178
data_collator = DataCollatorForGeneration(tokenizer)
179
data_loader = DataLoader(examples, batch_size=2, num_workers=2, collate_fn=data_collator)
180
181
generation_config = GenerationConfig(
182
max_new_tokens=250
183
)
184
llm_generate_module = LLMGenerateLightningModule(
185
pretrained_model_name_or_path=pretrained_model_name_or_path,
186
adapter_path=adapter_path,
187
generation_config=generation_config,
188
prediction_config={
189
"prediction_output_path": prediction_output_path,
190
"prediction_partition_format": "rank-{rank:02d}-partition-{partition:06d}.parquet"
191
}
192
)
193
trainer = pl.Trainer()
194
trainer.predict(llm_generate_module, data_loader)
195