in generation/llm_swarm_script.py [0:0]
def launch():
with LLMSwarm(
LLMSwarmConfig(
instances=8,
inference_engine="vllm",
gpus=1,
model=model_id,
slurm_template_path="templates/vllm_h100.template.slurm",
load_balancer_template_path="templates/nginx.template.conf",
trust_remote_code=True,
per_instance_max_parallel_requests=200,
)
) as llm_swarm:
semaphore = asyncio.Semaphore(llm_swarm.suggested_max_parallel_requests)
client = AsyncInferenceClient(model=llm_swarm.endpoint)
async def process_text(prompt):
async with semaphore:
response = await client.post(
json={
"prompt": prompt,
"max_tokens": 2000,
}
)
res = json.loads(response.decode("utf-8"))["text"][0][len(prompt):]
return res
def load_and_process_dataset(tar_file):
try:
print(f"Loading dataset from: {tar_file}")
dataset = load_dataset('webdataset', split='train', data_files=tar_file).to_pandas()
tasks = create_tasks(dataset, prompt_id=None, n_overlap=1)
return tasks
except Exception as e:
print(f"Error loading dataset from: {tar_file}")
print(e)
return []
def get_future_tasks(tar_index, executor):
futures = []
for inner_idx in range(STEP_SIZE):
tar_file = os.path.join(DATA_PATH, TAR_FILE_PATTERN.format(tar_index + inner_idx))
futures.append(executor.submit(load_and_process_dataset, tar_file))
return futures
async def process_dataset(tar_index, total_examples):
next_future_tasks = get_future_tasks(tar_index, ThreadPoolExecutor(max_workers=STEP_SIZE))
for idx in trange(tar_index, NUM_TAR_FILES + STEP_SIZE, STEP_SIZE, desc="Creating Dataset"):
print(f"Processing tar file {idx}")
tasks = []
future_tasks = next_future_tasks
results = [f.result() for f in future_tasks]
for result in results:
tasks.extend(result)
# Once you created the tasks for this batch, load the next batch in parallel
# Otherwise, the tasks for this batch compete with the tasks from next batch for resources
next_future_tasks = get_future_tasks(idx + STEP_SIZE, ThreadPoolExecutor(max_workers=1)) # Only one thread to avoid cpu clogging
results = await tqdm_asyncio.gather(*(process_text(task['messages']) for task in tasks))
df = pd.DataFrame({"Task": tasks, "Completion": results})
df_new = process_outputs_to_df(df)
# Save the batch to HDF5
df_new.to_hdf(f'synthetic_dataset_batch_{idx}.h5', key='df', mode='w')
unique_keys = df_new['__key__'].nunique()
total_examples += unique_keys
save_checkpoint(idx, total_examples)
async def main():
checkpoint = load_checkpoint()
tar_index = checkpoint['tar_index']
if tar_index != 0:
tar_index += STEP_SIZE
print(f"Resuming from tar file {tar_index}")
total_examples = checkpoint['total_examples']
processor = asyncio.create_task(process_dataset(tar_index, total_examples))
await processor
print("All batches processed.")
asyncio.run(main())