def launch()

in generation/llm_swarm_script.py [0:0]


def launch():
    with LLMSwarm(
        LLMSwarmConfig(
            instances=8,
            inference_engine="vllm",
            gpus=1,
            model=model_id,
            slurm_template_path="templates/vllm_h100.template.slurm",
            load_balancer_template_path="templates/nginx.template.conf",
            trust_remote_code=True,
            per_instance_max_parallel_requests=200,
        )
    ) as llm_swarm:
        semaphore = asyncio.Semaphore(llm_swarm.suggested_max_parallel_requests)
        client = AsyncInferenceClient(model=llm_swarm.endpoint)

        async def process_text(prompt):
            async with semaphore:
                response = await client.post(
                    json={
                        "prompt": prompt,
                        "max_tokens": 2000,
                    }
                )
                res = json.loads(response.decode("utf-8"))["text"][0][len(prompt):]
                return res

        def load_and_process_dataset(tar_file):
            try:
                print(f"Loading dataset from: {tar_file}")
                dataset = load_dataset('webdataset', split='train', data_files=tar_file).to_pandas()
                tasks = create_tasks(dataset, prompt_id=None, n_overlap=1)
                return tasks
            except Exception as e:
                print(f"Error loading dataset from: {tar_file}")
                print(e)
                return []

        def get_future_tasks(tar_index, executor):
            futures = []
            for inner_idx in range(STEP_SIZE):
                tar_file = os.path.join(DATA_PATH, TAR_FILE_PATTERN.format(tar_index + inner_idx))
                futures.append(executor.submit(load_and_process_dataset, tar_file))
            return futures

        async def process_dataset(tar_index, total_examples):
            next_future_tasks = get_future_tasks(tar_index, ThreadPoolExecutor(max_workers=STEP_SIZE))
            for idx in trange(tar_index, NUM_TAR_FILES + STEP_SIZE, STEP_SIZE, desc="Creating Dataset"):
                print(f"Processing tar file {idx}")
                tasks = []
                future_tasks = next_future_tasks

                results = [f.result() for f in future_tasks]
                for result in results:
                    tasks.extend(result)
                # Once you created the tasks for this batch, load the next batch in parallel
                # Otherwise, the tasks for this batch compete with the tasks from next batch for resources
                next_future_tasks = get_future_tasks(idx + STEP_SIZE, ThreadPoolExecutor(max_workers=1)) # Only one thread to avoid cpu clogging

                results = await tqdm_asyncio.gather(*(process_text(task['messages']) for task in tasks))
                df = pd.DataFrame({"Task": tasks, "Completion": results})
                df_new = process_outputs_to_df(df)

                # Save the batch to HDF5
                df_new.to_hdf(f'synthetic_dataset_batch_{idx}.h5', key='df', mode='w')

                unique_keys = df_new['__key__'].nunique()
                total_examples += unique_keys

                save_checkpoint(idx, total_examples)

        async def main():
            checkpoint = load_checkpoint()
            tar_index = checkpoint['tar_index']

            if tar_index != 0:
                tar_index += STEP_SIZE
                print(f"Resuming from tar file {tar_index}")

            total_examples = checkpoint['total_examples']
            processor = asyncio.create_task(process_dataset(tar_index, total_examples))
            await processor

            print("All batches processed.")

        asyncio.run(main())