optimum/neuron/accelerate/utils/misc.py (121 lines of code) (raw):

# coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utilities of various sorts related to accelerate with Neuron.""" import functools import gc import inspect import itertools from typing import TYPE_CHECKING, Callable, Dict, Optional, Union import torch from ....utils import logging from ...utils import is_torch_neuronx_available, is_torch_xla_available, patch_everywhere from ...utils.patching import Patcher from ...utils.require_utils import requires_neuronx_distributed, requires_safetensors, requires_torch_xla logger = logging.get_logger(__name__) if TYPE_CHECKING: import os # Dummy class to avoid import errors in type checking. class NeuronPeftModel: def __init__(self, *args, **kwargs): pass from transformers import PreTrainedModel if is_torch_neuronx_available(): from neuronx_distributed.pipeline import NxDPPModel def patched_accelerate_is_torch_xla_available(check_is_tpu=False, check_is_gpu=False): """ Fake `is_tpu_available` that returns `is_torch_xla_available` to patch `accelerate`. """ return is_torch_xla_available() def patch_accelerate_is_torch_xla_available(): if is_torch_xla_available(): import accelerate import torch_xla.core.xla_model as xm # Since `is_torch_xla_available` does not work properly for us, it does not import `xm`, which causes failure. # We set it manually. accelerate.accelerator.xm = xm accelerate.state.xm = xm accelerate.checkpointing.xm = xm patch_everywhere( "is_torch_xla_available", patched_accelerate_is_torch_xla_available, module_name_prefix="accelerate" ) _ORIG_TORCH_FINFO = torch.finfo @requires_neuronx_distributed @requires_safetensors def torch_xla_safe_save_file( tensors: Dict[str, torch.Tensor], filename: Union[str, "os.PathLike"], metadata: Optional[Dict[str, str]] = None, master_only: bool = True, global_master: bool = False, ): """ Torch XLA compatible implementation of `safetensors.torch.save_file`. """ from neuronx_distributed.parallel_layers.utils import move_all_tensor_to_cpu from safetensors.torch import save_file from torch_xla.core.xla_model import is_master_ordinal should_write_data = not master_only or is_master_ordinal(local=not global_master) cpu_data = move_all_tensor_to_cpu(tensors, convert=should_write_data) if should_write_data: save_file(cpu_data, filename, metadata=metadata) @requires_neuronx_distributed def create_patched_save_pretrained(orig_save_pretrained_function: Callable[["PreTrainedModel"], None]): """ Creates a wrapper around the `transformers.modeling_utils.PreTrainedModel.save_pretrained` method. This methods calls `tensor.data_ptr()` on the model parameters, which causes segmentation fault when the tensors are on the XLA device. To prevent that, this wrapper calls `save_pretrained` with the model on the CPU device. """ import torch_xla.core.xla_model as xm from neuronx_distributed.parallel_layers.parallel_state import ( get_data_parallel_rank, model_parallel_is_initialized, ) from neuronx_distributed.parallel_layers.utils import move_all_tensor_to_cpu orig_self = orig_save_pretrained_function.__self__ orig_func = orig_save_pretrained_function.__func__ patcher = Patcher([("transformers.modeling_utils.safe_save_file", torch_xla_safe_save_file)]) @functools.wraps(orig_func) def wrapper(*args, **kwargs): self = args[0] if model_parallel_is_initialized(): should_write_data = get_data_parallel_rank() == 0 else: should_write_data = xm.is_master_ordinal(local=True) orig_state_dict = self.state_dict() cpu_state_dict = move_all_tensor_to_cpu(self.state_dict(), convert=should_write_data) self.load_state_dict(cpu_state_dict, assign=True) output = None if should_write_data: with patcher: output = orig_func(*args, **kwargs) self.load_state_dict(orig_state_dict, assign=True) xm.mark_step() del cpu_state_dict gc.collect() return output return wrapper.__get__(orig_self) # TODO: @michaelbenayoun # Needs to make it work in the general case or be deleted and only use `apply_activation_checkpointing`. @requires_torch_xla def patched_gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): from torch_xla.utils.checkpoint import checkpoint if not self.supports_gradient_checkpointing: raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") if gradient_checkpointing_kwargs is None: gradient_checkpointing_kwargs = {"use_reentrant": True} gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs) # For old GC format (transformers < 4.35.0) for models that live on the Hub # we will fall back to the overwritten `_set_gradient_checkpointing` method _is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters if not _is_using_old_format: self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func) else: self.apply(functools.partial(self._set_gradient_checkpointing, value=True)) logger.warning( "You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)." "Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model." ) if getattr(self, "_hf_peft_config_loaded", False): # When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True # we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334 # When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate # the gradients to make sure the gradient flows. self.enable_input_require_grads() @requires_neuronx_distributed def apply_activation_checkpointing(model: Union["PreTrainedModel", "NxDPPModel", "NeuronPeftModel"]): from neuronx_distributed.pipeline import NxDPPModel from neuronx_distributed.utils.activation_checkpoint import ( apply_activation_checkpointing as nxd_apply_activation_checkpointing, ) from ...peft.peft_model import NeuronPeftModel if isinstance(model, NeuronPeftModel): model._prepare_model_for_gradient_checkpointing(model.get_base_model()) if isinstance(model, NxDPPModel): modules = itertools.chain(module.modules() for module in model.local_stage_modules) else: modules = model.modules() gradient_checkpointing_modules = set() for module in modules: if isinstance(module, torch.nn.ModuleList): for mod in module: # TODO: @michaelbenayoun. Need to find a better way to identify the blocks to apply gradient # checkpointing to. if "Layer" in mod.__class__.__name__ or "Block" in mod.__class__.__name__: gradient_checkpointing_modules.add(mod) def check_fn(m: torch.nn.Module) -> bool: return m in gradient_checkpointing_modules if gradient_checkpointing_modules: nxd_apply_activation_checkpointing(model, check_fn=check_fn)