rlhf/trlx/train_reward_model_bloom.py (151 lines of code) (raw):

import json import os import torch from torch.utils.data import Dataset from reward_model_bloom import BLOOMRewardModel from tqdm import tqdm from transformers import AutoTokenizer, Trainer, TrainingArguments os.environ['WANDB_DISABLED'] = 'true' os.environ['TOKENIZERS_PARALLELISM'] = 'false' os.environ['CUDA_LAUNCH_BLOCKING'] = '1' def read_json(data_path): res = [] with open(data_path, 'r') as f: data = f.readlines() for line in tqdm(data): line = json.loads(line) res.append(line) return res def create_comparison_dataset(path): dataset = read_json(path) print('dataset_size: ', len(dataset)) print('dataset_size case: ', dataset[0]) pairs = [] for sample in tqdm(dataset): pair = {} prompt = sample['prompt'] chosen_answer = sample['choosen'] rejected_answer = sample['rejected'] if chosen_answer == rejected_answer: continue if len(chosen_answer) < 1 or len(rejected_answer) < 1: continue pair['choosen'] = prompt + '\n' + chosen_answer pair['rejected'] = prompt + '\n' + rejected_answer pairs.append(pair) print('pairs_nums: ', len(pairs)) return pairs class PairwiseDataset(Dataset): def __init__(self, pairs, tokenizer, max_length): self.chosen_input_ids = [] self.chosen_attn_masks = [] self.rejected_input_ids = [] self.rejected_attn_masks = [] for pair in tqdm(pairs): chosen, rejected = pair['choosen'], pair['rejected'] # print("chosen: ", chosen) # print("rejected: ", rejected) chosen_encodings_dict = tokenizer( # "<|startoftext|>" + chosen + "<|endoftext|>", chosen + '</s>', truncation=True, max_length=max_length, padding='max_length', return_tensors='pt', ) rejected_encodings_dict = tokenizer( # "<|startoftext|>" + rejected + "<|endoftext|>", rejected + '</s>', truncation=True, max_length=max_length, padding='max_length', return_tensors='pt', ) # print("chosen_input_ids_shape: ", chosen_encodings_dict["input_ids"].size()) # print("chosen_input_ids: ", chosen_encodings_dict["input_ids"]) # print("rejected_input_ids: ", rejected_encodings_dict["input_ids"]) # chosen_ids = chosen_encodings_dict["input_ids"] # rejected_ids = rejected_encodings_dict["input_ids"] # print("dengyu: ", (chosen_ids == rejected_ids)) # print("equal: ", torch.eq(chosen_ids, rejected_ids)) # print("all: ", torch.all(torch.eq(chosen_ids, rejected_ids))) if torch.all( torch.eq(chosen_encodings_dict['input_ids'], rejected_encodings_dict['input_ids'])).item(): # print("chosen_input: ", tokenizer.decode(chosen_encodings_dict["input_ids"][0])) # print("rejected_input: ", tokenizer.decode(rejected_encodings_dict["input_ids"][0])) # print("chosen_input_ids: ", chosen_encodings_dict["input_ids"]) # print("rejected_input_ids: ", rejected_encodings_dict["input_ids"]) pass else: self.chosen_input_ids.append( chosen_encodings_dict['input_ids']) self.chosen_attn_masks.append( chosen_encodings_dict['attention_mask']) self.rejected_input_ids.append( rejected_encodings_dict['input_ids']) self.rejected_attn_masks.append( rejected_encodings_dict['attention_mask']) print('chosen_input_size: ', len(self.chosen_input_ids)) print('rejected_input_size: ', len(self.rejected_input_ids)) def __len__(self): return len(self.chosen_input_ids) def __getitem__(self, idx): return ( self.chosen_input_ids[idx], self.chosen_attn_masks[idx], self.rejected_input_ids[idx], self.rejected_attn_masks[idx], ) class DataCollatorReward: def __call__(self, data): # tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-1b1") batch = {} batch['input_ids'] = torch.cat([f[0] for f in data] + [f[2] for f in data]) # print("????input_ids: ", batch["input_ids"]) # print("????input: ", tokenizer.decode(batch["input_ids"][0])) batch['attention_mask'] = torch.cat([f[1] for f in data] + [f[3] for f in data]) batch['labels'] = torch.tensor([0] * len(data) + [1] * len(data)) return batch def compute_metrics(eval_preds): chosen_end_scores = eval_preds.predictions[0] # chosen scores rejected_end_scores = eval_preds.predictions[1] # rejected scores result = {} acc = sum( chosen_end_scores > rejected_end_scores) / len(rejected_end_scores) result['accuracy'] = acc return result if __name__ == '__main__': # tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") # tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B") tokenizer = AutoTokenizer.from_pretrained('bigscience/bloom-1b1') tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = 'right' # print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!tokenizer.pad_token: ", tokenizer.pad_token) if not os.path.exists('rm_checkpoint'): os.mkdir('rm_checkpoint') training_args = TrainingArguments( output_dir='rm_checkpoint/', num_train_epochs=2, logging_steps=10, gradient_accumulation_steps=4, save_strategy='steps', evaluation_strategy='steps', per_device_train_batch_size=16, per_device_eval_batch_size=1, eval_accumulation_steps=1, eval_steps=1000, save_steps=1000, warmup_steps=100, logging_dir='./logs', fp16=False, bf16=True, learning_rate=1e-5, deepspeed='ds_config_bloom.json', save_total_limit=3, ) # Initialize the reward model from the (supervised) fine-tuned GPT-J model = BLOOMRewardModel('bigscience/bloom-1b1') # Freeze the first 70% of the hidden layers of the reward model backbone layers = model.transformer.h num_layers = len(layers) num_unfrozen = int(0.3 * num_layers) for layer in layers[:-num_unfrozen]: layer.requires_grad_(False) # Create the comparisons datasets data_path = './ranking_data/' train_pairs = create_comparison_dataset( os.path.join(data_path, 'ranking_train.json')) val_pairs = create_comparison_dataset( os.path.join(data_path, 'ranking_val.json')) # Make pairwise datasets for training max_length = 550 train_dataset = PairwiseDataset(train_pairs, tokenizer, max_length=max_length) val_dataset = PairwiseDataset(val_pairs, tokenizer, max_length=max_length) # Create the collator to gather batches of pairwise comparisons data_collator = DataCollatorReward() Trainer( model=model, args=training_args, train_dataset=train_dataset, compute_metrics=compute_metrics, eval_dataset=val_dataset, data_collator=data_collator, ).train()