generation/llm_swarm_script.py (195 lines of code) (raw):
import asyncio
import multiprocessing
import os
import time
from dataclasses import asdict, dataclass
from datasets import Dataset, load_dataset
from huggingface_hub import AsyncInferenceClient
from llm_swarm import LLMSwarm, LLMSwarmConfig
from tqdm.asyncio import tqdm_asyncio
from transformers import AutoTokenizer, HfArgumentParser
import wandb
HF_TOKEN = os.environ.get("HF_TOKEN", None)
@dataclass
class Args:
# gneration parameters
max_new_tokens: int = 2500
"""Max new tokens"""
temperature: float = 0.6
"""Generation temperature"""
top_p: float = 0.95
"""Generation top_p"""
top_k: int = 50
"""Generation top_k"""
repetition_penalty: float = 1.2
"""Generation repetition_penalty"""
# prompts dataset parameters
prompts_dataset: str = "HuggingFaceTB/cosmopedia-100k"
"""Dataset containing the prompts"""
max_samples: int = 5000
"""The maximum number of samples to generate (use -1 for all))"""
start_sample: int = -1
"""First sample to process"""
end_sample: int = -1
"""Last sample to process"""
seed: int = 42
"""Seed for shuffling"""
prompt_column: str = "prompt"
"""Name of the column containing the prompt"""
shuffle_dataset: bool = False
"""Whether to shuffle the prompts"""
debug: bool = False
"""Debugging mode"""
# logging parameters
repo_id: str = "HuggingFaceTB/synthetic_data_test"
"""The repo id to push to"""
checkpoint_path: str = "./synthetic_data"
"""Path for saving intermediate generations"""
checkpoint_interval: int = 1_000
"""Interval for saving intermediate generations"""
wandb_username: str = "loubnabnl"
"""Wandb username"""
min_token_length: int = 150
"""Minimum number of tokens in a generation to be kept in the final dataset"""
push_to_hub: bool = True
"""Whether to push to hub"""
parser = HfArgumentParser((Args, LLMSwarmConfig))
args, isc = parser.parse_args_into_dataclasses()
# args used in wandb
args_dict = asdict(args)
args_dict.update(
{
"per_instance_max_parallel_requests": isc.per_instance_max_parallel_requests,
"instances": isc.instances,
"inference_engine": isc.inference_engine,
"model": isc.model,
}
)
print(args_dict)
tokenizer = AutoTokenizer.from_pretrained(isc.model)
num_proc = 1 if args.debug else multiprocessing.cpu_count()
ds = load_dataset(
args.prompts_dataset, token=HF_TOKEN, split="train", num_proc=num_proc
)
if args.shuffle_dataset:
ds = ds.shuffle(seed=args.seed)
if args.start_sample >= 0:
end_sample = len(ds) if args.end_sample < 0 else args.end_sample
print(f"Loading a defined range of samples: ({args.start_sample}, {end_sample})...")
ds = ds.select(range(args.start_sample, end_sample))
elif args.max_samples > 0:
print(f"Loading the first {args.max_samples} samples...")
ds = ds.select(range(args.max_samples))
with LLMSwarm(isc) as llm_swarm:
semaphore = asyncio.Semaphore(llm_swarm.suggested_max_parallel_requests)
client = AsyncInferenceClient(model=llm_swarm.endpoint)
STOP_SEQ = ["<|endoftext|>"]
MAX_RETRIES = 6 # maximum number of retries
RETRY_DELAY = 4 # delay in seconds between retries
async def process_text(sample):
token_length = 0
attempt = 0
while attempt < MAX_RETRIES:
try:
async with semaphore:
completion = await client.text_generation(
prompt=tokenizer.apply_chat_template(
[{"role": "user", "content": sample[args.prompt_column]}],
tokenize=False,
),
max_new_tokens=args.max_new_tokens,
stop_sequences=STOP_SEQ,
temperature=args.temperature,
top_p=args.top_p,
top_k=args.top_k,
repetition_penalty=args.repetition_penalty,
)
for stop_seq in STOP_SEQ:
if completion.endswith(stop_seq):
completion = completion[: -len(stop_seq)].rstrip()
token_length += len(tokenizer.encode(completion))
sample["completion"] = completion
sample["token_length"] = token_length
return sample
except Exception as e:
attempt += 1
if attempt < MAX_RETRIES:
print(
f"Request failed, retrying in {RETRY_DELAY} seconds... (Attempt {attempt}/{MAX_RETRIES})"
)
await asyncio.sleep(RETRY_DELAY)
else:
print(
f"Max retries reached. Failed to process the request with error {str(e)}."
)
sample["completion"] = ""
sample["token_length"] = 0
return sample
async def main():
start_time = time.time()
total_tokens = 0
saving_time = 0
repo_id = (
f"{args.repo_id}_{args.prompt_column}"
if args.prompt_column not in args.repo_id
else args.repo_id
)
wandb.init(
project="synthetic_data",
entity=args.wandb_username,
name=repo_id.split("/")[1],
)
wandb.config.update(args_dict)
repo_id = (
f"{args.repo_id}_{args.prompt_column}"
if args.prompt_column not in args.repo_id
else args.repo_id
)
checkpoint_dir = f"{args.checkpoint_path}/{repo_id.split('/')[1]}/data"
os.makedirs(checkpoint_dir, exist_ok=True)
print(f"Will be saving at {checkpoint_dir}")
total_samples = len(ds)
for i in range(0, total_samples, args.checkpoint_interval):
batch_time = time.time()
# Processing a chunk
print(
f"Processing chunk {int(i/args.checkpoint_interval)}/{int(total_samples/args.checkpoint_interval)}"
)
end_index = min(i + args.checkpoint_interval, total_samples)
chunk = ds.select(range(i, end_index))
chunk_results = await tqdm_asyncio.gather(
*(process_text(sample) for sample in chunk)
)
# Save the chunk results and log throughput
temp_time = time.time()
time_per_chunk = temp_time - batch_time
checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_{i}.json")
intermediate_ds = Dataset.from_list(chunk_results)
intermediate_ds.to_json(checkpoint_path)
batch_tokens = sum(intermediate_ds["token_length"])
total_tokens += batch_tokens
saving_time += time.time() - temp_time
print(f"💾 Checkpoint (samples {i}-{i + args.checkpoint_interval}) saved at {checkpoint_path}.")
wandb.log(
{
"sample": i + args.checkpoint_interval,
"batch": int(i / args.checkpoint_interval),
"total_tokens (M)": total_tokens / 1e6,
"tokens_per_batch": batch_tokens,
"time_per_batch (s)": time_per_chunk,
"generated_tokens_per_sec": int(batch_tokens / time_per_chunk),
"generated_tokens_per_sec_per_node": int(
batch_tokens / (time_per_chunk * isc.instances)
),
}
)
end_time = time.time()
print(
"Done processing and saving all chunks 🎉! Let's get some stats and push to hub..."
)
total_duration = end_time - start_time
overall_tokens_per_second = (
total_tokens / total_duration if total_duration > 0 else 0
)
print(
f"🏎️💨 Overall Tokens per Second: {overall_tokens_per_second:.2f}, per instance: {overall_tokens_per_second/isc.instances:.2f}"
)
print(f"Generated {total_tokens / 1e6:.2f}M tokens")
print(
f"Total duration: {total_duration // 3600}h{int((total_duration % 3600) // 60)}min "
)
print(f"Saving time: {saving_time}s={saving_time/60}min ")
# load dataset
print("Load checkpoints...")
output_ds = load_dataset(checkpoint_dir, split="train")
# remove empty completions
final_data = output_ds.filter(
lambda x: x["token_length"] >= args.min_token_length
)
print(final_data)
failed = output_ds.filter(lambda x: x["token_length"] <= args.min_token_length)
print(final_data)
if args.push_to_hub:
print(f"📨 Pushing dataset to {repo_id}")
final_data.push_to_hub(repo_id, private=True)
print("Dataset pushed!")
if len(failed) > 0:
print(f"{len(failed)} generations failed")
size = min(len(failed), 1000)
failed = failed.select(range(size))
failed.push_to_hub(f"{repo_id}_failed", private=True)
asyncio.run(main())
wandb.finish()