in optimum/tpu/fsdp_v2.py [0:0]
def get_fsdp_training_args(model: PreTrainedModel) -> Dict:
"""
Returns the default FSDPv2 training arguments for a model of a known class.
Args:
model: The model to train with FSDPv2.
Returns:
A dictionary with the FSDPv2 training arguments.
"""
model = _unwrap_model(model)
model_type = model.config.model_type
matched_model = False
if model_type == "gemma":
from transformers import GemmaForCausalLM as HFGemmaForCausalLLM
from .modeling_gemma import GemmaForCausalLM
if isinstance(model, GemmaForCausalLM) or isinstance(model, HFGemmaForCausalLLM):
cls_to_wrap = "GemmaDecoderLayer"
matched_model = True
elif model_type == "llama":
from transformers import LlamaForCausalLM as HFLlamaForCausalLLM
from .modeling_llama import LlamaForCausalLM
if isinstance(model, LlamaForCausalLM) or isinstance(model, HFLlamaForCausalLLM):
cls_to_wrap = "LlamaDecoderLayer"
matched_model = True
if not matched_model:
raise ValueError(f"Model {model} configuration cannot be auto-generated, use get_fsdp_config instead.")
fsdp_training_args = {
"fsdp": "full_shard",
"fsdp_config": get_fsdp_config(cls_to_wrap),
}
return fsdp_training_args