def main()

in scripts/run_prompt_creation.py [0:0]


def main():
    # 1. Parse input arguments
    parser = HfArgumentParser((ModelArguments, DataArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args = parser.parse_args_into_dataclasses()

    # 2. Setup logging
    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )
    
    if data_args.is_single_speaker and data_args.speaker_name is None:
        raise ValueError("`is_single_speaker=True` but `speaker_name` is not specified. Specify it or remove `is_single_speaker`.")

    if not data_args.is_single_speaker and data_args.speaker_name:
        raise ValueError(f"`is_single_speaker=False` but `speaker_name=data_args.speaker_name` is not specified. Add `--is_single_speaker` or remove `speaker_name`.")


    # Create the custom configuration
    process_group_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=3600*3))
    accelerator = Accelerator(kwargs_handlers=[process_group_kwargs])

    if data_args.overwrite_output_dir and os.path.exists(data_args.output_dir) and os.path.isdir(data_args.output_dir):
        logger.info("Cleaning output dir from previous run...")
        shutil.rmtree(data_args.output_dir)

    # 3. Load annotated dataset
    logger.info("*** Load annotated dataset ***")
    if data_args.dataset_split_name is not None:
        raw_datasets = DatasetDict()
        data_splits = data_args.dataset_split_name.split("+")
        # load on a split-wise basis
        for split in data_splits:
            with accelerator.local_main_process_first():
                raw_datasets[split] = load_dataset(
                    data_args.dataset_name,
                    data_args.dataset_config_name,
                    split=split,
                    cache_dir=model_args.cache_dir,
                    token=model_args.token,
                    num_proc=data_args.preprocessing_num_workers,
                )
    else:
        with accelerator.local_main_process_first():
            # load all splits for annotation
            raw_datasets = load_dataset(
                data_args.dataset_name,
                data_args.dataset_config_name,
                cache_dir=model_args.cache_dir,
                token=model_args.token,
                num_proc=data_args.preprocessing_num_workers,
            )

    raw_datasets_features = set(raw_datasets[next(iter(raw_datasets))].features.keys())

    if data_args.max_eval_samples is not None:
        for split in raw_datasets:
            raw_datasets[split] = raw_datasets[split].select(range(data_args.max_eval_samples))

    EXPECTED_COLUMNS = {"gender", "pitch", "noise", "reverberation", "speech_monotony", "speaking_rate"}
    if data_args.is_single_speaker:
        EXPECTED_COLUMNS = {"noise", "reverberation", "speech_monotony", "speaking_rate"}
        
    if data_args.is_new_speaker_prompt:
        EXPECTED_COLUMNS.remove("noise")
        EXPECTED_COLUMNS.add("sdr_noise")
        
    speaker_ids_to_name = {}
    speaker_id_column = data_args.speaker_id_column
    if data_args.speaker_id_column and data_args.speaker_ids_to_name_json:
        import json
        if data_args.is_single_speaker:
            raise ValueError(f"`is_single_speaker=True` but `speaker_ids_to_name_json={data_args.speaker_ids_to_name_json}`. Specify one or another.")
        
        EXPECTED_COLUMNS.add(data_args.speaker_id_column)
        with open(data_args.speaker_ids_to_name_json, "r") as read_file:
            speaker_ids_to_name = json.load(read_file)

    if not EXPECTED_COLUMNS.issubset(raw_datasets_features):
        missing_columns = EXPECTED_COLUMNS - raw_datasets_features
        raise ValueError(
            f"Missing columns {missing_columns} from the dataset features. Got dataset features {raw_datasets_features}"
        )

    # 4. Load pre-trained model
    logger.info("*** Load pretrained model ***")
    torch_dtype = (
        model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
    )
    quantization_config = get_quantization_config(model_args)

    model = AutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path,
        revision=model_args.model_revision,
        variant=model_args.model_variant,
        trust_remote_code=model_args.trust_remote_code,
        attn_implementation=model_args.attn_implementation,
        torch_dtype=torch_dtype,
        device_map=get_kbit_device_map() if quantization_config is not None else None,
        quantization_config=quantization_config,
        low_cpu_mem_usage=True,
        token=model_args.token,
    ).eval()

    if model_args.torch_compile:
        # torch compile only compatible with gemma and llama
        if not callable(getattr(model, "_setup_cache", None)):
            raise ValueError(
                f"Static k/v cache is not compatible with the model {model.__class__.__name__}. Set `--torch_compile=False"
                "for dynamic k/v cache"
            )
        model.generation_config.cache_implementation = "static"
        # compile the forward pass (but not the top-{p,k} sampling)
        model = torch.compile(model, mode="reduce-overhead", fullgraph=True)

    tokenizer = AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        revision=model_args.model_revision,
        trust_remote_code=model_args.trust_remote_code,
        use_fast=model_args.use_fast_tokenizer,
        padding_side="left",
    )
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.bos_token_id
        model.generation_config.pad_token_id = model.generation_config.eos_token_id

    speaker_name = data_args.speaker_name
    is_single_speaker = data_args.is_single_speaker
    is_new_speaker_prompt = data_args.is_new_speaker_prompt
    accent_column_name = data_args.accent_column

    def prepare_dataset(sample):
        sample_prompt = PROMPT
        if is_single_speaker:
            sample_prompt = SINGLE_SPEAKER_PROMPT if not is_new_speaker_prompt else NEW_SINGLE_SPEAKER_PROMPT
            sample_prompt = sample_prompt.replace(f"[speaker_name]", speaker_name)
        elif (speaker_id_column and speaker_ids_to_name.get(str(sample.get(speaker_id_column)), None)):
            name =  speaker_ids_to_name.get(str(sample.get(speaker_id_column)), None)
            sample_prompt = SINGLE_SPEAKER_PROMPT if not is_new_speaker_prompt else NEW_SINGLE_SPEAKER_PROMPT
            sample_prompt = sample_prompt.replace(f"[speaker_name]", name)
        elif is_new_speaker_prompt and accent_column_name is not None:
            sample_prompt = NEW_PROMPT if sample.get(accent_column_name, "Unindentified") == "Unindentified" else NEW_PROMPT_WITH_ACCENT
        elif is_new_speaker_prompt:
            sample_prompt = NEW_PROMPT
        for key in EXPECTED_COLUMNS:
            sample_prompt = sample_prompt.replace(f"[{key}]", sample[key])
        if accent_column_name is not None and sample.get(accent_column_name, "Unindentified") != "Unindentified":
            sample_prompt = sample_prompt.replace("[accent]", sample["accent"])
            
        sample_prompt = [{"role": "user", "content": sample_prompt}]
        token_ids = tokenizer.apply_chat_template(sample_prompt)
        sample["input_ids"] = token_ids
        return sample

    with accelerator.local_main_process_first():
        vectorized_datasets = raw_datasets.map(
            prepare_dataset, num_proc=data_args.preprocessing_num_workers, desc="Preparing prompts"
        )

    # Prepare everything with our `accelerator`
    model = accelerator.prepare(model)
    data_collator = DataCollatorWithPadding(tokenizer)

    def generate_step(batch):
        output_ids = accelerator.unwrap_model(model).generate(
            batch["input_ids"],
            attention_mask=batch["attention_mask"],
            do_sample=model_args.do_sample,
            temperature=model_args.temperature,
            max_new_tokens=model_args.max_new_tokens,
        )
        output_ids = accelerator.pad_across_processes(output_ids, dim=1, pad_index=tokenizer.pad_token_id)
        return output_ids

    def postprocess_dataset(batch):
        prompt_texts = tokenizer.batch_decode(batch["input_ids"], skip_special_tokens=True)
        generated_texts = tokenizer.batch_decode(batch["generated_ids"], skip_special_tokens=True)
        
        batch["text_description"] = [generated_text[len(prompt_text) :] for (prompt_text, generated_text) in zip(prompt_texts, generated_texts)]
        return batch

    for split in vectorized_datasets:
        data_loader = DataLoader(
            vectorized_datasets[split],
            batch_size=model_args.per_device_eval_batch_size,
            collate_fn=data_collator,
            num_workers=data_args.dataloader_num_workers,
            pin_memory=True,
        )
        data_loader = accelerator.prepare(data_loader)
        total_inference_steps = len(data_loader)
        progress_bar = tqdm(
            range(total_inference_steps), desc=" ... ", position=0, disable=not accelerator.is_local_main_process
        )

        split_output_dir = os.path.join(data_args.output_dir, split)
        all_generated_ids, cur_step = get_last_checkpoint(split_output_dir, accelerator.is_local_main_process)
        accelerator.wait_for_everyone()

        if cur_step > 0:
            logger.info(f"Resuming {split} from step {cur_step}")
            # efficiently skip the first n batches
            data_loader = skip_first_batches(data_loader, cur_step)
            progress_bar.update(cur_step)

        while cur_step < total_inference_steps:
            for batch in data_loader:
                generated_ids = generate_step(batch)
                generated_ids = accelerator.gather_for_metrics(generated_ids)
                if accelerator.is_local_main_process:
                    all_generated_ids.extend(generated_ids.cpu().numpy())

                cur_step += 1
                progress_bar.update(1)

                if (cur_step % data_args.save_steps == 0) or (cur_step == total_inference_steps):
                    if accelerator.is_main_process:
                        save_checkpoint(split_output_dir, all_generated_ids, cur_step)
                        rotate_checkpoints(data_args.save_total_limit, output_dir=split_output_dir)
                    accelerator.wait_for_everyone()

        if accelerator.is_local_main_process:
            vectorized_datasets[split] = vectorized_datasets[split].add_column("generated_ids", all_generated_ids)

        if accelerator.is_main_process:
            vectorized_datasets[split] = vectorized_datasets[split].map(
                postprocess_dataset,
                batched=True,
                num_proc=data_args.preprocessing_num_workers,
                desc="Postprocessing dataset",
                remove_columns=["input_ids", "generated_ids"],
            )
        accelerator.wait_for_everyone()

    if accelerator.is_main_process:
        vectorized_datasets.save_to_disk(data_args.output_dir)
        if data_args.push_to_hub:
            vectorized_datasets.push_to_hub(
                data_args.hub_dataset_id,
                config_name=data_args.dataset_config_name if data_args.dataset_config_name is not None else "default",
                token=model_args.token,
            )
    accelerator.wait_for_everyone()
    accelerator.end_training()