def set_flash_attention()

in assets/training/finetune_acft_hf_nlp/src/finetune/finetune.py [0:0]


def set_flash_attention(args: Namespace):
    """Set Flash Attention related parameters."""
    flash_attention_load_model_kwargs = {}
    if (
        hasattr(args, "model_type")
        and args.model_type in FORCE_FLASH_ATTENTION_2_MODEL_TYPES
    ):
        # only Ampere or higher architecture supports Flash attention 2
        # Flash attention 2 is supported with 16-bit, 8-bit anf 4-bit
        if torch.cuda.is_available() and torch.cuda.is_bf16_supported() and args.precision in [16, 8, 4]:
            # `use_flash_attention_2=True` will be deprecated, use `attn_implementation="flash_attention_2"`
            flash_attention_load_model_kwargs.update({"attn_implementation": "flash_attention_2"})
            setattr(args, "apply_flash_attention", True)
            setattr(args, "flash_attention_version", 2)
        # elif args.precision == 16:
        #     # Flash attention is supported with only 16-bit
        #     setattr(args, "apply_flash_attention", True)
        #     setattr(args, "flash_attention_version", 1)
        # else:
        #     # unable to use Flash attention as precision is not supported
        #     logger.warning(f"{args.precision}-bit precision is not supported for Flash attention.")
        #     logger.warning("Disabling Flash attention.")
        #     setattr(args, "apply_flash_attention", False)
        #     setattr(args, "flash_attention_version", -1)
        else:
            logger.warning("Flash Attention is not supported on current compute.")
            setattr(args, "apply_flash_attention", False)
            setattr(args, "flash_attention_version", -1)
        if args.flash_attention_version != -1:
            # Set 16-bit precision value in Quantization case for Flash Attention to work.
            # Currently will fail with error `RuntimeError: FlashAttention only support fp16 and bf16 data type`.
            # When fp16/bf16 is set the attention q,k,v layers are autocasted to respective precision from `uint8`.
            if (args.finetune_in_4bit or args.finetune_in_8bit) and not (args.fp16 or args.bf16):
                set_16bit_precision(args)
            # Flash attention is supported only when model is loaded in respective supported precision
            if args.bf16:
                flash_attention_load_model_kwargs.update({"torch_dtype": torch.bfloat16})
            elif args.fp16:
                flash_attention_load_model_kwargs.update({"torch_dtype": torch.float16})
            # update finetune_config to load model with flash_attention_2/torch_dtype
            args.finetune_config = deep_update(
                args.finetune_config,
                {
                    "load_model_kwargs": flash_attention_load_model_kwargs,
                }
            )
    else:
        setattr(args, "apply_flash_attention", False)
        setattr(args, "flash_attention_version", -1)
    if args.precision == 32 and (args
                                 .finetune_config
                                 .get("load_model_kwargs", {})
                                 .get("use_flash_attention_2", False) is True):
        # Flash attention is not supported with 32-bit precision
        logger.warning("Flash Attention is not supported with 32-bit precision.")
        raise ACFTValidationException._with_error(
                    AzureMLError.create(
                        ACFTUserError,
                        pii_safe_message=(
                            "Flash Attention is not supported with 32-bit precision."
                        )
                    )
                )

    logger.info(f"enable Flash attention: {getattr(args, 'apply_flash_attention', None)}")
    logger.info(f"Using Flash Attention version: {getattr(args, 'flash_attention_version', None)}")
    logger.info(f"Flash Attention model load kwargs: {flash_attention_load_model_kwargs}")