generation/llm_swarm_script.py (197 lines of code) (raw):

import asyncio import json import os import random import re from concurrent.futures import ThreadPoolExecutor from typing import Any, Dict, List, Optional import pandas as pd from datasets import IterableDataset, load_dataset from huggingface_hub import AsyncInferenceClient from tqdm import trange from tqdm.asyncio import tqdm_asyncio from transformers import AutoTokenizer from examples.question_answer_pairs.phase_1.base_prompts import (BASE_PROMPT, BASE_USER_CONTENT, PROMPTS) from llm_swarm import LLMSwarm, LLMSwarmConfig CHECKPOINT_FILE = 'checkpoint.json' DATA_PATH = '/fsx/andi/pdfa_data/' TAR_FILE_PATTERN = 'pdfa-eng-train-{:06d}.tar' NUM_TAR_FILES = 1800 # Total number of tar files MAX_PAGES_PER_PDF = 4 STEP_SIZE = 10 model_id = "microsoft/Phi-3-small-8k-instruct" def create_llm_prompt(prompt, text): system_content = BASE_PROMPT.format( role_description=prompt["role_description"], examples=prompt["examples"] ) return [ {"role": "system", "content": system_content}, {"role": "user", "content": BASE_USER_CONTENT.format(text=text)} ] def extract_text_per_page_from_sample(sample: Dict[str, Any]) -> List[str]: """ Extracts text from each page of a given sample and returns it as a list of strings. Args: sample (Dict[str, Any]): The sample containing page data in JSON format. Returns: List[str]: A list of strings, where each string represents the text of a page. """ texts = [] for page in sample['json']['pages']: pages_text = ' \n '.join(page['lines']['text']) texts.append(pages_text) return texts def extract_chunks(pages: List[Any], max_tokens_per_group: int, max_pages_per_group: int, n_overlap: int) -> List[str]: """ Splits a list of pages into chunks with a specified maximum number of tokens per chunk, a maximum number of pages per chunk, and overlap between chunks. Args: pages (List[Any]): The list of pages to be chunked. max_tokens_per_group (int): The maximum number of tokens allowed per chunk. max_pages_per_group (int): The maximum number of pages allowed per chunk. n_overlap (int): The number of overlapping pages between consecutive chunks. Returns: List[str]: A list of chunked text, each chunk containing text from multiple pages. """ chunks = [] current_chunk = [] current_chunk_tokens = 0 current_chunk_pages = 0 page_token_counts = [len(tokenizer.encode(page, add_special_tokens=False)) for page in pages] for i, page in enumerate(pages): page_tokens = page_token_counts[i] if page_tokens > max_tokens_per_group: print(f"Skipping document where page nr {i} has {page_tokens} tokens.") return [] if (current_chunk_tokens + page_tokens > max_tokens_per_group) or (current_chunk_pages + 1 > max_pages_per_group): if current_chunk: chunks.append('\nNEW PAGE\n'.join(current_chunk)) current_chunk = current_chunk[-n_overlap:] if n_overlap > 0 else [] current_chunk_tokens = sum(page_token_counts[max(0, i - n_overlap):i]) current_chunk_pages = len(current_chunk) current_chunk.append(page) current_chunk_tokens += page_tokens current_chunk_pages += 1 if current_chunk: chunks.append('\nNEW PAGE\n'.join(current_chunk)) return chunks def create_tasks(dataset: IterableDataset, prompt_id: Optional[int] = None, n_overlap: int = 2) -> List[Dict[str, Any]]: """ Processes a dataset to generate question and answer pairs for each sample. Args: dataset (IterableDataset): The dataset containing samples. prompt_id (Optional[int]): The ID of the prompt template to use for generating questions. If set to None, prompt_id is random. n_overlap (int): The number of overlapping pages between consecutive chunks. num_samples (int): The number of samples to process. Returns: List[Dict[str, Any]]: A list of dictionaries containing the sample key, page count, generated Q/A pairs, and prompt ID. """ if prompt_id is not None: selected_id_prompt = prompt_id tasks = [] for index, sample in dataset.iterrows(): text_per_page = extract_text_per_page_from_sample(sample) if len(text_per_page) > MAX_PAGES_PER_PDF: continue page_chunks = extract_chunks(text_per_page, max_tokens_per_group=5000, max_pages_per_group=5, n_overlap=n_overlap) for chunk in page_chunks: if prompt_id is None: selected_id_prompt = random.randint(0, 4) prompt = PROMPTS[selected_id_prompt] messages = create_llm_prompt(prompt, chunk) prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True ) tasks_dict = { "__key__": sample['__key__'], "Page count": len(text_per_page), "messages": prompt, "Prompt ID": selected_id_prompt } tasks.append(tasks_dict) return tasks # Function to extract Q&A pairs from a string def extract_qa_pairs(text): qa_pattern = re.compile(r'(Q\d+:\s*.*?)(A\d+:\s*.*?)(?=(Q\d+:)|$)', re.DOTALL) matches = qa_pattern.findall(text) qa_pairs = [(q.strip(), a.strip()) for match in matches for q, a in [match[:2]]] return qa_pairs def process_outputs_to_df(df): all_data = [] for index, row in df.iterrows(): task = row['Task'] completion = row['Completion'] sample_key = task['__key__'] page_count = task['Page count'] prompt_id = task['Prompt ID'] qa_pairs = extract_qa_pairs(completion) if len(qa_pairs) == 0: print('No Q&A pairs found for sample:', sample_key) for question, answer in qa_pairs: all_data.append({ '__key__': sample_key, 'Page count': page_count, 'Prompt ID': prompt_id, 'question': question, 'answer': answer }) qa_df = pd.DataFrame(all_data) return qa_df def save_checkpoint(tar_index, total_examples): checkpoint_data = { 'tar_index': tar_index, 'total_examples': total_examples } with open(CHECKPOINT_FILE, 'w') as f: json.dump(checkpoint_data, f) def load_checkpoint(): if os.path.exists(CHECKPOINT_FILE): with open(CHECKPOINT_FILE, 'r') as f: return json.load(f) return {'tar_index': 0, 'total_examples': 0} tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) def launch(): with LLMSwarm( LLMSwarmConfig( instances=8, inference_engine="vllm", gpus=1, model=model_id, slurm_template_path="templates/vllm_h100.template.slurm", load_balancer_template_path="templates/nginx.template.conf", trust_remote_code=True, per_instance_max_parallel_requests=200, ) ) as llm_swarm: semaphore = asyncio.Semaphore(llm_swarm.suggested_max_parallel_requests) client = AsyncInferenceClient(model=llm_swarm.endpoint) async def process_text(prompt): async with semaphore: response = await client.post( json={ "prompt": prompt, "max_tokens": 2000, } ) res = json.loads(response.decode("utf-8"))["text"][0][len(prompt):] return res def load_and_process_dataset(tar_file): try: print(f"Loading dataset from: {tar_file}") dataset = load_dataset('webdataset', split='train', data_files=tar_file).to_pandas() tasks = create_tasks(dataset, prompt_id=None, n_overlap=1) return tasks except Exception as e: print(f"Error loading dataset from: {tar_file}") print(e) return [] def get_future_tasks(tar_index, executor): futures = [] for inner_idx in range(STEP_SIZE): tar_file = os.path.join(DATA_PATH, TAR_FILE_PATTERN.format(tar_index + inner_idx)) futures.append(executor.submit(load_and_process_dataset, tar_file)) return futures async def process_dataset(tar_index, total_examples): next_future_tasks = get_future_tasks(tar_index, ThreadPoolExecutor(max_workers=STEP_SIZE)) for idx in trange(tar_index, NUM_TAR_FILES + STEP_SIZE, STEP_SIZE, desc="Creating Dataset"): print(f"Processing tar file {idx}") tasks = [] future_tasks = next_future_tasks results = [f.result() for f in future_tasks] for result in results: tasks.extend(result) # Once you created the tasks for this batch, load the next batch in parallel # Otherwise, the tasks for this batch compete with the tasks from next batch for resources next_future_tasks = get_future_tasks(idx + STEP_SIZE, ThreadPoolExecutor(max_workers=1)) # Only one thread to avoid cpu clogging results = await tqdm_asyncio.gather(*(process_text(task['messages']) for task in tasks)) df = pd.DataFrame({"Task": tasks, "Completion": results}) df_new = process_outputs_to_df(df) # Save the batch to HDF5 df_new.to_hdf(f'synthetic_dataset_batch_{idx}.h5', key='df', mode='w') unique_keys = df_new['__key__'].nunique() total_examples += unique_keys save_checkpoint(idx, total_examples) async def main(): checkpoint = load_checkpoint() tar_index = checkpoint['tar_index'] if tar_index != 0: tar_index += STEP_SIZE print(f"Resuming from tar file {tar_index}") total_examples = checkpoint['total_examples'] processor = asyncio.create_task(process_dataset(tar_index, total_examples)) await processor print("All batches processed.") asyncio.run(main()) launch()