optimum/tpu/fsdp_v2.py (43 lines of code) (raw):

# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Utility functions to provide FSDPv2 configuration for TPU training. """ from typing import Any, Dict, List, Union PreTrainedModel = Any # NOTE: instead of the above, modeling_utils.PreTrainedModel should be used, but since the usage is only for type # hinting, it is not imported here, so to avoid pulling imports of torch_xla. def use_fsdp_v2(): """ Enable FSDPv2 for TPU training. """ import torch_xla.runtime as xr # FSDPv2 requires SPMD to be enabled. xr.use_spmd() def get_fsdp_config(*cls_to_wrap: Union[str | List[str]]) -> Dict: """ Returns the FSDPv2 configuration for a given class to wrap. Args: cls_to_wrap: One or more class names to wrap with FSDPv2. Returns: A dictionary with the FSDPv2 configuration. """ return { "transformer_layer_cls_to_wrap": [*cls_to_wrap], "xla": True, "xla_fsdp_v2": True, "xla_fsdp_grad_ckpt": True, } def _unwrap_model(model: PreTrainedModel) -> PreTrainedModel: """ Unwraps the model from the PeftModel wrapper. Args: model: The model to unwrap. Returns: The unwrapped model. """ try: from peft.peft_model import LoraModel, PeftModel if isinstance(model, PeftModel) and isinstance(model.base_model, LoraModel): return model.base_model.model return model except ImportError: return model def get_fsdp_training_args(model: PreTrainedModel) -> Dict: """ Returns the default FSDPv2 training arguments for a model of a known class. Args: model: The model to train with FSDPv2. Returns: A dictionary with the FSDPv2 training arguments. """ model = _unwrap_model(model) model_type = model.config.model_type matched_model = False if model_type == "gemma": from transformers import GemmaForCausalLM as HFGemmaForCausalLLM from .modeling_gemma import GemmaForCausalLM if isinstance(model, GemmaForCausalLM) or isinstance(model, HFGemmaForCausalLLM): cls_to_wrap = "GemmaDecoderLayer" matched_model = True elif model_type == "llama": from transformers import LlamaForCausalLM as HFLlamaForCausalLLM from .modeling_llama import LlamaForCausalLM if isinstance(model, LlamaForCausalLM) or isinstance(model, HFLlamaForCausalLLM): cls_to_wrap = "LlamaDecoderLayer" matched_model = True if not matched_model: raise ValueError(f"Model {model} configuration cannot be auto-generated, use get_fsdp_config instead.") fsdp_training_args = { "fsdp": "full_shard", "fsdp_config": get_fsdp_config(cls_to_wrap), } return fsdp_training_args