cookbook-efforts/kto-preference/preference_gen.py (62 lines of code) (raw):

import os import random from datasets import load_dataset from distilabel.llm import LLM, InferenceEndpointsLLM, LLMPool, ProcessLLM from distilabel.pipeline import Pipeline from distilabel.tasks import Task, TextGenerationTask from dotenv import load_dotenv load_dotenv() # You need to set the HF_TOKEN environment variable to your Hugging Face API token HF_TOKEN = os.getenv("HF_TOKEN") assert HF_TOKEN is not None, "Please set HF_TOKEN to your Hugging Face API token" HF_USER_NAME = None assert HF_USER_NAME, "Please set HF_USER_NAME to your Hugging Face username" # if you want to sample from the dataset, set this to the number of samples you want # if the size of your sample is larger than the dataset the full dataset will be used SAMPLE_SIZE = None ## Load the dataset of prompts def prepare_data(): prompts = load_dataset("davanstrien/haiku_prompts", split="train") print(f"Loaded {len(prompts)} prompts") return prompts.rename_column("instructions", "input") dataset = prepare_data() ## Define the task task = TextGenerationTask( system_prompt="""You are a poet specialising in creating Haiku. \nYour haiku consist of three lines, with five syllables in the first line, seven in the second, and five in the third.\nBeyond being technically correct, your haiku should also be beautiful and meaningful. \nYou respond only with a haiku. You do not add anything else to your responses. \n\n""", ) print(task.system_prompt) # load llms def load_llama2(task: Task) -> LLM: return InferenceEndpointsLLM( "meta-llama/Llama-2-70b-chat-hf", token=HF_TOKEN, task=task, max_new_tokens=512, prompt_format="llama2", ) def load_mistral(task: Task) -> LLM: checkpoint = "mistralai/Mistral-7B-Instruct-v0.2" return InferenceEndpointsLLM( checkpoint, token=HF_TOKEN, task=task, max_new_tokens=512, prompt_format="llama2", ) # uncomment to use nous-hermes-2-yi-34b-aug # def load_nous_yi(task: Task) -> LLM: # checkpoint = "nous-hermes-2-yi-34b-aug" # return InferenceEndpointsLLM( # checkpoint, # token=HF_TOKEN, # task=task, # max_new_tokens=488, # prompt_format="chatml", # ) mistral = ProcessLLM(task=task, load_llm_fn=load_mistral) llama2 = ProcessLLM(task=task, load_llm_fn=load_llama2) # uncomment to use nous-hermes-2-yi-34b-aug # nous_yi = ProcessLLM(task=task, load_llm_fn=load_nous_yi) llms = [ mistral, llama2, ] # nous_yi] # uncomment to use nous-hermes-2-yi-34b-aug pool = LLMPool(llms=llms) pipeline = Pipeline(generator=pool) if SAMPLE_SIZE is not None: sample_idx = random.sample(range(len(dataset)), min(SAMPLE_SIZE, len(dataset))) dataset = dataset.select(sample_idx) print(f"Using {len(dataset)} prompts") print("Generating haiku...") haiku = pipeline.generate( dataset, num_generations=3, batch_size=1, display_progress_bar=True, shuffle_before_labelling=False, ) print(haiku) print("Pushing to hub...") haiku.push_to_hub(f"{HF_USER_NAME}/haiku_dpo", "aesthetic-preference", token=HF_TOKEN)