def add_sd_args_to_parser()

in assets/training/finetune_acft_image/src/finetune/finetune.py [0:0]


def add_sd_args_to_parser(parser):
    """Add Stable Diffusion related args to parser."""
    # # Data inputs
    parser.add_argument(
        "--class_data_dir",
        type=str,
        default="class_data_dir",
        required=False,
        help="A folder containing the training data of class images."
    )

    # Instance prompt
    parser.add_argument(
        "--instance_prompt",
        type=str,
        default=None,
        required=False,
        help="The prompt with identifier specifying the instance"
    )

    parser.add_argument(
        "--resolution",
        type=int,
        default=512,
        required=False,
        help="The image resolution for training."
    )

    parser.add_argument(
        "--sample_batch_size",
        type=int,
        default=4,
        required=False,
        help="Batch size (per device) for sampling class images for prior preservation."
    )

    # Tokenizer
    parser.add_argument(
        "--tokenizer_name",
        type=str,
        default="openai/clip-vit-large-patch14",
        choices=("openai/clip-vit-large-patch14"),
        help="Pretrained tokenizer name or path if not the same as model_name"
    )
    parser.add_argument(
        "--tokenizer_max_length",
        type=int,
        default=None,
        required=False,
        help="The maximum length of the tokenizer. If not set, will default to the tokenizer's max length."
    )

    # Text Encoder:
    parser.add_argument(
        "--text_encoder_type",
        type=str,
        default="CLIPTextModel",
        choices=("CLIPTextModel", "T5EncoderModel"),
        help="Text Encoder or path if not the same as model_name"
    )
    parser.add_argument(
        "--text_encoder_name",
        type=str,
        required=False,
        help="Text Encoder or path if not the same as model_name",
    )
    parser.add_argument(
        "--train_text_encoder",
        type=lambda x: bool(str2bool(str(x), "train_text_encoder")),
        default=False,
        help="Whether to train the text encoder. If set, the text encoder should be float32 precision."
    )
    parser.add_argument(
        "--pre_compute_text_embeddings",
        type=lambda x: bool(str2bool(str(x), "pre_compute_text_embeddings")),
        default=True,
        help=(
            "Whether or not to pre-compute text embeddings. If text embeddings are pre-computed,"
            "the text encoder will not be kept in memory during training and will leave more GPU memory"
            "available for training the rest of the model. This is not compatible with `--train_text_encoder`."
        )
    )
    parser.add_argument(
        "--text_encoder_use_attention_mask",
        type=lambda x: bool(str2bool(str(x), "text_encoder_use_attention_mask")),
        default=False,
        required=False,
        help="Whether to use attention mask for the text encoder"
    )
    parser.add_argument(
        "--skip_save_text_encoder",
        type=lambda x: bool(str2bool(str(x), "skip_save_text_encoder")),
        default=False,
        required=False,
        help="Set to not save text encoder"
    )

    # Residual noise predictio using UNET - decide whether to use timesteps as labels or None
    parser.add_argument(
        "--class_labels_conditioning",
        type=str,
        required=False,
        default=None,
        choices=("timesteps", None),
        help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`."
    )

    # Noise Scheduler
    parser.add_argument(
        "--noise_scheduler_name",
        type=str,
        required=False,
        choices=("DPMSolverMultistepScheduler", "DDPMScheduler", "PNDMScheduler"),
        help="The noise scheduler name to use for the diffusion process."
    )
    parser.add_argument(
        "--noise_scheduler_num_train_timesteps",
        type=int,
        required=False,
        help="The number of diffusion steps to train the model."
    )
    parser.add_argument(
        "--noise_scheduler_variance_type",
        type=str,
        choices=("fixed_small", "fixed_small_log", "fixed_large", "fixed_large_log", "learned", "learned_range"),
        required=False,
        help="Clip the variance when adding noise to the denoised sample."
    )
    parser.add_argument(
        "--noise_scheduler_prediction_type",
        type=str,
        choices=("epsilon", "sample", "v_prediction"),
        required=False,
        help=(
            "Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion"
            "process), `sample` (directly predicts the noisy sample`) or `v_prediction` "
            "(see section 2.4 of [Imagen Video](https://imagen.research.google/video/paper.pdf) paper)."
        )
    )
    parser.add_argument(
        "--noise_scheduler_timestep_spacing",
        type=str,
        required=False,
        help=(
            "The way the timesteps should be scaled. Refer to Table 2 of the "
            "[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) "
            "for more information."
        )
    )
    parser.add_argument(
        "--noise_scheduler_steps_offset",
        type=int,
        default=0,
        required=False,
        help=(
            "An offset added to the inference steps. You can use a combination of `offset=1` and "
            "`set_alpha_to_one=False` to make the last step use step 0 for the previous "
            " alpha product like in Stable Diffusion."
        ),
    )
    parser.add_argument(
        "--extra_noise_scheduler_args",
        type=str,
        required=False,
        help=(
            "Optional additional arguments that are supplied to noise scheduler. The arguments should be semi-colon "
            "separated key value pairs and should be enclosed in double quotes. "
            "For example, 'clip_sample_range=1.0; clip_sample=True' for DDPMScheduler."
        )
    )

    # Offset Noise
    parser.add_argument(
        "--offset_noise",
        type=lambda x: bool(str2bool(str(x), "offset_noise")),
        required=False,
        help=(
            "Fine-tuning against a modified noise"
            " See: https://www.crosslabs.org//blog/diffusion-with-offset-noise for more information."
        )
    )

    # Rebalance the loss
    parser.add_argument(
        "--snr_gamma",
        type=float,
        default=None,
        help=(
            "SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
            "More details here: https://arxiv.org/abs/2303.09556."
        )
    )

    # Prior preservation loss
    parser.add_argument(
        "--with_prior_preservation",
        type=lambda x: bool(str2bool(str(x), "with_prior_preservation")),
        default=True,
        help="Ste to True for enabling prior preservation loss."
    )
    parser.add_argument(
        "--class_prompt",
        type=str,
        default=None,
        help="The prompt to specify images in the same class as provided instance images."
    )
    parser.add_argument(
        "--num_class_images",
        type=int,
        default=100,
        help=(
            "Minimal class images for prior preservation loss. If there are not enough images already present in"
            " class_data_dir, additional images will be sampled with class_prompt."
        )
    )
    parser.add_argument(
        "--prior_generation_precision",
        type=str,
        default="fp32",
        choices=["fp32", "fp16", "bf16"],
        help=(
            "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
            " 1.10.and an Nvidia Ampere GPU.  Default to  fp16 if a GPU is available else fp32."
        )
    )
    parser.add_argument(
        "--prior_loss_weight",
        type=float,
        default=1.0,
        help="The weight of prior preservation loss."
    )

    # Validation
    parser.add_argument(
        "--validation_prompt",
        type=str,
        default=None,
        help="A prompt that is used during validation to verify that the model is learning."
    )
    parser.add_argument(
        "--num_validation_images",
        type=int,
        default=0,
        help="Number of images that should be generated during validation with `instance prompt`."
    )
    parser.add_argument(
        "--validation_steps",
        type=int,
        default=100,
        help=(
            "Run validation every X steps. Validation consists of running the prompt"
            " `args.validation_prompt` multiple times: `args.num_validation_images`"
            " and logging the images."
        )
    )
    parser.add_argument(
        "--validation_scheduler",
        type=str,
        default="DPMSolverMultistepScheduler",
        choices=("DPMSolverMultistepScheduler", "DDPMScheduler"),
        help="Select which scheduler to use for validation. DDPMScheduler is recommended for DeepFloyd IF."
    )

    return parser