def finetune_model()

in scripts/launcher_distributed.py [0:0]


def finetune_model() -> None:
    """
    Fine-tune a model using distributed training.

    Returns:
        None
    """
    print("***** Starting model fine-tuning *****")

    # Set custom environment variables
    # NCCL_DEBUG=INFO will dump a lot of NCCL-related debug information, which you can then search online if you find that some problems are reported.
    # Or if you’re not sure how to interpret the output you can share the log file in an Issue.
    custom_env: Dict[str, str] = {
        "HF_DATASETS_TRUST_REMOTE_CODE": "TRUE",
        "HF_TOKEN": args.hf_token,
        # "NCCL_DEBUG": "INFO",
        "WANDB_API_KEY": args.wandb_api_key,
        "WANDB_PROJECT": args.wandb_project,
        "WANDB_WATCH": args.wandb_watch,
        "WANDB_DIR": args.log_dir,
    }

    set_custom_env(custom_env)
    os.makedirs(args.model_dir, exist_ok=True)
    os.makedirs(args.log_dir, exist_ok=True)
    os.makedirs(args.model_output_dir, exist_ok=True)

    with torch_distributed_zero_first(LOCAL_RANK):
        # Download the model
        download_model(args.model_id, args.model_dir, args.ignore_patterns)

    # Construct the fine-tuning command
    if "single" in args.tune_recipe:
        print("***** Single Device Training *****")
        full_command = f"tune run {args.tune_recipe} --config {args.tune_finetune_yaml}"
        # Run the fine-tuning command
        run_command(full_command)
    else:
        print("***** Distributed Training *****")
        if dist.is_initialized():
            print("Destroying current process group before launching tune run...")
            dist.destroy_process_group()

        if GLOBAL_RANK in {-1, 0}:
            # Run the fine-tuning command
            full_command = (
                f"tune run --master-addr {MASTER_ADDR} --master-port {MASTER_PORT} --nnodes {NUM_NODES} --nproc_per_node {NUM_GPUS_PER_NODE} "
                f"{args.tune_recipe} "
                f"--config {args.tune_finetune_yaml}"
            )
            run_command(full_command)