def get_fsdp_training_args()

in optimum/tpu/fsdp_v2.py [0:0]


def get_fsdp_training_args(model: PreTrainedModel) -> Dict:
    """
    Returns the default FSDPv2 training arguments for a model of a known class.

    Args:
        model: The model to train with FSDPv2.

    Returns:
        A dictionary with the FSDPv2 training arguments.
    """
    model = _unwrap_model(model)
    model_type = model.config.model_type
    matched_model = False
    if model_type == "gemma":
        from transformers import GemmaForCausalLM as HFGemmaForCausalLLM

        from .modeling_gemma import GemmaForCausalLM

        if isinstance(model, GemmaForCausalLM) or isinstance(model, HFGemmaForCausalLLM):
            cls_to_wrap = "GemmaDecoderLayer"
            matched_model = True
    elif model_type == "llama":
        from transformers import LlamaForCausalLM as HFLlamaForCausalLLM

        from .modeling_llama import LlamaForCausalLM

        if isinstance(model, LlamaForCausalLM) or isinstance(model, HFLlamaForCausalLLM):
            cls_to_wrap = "LlamaDecoderLayer"
            matched_model = True

    if not matched_model:
        raise ValueError(f"Model {model} configuration cannot be auto-generated, use get_fsdp_config instead.")

    fsdp_training_args = {
        "fsdp": "full_shard",
        "fsdp_config": get_fsdp_config(cls_to_wrap),
    }
    return fsdp_training_args