optimum/tpu/fsdp_v2.py (43 lines of code) (raw):
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utility functions to provide FSDPv2 configuration for TPU training.
"""
from typing import Any, Dict, List, Union
PreTrainedModel = Any
# NOTE: instead of the above, modeling_utils.PreTrainedModel should be used, but since the usage is only for type
# hinting, it is not imported here, so to avoid pulling imports of torch_xla.
def use_fsdp_v2():
"""
Enable FSDPv2 for TPU training.
"""
import torch_xla.runtime as xr
# FSDPv2 requires SPMD to be enabled.
xr.use_spmd()
def get_fsdp_config(*cls_to_wrap: Union[str | List[str]]) -> Dict:
"""
Returns the FSDPv2 configuration for a given class to wrap.
Args:
cls_to_wrap: One or more class names to wrap with FSDPv2.
Returns:
A dictionary with the FSDPv2 configuration.
"""
return {
"transformer_layer_cls_to_wrap": [*cls_to_wrap],
"xla": True,
"xla_fsdp_v2": True,
"xla_fsdp_grad_ckpt": True,
}
def _unwrap_model(model: PreTrainedModel) -> PreTrainedModel:
"""
Unwraps the model from the PeftModel wrapper.
Args:
model: The model to unwrap.
Returns:
The unwrapped model.
"""
try:
from peft.peft_model import LoraModel, PeftModel
if isinstance(model, PeftModel) and isinstance(model.base_model, LoraModel):
return model.base_model.model
return model
except ImportError:
return model
def get_fsdp_training_args(model: PreTrainedModel) -> Dict:
"""
Returns the default FSDPv2 training arguments for a model of a known class.
Args:
model: The model to train with FSDPv2.
Returns:
A dictionary with the FSDPv2 training arguments.
"""
model = _unwrap_model(model)
model_type = model.config.model_type
matched_model = False
if model_type == "gemma":
from transformers import GemmaForCausalLM as HFGemmaForCausalLLM
from .modeling_gemma import GemmaForCausalLM
if isinstance(model, GemmaForCausalLM) or isinstance(model, HFGemmaForCausalLLM):
cls_to_wrap = "GemmaDecoderLayer"
matched_model = True
elif model_type == "llama":
from transformers import LlamaForCausalLM as HFLlamaForCausalLLM
from .modeling_llama import LlamaForCausalLM
if isinstance(model, LlamaForCausalLM) or isinstance(model, HFLlamaForCausalLLM):
cls_to_wrap = "LlamaDecoderLayer"
matched_model = True
if not matched_model:
raise ValueError(f"Model {model} configuration cannot be auto-generated, use get_fsdp_config instead.")
fsdp_training_args = {
"fsdp": "full_shard",
"fsdp_config": get_fsdp_config(cls_to_wrap),
}
return fsdp_training_args