Path: blob/master/deep_learning/llm/rlhf/generate.py
2581 views
"""1Script for generating answers/completions on ultrafeedback dataset with2models or models + adapters3"""4import os5import torch6import numpy as np7import pandas as pd8import transformers9import pytorch_lightning as pl10from peft import PeftModel11from datasets import load_dataset12from dataclasses import dataclass13from torch.utils.data import DataLoader14from transformers import (15AutoTokenizer,16AutoModelForCausalLM,17GenerationConfig,18PreTrainedTokenizerBase,19)20from typing import Any, Dict, List, Optional, Union21from transformers.data.data_collator import DataCollatorMixin22from transformers.utils import PaddingStrategy232425@dataclass26class DataCollatorForGeneration(DataCollatorMixin):27"""28tokenize raw text (prompt) as well as padding while forming a batch for data loader.29"""3031tokenizer: PreTrainedTokenizerBase32max_seq_len: int = 51233padding: Union[bool, str, PaddingStrategy] = True34return_tensors: str = "pt"35prompt_col_name: str = "prompt"3637def __post_init__(self):38self.tokenizer.padding_side = "left"39self.tokenizer.pad_token = self.tokenizer.eos_token4041def __call__(42self, features: List[Dict[str, Any]], return_tensors=None43) -> Dict[str, Any]:4445prompts = [feature[self.prompt_col_name] for feature in features]46tokenized_text = self.tokenizer(47prompts,48padding=self.padding,49max_length=self.max_seq_len,50truncation=True,51return_attention_mask=True,52return_tensors=self.return_tensors,53)5455batch = {56"prompts": prompts,57"input_ids": tokenized_text["input_ids"],58"attention_mask": tokenized_text["attention_mask"],59}60return batch616263class LLMGenerateLightningModule(pl.LightningModule):64"""65Generate responses from LLM. Expects input prompts, tokenized input_ids, attention_mask66"""6768def __init__(69self,70pretrained_model_name_or_path,71generation_config,72prediction_config,73adapter_path=None,74cache_dir="/data",75):76super().__init__()77self.model = AutoModelForCausalLM.from_pretrained(78pretrained_model_name_or_path, cache_dir=cache_dir79)80if adapter_path:81peft_model = PeftModel.from_pretrained(self.model, adapter_path)82self.model = peft_model.merge_and_unload()8384self.tokenizer = AutoTokenizer.from_pretrained(85pretrained_model_name_or_path, padding_side="left", cache_dir=cache_dir86)87self.tokenizer.pad_token = self.tokenizer.eos_token8889self.generation_config = generation_config90self._setup_prediction(prediction_config)9192def predict_step(self, batch, batch_idx, dataloader_idx=None):93prompts = batch["prompts"]94input_ids = batch["input_ids"]95attention_mask = batch["attention_mask"]9697responses = self.generate(input_ids, attention_mask)9899prediction_output = {100"prompts": prompts,101"responses": responses,102}103self.prediction_outputs.append(prediction_output)104return prediction_output105106def generate(self, input_ids, attention_mask):107model_output = self.model.generate(108input_ids,109attention_mask=attention_mask,110generation_config=self.generation_config111)112# crop input prompt from generated response113input_seq_length = input_ids.shape[-1]114model_output_answer_only = model_output[:, input_seq_length:]115responses = self.tokenizer.batch_decode(model_output_answer_only, skip_special_tokens=True)116return responses117118def _setup_prediction(self, prediction_config):119if prediction_config:120self.prediction_outputs = []121self._prediction_partition_idx = 0122self.prediction_partition_format = prediction_config["prediction_partition_format"]123self.prediction_output_path = prediction_config["prediction_output_path"]124self.prediction_accumulation_steps = prediction_config.get("prediction_accumulation_steps", 100)125126def _save_prediction_outputs(self):127if self.prediction_output_path:128data = {field: [] for field in self.prediction_outputs[0]}129for prediction_output in self.prediction_outputs:130for field in data:131data[field].extend(prediction_output[field])132133partition_file_name = self.prediction_partition_format.format(134rank=self.global_rank, partition=self._prediction_partition_idx135)136formatted_output_path = os.path.join(137self.prediction_output_path, partition_file_name138)139140# saves prediction batch locally via pandas data frame141df_prediction_outputs = pd.DataFrame.from_dict(data)142os.makedirs(self.prediction_output_path, exist_ok=True)143df_prediction_outputs.to_parquet(formatted_output_path, index=False)144145self._prediction_partition_idx += 1146self.prediction_outputs.clear()147148def on_predict_batch_end(self, outputs, batch, batch_idx):149if len(self.prediction_outputs) == self.prediction_accumulation_steps:150self._save_prediction_outputs()151152def on_predict_epoch_end(self):153if len(self.prediction_outputs) > 0:154self._save_prediction_outputs()155156157158if __name__ == "__main__":159pretrained_model_name_or_path = "Qwen/Qwen2.5-3B-Instruct"160# generate response from a instruct model, versus a model trained161# on ultra feedback dataset using LoRA162adapter_path = None163prediction_output_path = "prediction_instruction_3B_model"164# adapter_path = "dpo_model_v7"165# prediction_output_path = "prediction_dpo_model_v7"166167dataset = load_dataset(168"argilla/ultrafeedback-binarized-preferences-cleaned",169split="train",170verification_mode="no_checks",171cache_dir="/data"172)173dataset_dict = dataset.train_test_split(test_size=0.001, seed=54321)174examples = dataset_dict["test"]175176tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)177data_collator = DataCollatorForGeneration(tokenizer)178data_loader = DataLoader(examples, batch_size=2, num_workers=2, collate_fn=data_collator)179180generation_config = GenerationConfig(181max_new_tokens=250182)183llm_generate_module = LLMGenerateLightningModule(184pretrained_model_name_or_path=pretrained_model_name_or_path,185adapter_path=adapter_path,186generation_config=generation_config,187prediction_config={188"prediction_output_path": prediction_output_path,189"prediction_partition_format": "rank-{rank:02d}-partition-{partition:06d}.parquet"190}191)192trainer = pl.Trainer()193trainer.predict(llm_generate_module, data_loader)194195