cookbook-efforts/domain-specific-datasets/distilabel_pipelines/domain_expert_pipeline.py (180 lines of code) (raw):

import json from typing import Any, Dict import argilla as rg from distilabel.llms import InferenceEndpointsLLM from distilabel.pipeline import Pipeline from distilabel.steps import ( LoadDataFromDicts, TextGenerationToArgilla, ExpandColumns, ) from distilabel.steps.tasks import ( TextGeneration, SelfInstruct, ) from distilabel.steps.tasks.typing import ChatType from huggingface_hub import hf_hub_download ################################################################################ # Define custom Argilla Dataset ################################################################################ def create_argilla_dataset( api_url: str, api_key: str, dataset_name: str, workspace: str, ): """Create a dataset in Argilla.""" rg.init(api_url, api_key) rg_dataset = rg.FeedbackDataset( fields=[ rg.TextField(name="id", title="id"), # type: ignore rg.TextField(name="instruction", title="instruction"), # type: ignore rg.TextField(name="generation", title="generation"), # type: ignore ], questions=[ rg.LabelQuestion( # type: ignore name="quality", title=f"What's the quality of the generation for the given instruction?", labels={"bad": "👎", "good": "👍"}, ), rg.TextQuestion( name="improved_instruction", title="How would you improve the instruction?", required=False, ), rg.TextQuestion( name="improved_response", title="How would you improve the response?", required=False, ), ], ) try: rg_dataset.push_to_argilla(name=dataset_name, workspace=workspace) except RuntimeError as e: print(f"Failed to create the dataset in Argilla: {e} Moving on...") ################################################################################ # Define out custom step for the domain expert ################################################################################ class DomainExpert(TextGeneration): """A customized task to generate text as a domain expert in the domain of farming and agriculture.""" system_prompt: str template: str = """This is the the instruction: {instruction}""" def format_input(self, input: Dict[str, Any]) -> "ChatType": return [ { "role": "system", "content": self.system_prompt, }, { "role": "user", "content": self.template.format(**input), }, ] ################################################################################ # Main script to run the pipeline ################################################################################ if __name__ == "__main__": import os import json import sys # get some args repo_id = sys.argv[1] # Get super secret tokens hub_token = os.environ.get("HF_TOKEN") argilla_api_key = os.environ.get("ARGILLA_API_KEY", "owner.apikey") # load pipeline parameters with open( hf_hub_download( repo_id=repo_id, filename="pipeline_params.json", repo_type="dataset" ), "r", ) as f: params = json.load(f) argilla_api_url = params.get("argilla_api_url") argilla_dataset_name = params.get("argilla_dataset_name") self_instruct_base_url = params.get("self_instruct_base_url") domain_expert_base_url = params.get("domain_expert_base_url") self_intruct_num_generations = params.get("self_instruct_num_generations", 2) domain_expert_num_generations = params.get("domain_expert_num_generations", 2) self_instruct_temperature = params.get("self_instruct_temperature", 0.9) domain_expert_temperature = params.get("domain_expert_temperature", 0.9) self_instruct_max_new_tokens = params.get("self_instruct_max_new_tokens", 2048) domain_expert_max_new_tokens = params.get("domain_expert_max_new_tokens", 2048) if not all( [ argilla_api_url, argilla_dataset_name, self_instruct_base_url, domain_expert_base_url, ] ): raise ValueError("Some of the pipeline parameters are missing") # collect our seed prompts defined in the space with open( hf_hub_download( repo_id=repo_id, filename="seed_data.json", repo_type="dataset" ), "r", ) as f: seed_data = json.load(f) application_instruction = seed_data.get("application_instruction") domain_expert_prompt = seed_data.get("domain_expert_prompt") domain_name = seed_data.get("domain") terms = seed_data.get("seed_terms") # Create the Argilla dataset create_argilla_dataset( api_url=argilla_api_url, api_key=argilla_api_key, dataset_name=argilla_dataset_name, workspace="admin", ) # Define the distilabel pipeline with Pipeline(domain_name) as pipeline: load_data = LoadDataFromDicts( name="load_data", batch_size=64, data=[{"input": term} for term in terms], ) self_instruct = SelfInstruct( name="self_instruct", num_instructions=self_intruct_num_generations, input_batch_size=8, llm=InferenceEndpointsLLM( api_key=hub_token, base_url=self_instruct_base_url, ), application_description=application_instruction, ) expand_columns = ExpandColumns( name="expand_columns", columns=["instructions"], output_mappings={"instructions": "instruction"}, ) domain_expert = DomainExpert( name="domain_expert", llm=InferenceEndpointsLLM( api_key=hub_token, base_url=domain_expert_base_url, ), input_batch_size=8, num_generations=domain_expert_num_generations, system_prompt=domain_expert_prompt, ) # Push the generated dataset to Argilla to_argilla = TextGenerationToArgilla( name="to_argilla", dataset_workspace="admin", ) # Connect up the pipeline load_data.connect(self_instruct) self_instruct.connect(expand_columns) expand_columns.connect(domain_expert) domain_expert.connect(to_argilla) # Run the pipeline pipeline.run( parameters={ "self_instruct": { "llm": { "generation_kwargs": { "max_new_tokens": self_instruct_max_new_tokens, "temperature": self_instruct_temperature, }, } }, "domain_expert": { "llm": { "generation_kwargs": { "max_new_tokens": self_instruct_max_new_tokens, "temperature": domain_expert_temperature, }, } }, "to_argilla": { "dataset_name": argilla_dataset_name, "api_key": argilla_api_key, "api_url": argilla_api_url, }, }, use_cache=False, )