in assets/training/finetune_acft_hf_nlp/src/finetune/finetune.py [0:0]
def set_flash_attention(args: Namespace):
"""Set Flash Attention related parameters."""
flash_attention_load_model_kwargs = {}
if (
hasattr(args, "model_type")
and args.model_type in FORCE_FLASH_ATTENTION_2_MODEL_TYPES
):
# only Ampere or higher architecture supports Flash attention 2
# Flash attention 2 is supported with 16-bit, 8-bit anf 4-bit
if torch.cuda.is_available() and torch.cuda.is_bf16_supported() and args.precision in [16, 8, 4]:
# `use_flash_attention_2=True` will be deprecated, use `attn_implementation="flash_attention_2"`
flash_attention_load_model_kwargs.update({"attn_implementation": "flash_attention_2"})
setattr(args, "apply_flash_attention", True)
setattr(args, "flash_attention_version", 2)
# elif args.precision == 16:
# # Flash attention is supported with only 16-bit
# setattr(args, "apply_flash_attention", True)
# setattr(args, "flash_attention_version", 1)
# else:
# # unable to use Flash attention as precision is not supported
# logger.warning(f"{args.precision}-bit precision is not supported for Flash attention.")
# logger.warning("Disabling Flash attention.")
# setattr(args, "apply_flash_attention", False)
# setattr(args, "flash_attention_version", -1)
else:
logger.warning("Flash Attention is not supported on current compute.")
setattr(args, "apply_flash_attention", False)
setattr(args, "flash_attention_version", -1)
if args.flash_attention_version != -1:
# Set 16-bit precision value in Quantization case for Flash Attention to work.
# Currently will fail with error `RuntimeError: FlashAttention only support fp16 and bf16 data type`.
# When fp16/bf16 is set the attention q,k,v layers are autocasted to respective precision from `uint8`.
if (args.finetune_in_4bit or args.finetune_in_8bit) and not (args.fp16 or args.bf16):
set_16bit_precision(args)
# Flash attention is supported only when model is loaded in respective supported precision
if args.bf16:
flash_attention_load_model_kwargs.update({"torch_dtype": torch.bfloat16})
elif args.fp16:
flash_attention_load_model_kwargs.update({"torch_dtype": torch.float16})
# update finetune_config to load model with flash_attention_2/torch_dtype
args.finetune_config = deep_update(
args.finetune_config,
{
"load_model_kwargs": flash_attention_load_model_kwargs,
}
)
else:
setattr(args, "apply_flash_attention", False)
setattr(args, "flash_attention_version", -1)
if args.precision == 32 and (args
.finetune_config
.get("load_model_kwargs", {})
.get("use_flash_attention_2", False) is True):
# Flash attention is not supported with 32-bit precision
logger.warning("Flash Attention is not supported with 32-bit precision.")
raise ACFTValidationException._with_error(
AzureMLError.create(
ACFTUserError,
pii_safe_message=(
"Flash Attention is not supported with 32-bit precision."
)
)
)
logger.info(f"enable Flash attention: {getattr(args, 'apply_flash_attention', None)}")
logger.info(f"Using Flash Attention version: {getattr(args, 'flash_attention_version', None)}")
logger.info(f"Flash Attention model load kwargs: {flash_attention_load_model_kwargs}")