def main()

in train.py [0:0]


def main(args):
    model_slug = args.model_id.split("/")[-1]
    ds_slug = args.dataset_id.split("/")[-1]
    run_name = f"model@{model_slug}-ds@{ds_slug}-bs@{args.batch_size}-8bit@{args.use_8bit_adam}-lora@{args.use_lora}-lr@{args.lr}-mp@{args.mixed_precision}-fve@{args.freeze_vision_tower}"
    output_dir = Path("./model_checkpoints_accelerate") / run_name
    args.output_dir = output_dir

    accelerator_project_config = ProjectConfiguration(project_dir=output_dir, logging_dir=output_dir / "logs")
    kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
    accelerator = Accelerator(
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        mixed_precision=args.mixed_precision,
        log_with=args.report_to,
        project_config=accelerator_project_config,
        kwargs_handlers=[kwargs],
    )

    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info(accelerator.state, main_process_only=False)
    if accelerator.is_local_main_process:
        transformers.utils.logging.set_verbosity_warning()
        diffusers.utils.logging.set_verbosity_info()
    else:
        transformers.utils.logging.set_verbosity_error()
        diffusers.utils.logging.set_verbosity_error()

    if accelerator.is_main_process:
        if args.output_dir is not None:
            os.makedirs(args.output_dir, exist_ok=True)

        tracker_name = "shot-categorizer"
        accelerator.init_trackers(tracker_name, config=vars(args), init_kwargs={"wandb": {"name": run_name}})

    train_model(accelerator, args)