optimum/habana/transformers/trainer.py (1,901 lines of code) (raw):
# coding=utf-8
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The Gaudi Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.
"""
import contextlib
import copy
import functools
import inspect
import json
import math
import os
import random
import shutil
import time
import warnings
from collections.abc import Mapping
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
import huggingface_hub.utils as hf_hub_utils
import numpy as np
import torch
from accelerate import DistributedType, skip_first_batches
from accelerate.data_loader import SeedableRandomSampler
from accelerate.utils import (
DistributedDataParallelKwargs,
TorchTensorParallelPlugin,
load_fsdp_model,
load_fsdp_optimizer,
save_fsdp_model,
save_fsdp_optimizer,
)
from huggingface_hub import upload_folder
from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler
from transformers import Trainer
from transformers.data.data_collator import DataCollator
from transformers.debug_utils import DebugOption, DebugUnderflowOverflow
from transformers.feature_extraction_utils import FeatureExtractionMixin
from transformers.image_processing_utils import BaseImageProcessor
from transformers.integrations.deepspeed import (
deepspeed_load_checkpoint,
is_deepspeed_available,
is_deepspeed_zero3_enabled,
)
from transformers.modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
from transformers.processing_utils import ProcessorMixin
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer import _get_fsdp_ckpt_kwargs, _is_peft_model, safe_globals
from transformers.trainer_callback import ExportableState, TrainerCallback, TrainerState
from transformers.trainer_pt_utils import (
DistributedTensorGatherer,
EvalLoopContainer,
IterableDatasetShard,
LengthGroupedSampler,
SequentialDistributedSampler,
find_batch_size,
get_model_param_count,
nested_concat,
nested_detach,
reissue_pt_warnings,
remove_dummy_checkpoint,
)
from transformers.trainer_utils import (
PREFIX_CHECKPOINT_DIR,
EvalLoopOutput,
EvalPrediction,
HubStrategy,
PredictionOutput,
SaveStrategy,
TrainOutput,
denumpify_detensorize,
enable_full_determinism,
find_executable_batch_size,
get_last_checkpoint,
has_length,
)
from transformers.training_args import OptimizerNames, ParallelMode, TrainingArguments
from transformers.utils import (
ADAPTER_CONFIG_NAME,
ADAPTER_SAFE_WEIGHTS_NAME,
ADAPTER_WEIGHTS_NAME,
CONFIG_NAME,
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
PushInProgress,
is_accelerate_available,
is_datasets_available,
is_peft_available,
is_safetensors_available,
)
from transformers.utils.deprecation import deprecate_kwarg
from optimum.utils import logging
from ..accelerate import GaudiAccelerator
from ..accelerate.utils import FP8ContextWrapper
from ..utils import (
HabanaGenerationTime,
HabanaProfile,
get_hpu_memory_stats,
set_seed,
speed_metrics,
to_device_dtype,
)
from .gaudi_configuration import GAUDI_CONFIG_NAME, GaudiConfig
from .integrations.deepspeed import deepspeed_init
from .trainer_utils import convert_into_dtypes, get_dtype
from .training_args import GaudiTrainingArguments
if is_datasets_available():
import datasets
if is_safetensors_available():
import safetensors.torch
if is_peft_available():
from peft import PeftModel
from peft.utils import PeftType
if is_deepspeed_available():
from accelerate.utils import DeepSpeedSchedulerWrapper
from accelerate.utils import DataLoaderConfiguration, is_torch_version
def _get_input_update_settings(model, lazy_mode: Optional[bool] = None) -> Tuple[bool, Dict]:
"""
Determines whether the input settings need to be updated.
Currently (attn_softmax_bf16, use_flash_attention, flash_attention_recompute,
flash_attention_causal_mask) are enabled only for llama, qwen2, starcoder2, gemma, baichuan
and chatglm
lazy_mode for llama, qwen2, starcoder2 and mistral
Args:
model: The model instance for which the input update settings are being evaluated
lazy_mode[Optional[bool]]: Whether to use lazy mode for the model (defaults to `None`)
Returns:
Tuple[bool, Dict]: A flag indicating whether the input settings should be updated.
A dictionary containing the specific input settings that need to be updated, if any
"""
inputs_update: Dict = {}
should_update_inputs = (getattr(model, "generation_config", None) is not None) and (
model.config.model_type in ("llama", "qwen2", "starcoder2", "gemma", "baichuan", "chatglm", "deepseek_v2")
)
if should_update_inputs:
if model.generation_config.attn_softmax_bf16:
inputs_update["attn_softmax_bf16"] = True
if model.generation_config.use_flash_attention:
inputs_update["use_flash_attention"] = True
if model.generation_config.flash_attention_recompute:
inputs_update["flash_attention_recompute"] = True
if model.generation_config.flash_attention_causal_mask:
inputs_update["flash_attention_causal_mask"] = True
should_update_inputs = (
(getattr(model, "generation_config", None) is not None)
and (model.config.model_type in ("llama", "qwen2", "starcoder2", "mistral"))
and (lazy_mode is not None)
)
if should_update_inputs:
if _is_peft_model(model):
forward_method = getattr(model.get_base_model(), "forward")
else:
forward_method = getattr(model, "forward")
signature = inspect.signature(forward_method)
if "lazy_mode" in signature.parameters:
inputs_update["lazy_mode"] = lazy_mode
should_update_inputs: bool = len(inputs_update) > 0
return should_update_inputs, inputs_update
if TYPE_CHECKING:
import optuna
DATA_SAMPLERS = [RandomSampler, SeedableRandomSampler]
logger = logging.get_logger(__name__)
# Name of the files used for checkpointing
TRAINING_ARGS_NAME = "training_args.bin"
TRAINER_STATE_NAME = "trainer_state.json"
OPTIMIZER_NAME = "optimizer.pt"
OPTIMIZER_NAME_BIN = "optimizer.bin"
SCALER_NAME = "scaler.pt"
SCHEDULER_NAME = "scheduler.pt"
class GaudiTrainer(Trainer):
"""
GaudiTrainer is built on top of the tranformers' Trainer to enable
deployment on Habana's Gaudi.
"""
@deprecate_kwarg("tokenizer", new_name="processing_class", version="5.0.0", raise_if_both_names=True)
def __init__(
self,
model: Union[PreTrainedModel, torch.nn.Module, None] = None,
gaudi_config: GaudiConfig = None,
args: TrainingArguments = None,
data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Union[Dataset, IterableDataset, "datasets.Dataset"]] = None,
eval_dataset: Optional[Union[Dataset, dict[str, Dataset], "datasets.Dataset"]] = None,
processing_class: Optional[
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
] = None,
model_init: Optional[Callable[[], PreTrainedModel]] = None,
compute_loss_func: Optional[Callable] = None,
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
callbacks: Optional[list[TrainerCallback]] = None,
optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
optimizer_cls_and_kwargs: Optional[tuple[type[torch.optim.Optimizer], dict[str, Any]]] = None,
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
):
if args is None:
output_dir = "tmp_trainer"
logger.info(f"No `GaudiTrainingArguments` passed, using `output_dir={output_dir}`.")
args = GaudiTrainingArguments(output_dir=output_dir)
self.use_hpu_amp = False
self.use_cpu_amp = False
if args.bf16 and not args.deepspeed:
if args.half_precision_backend == "hpu_amp":
self.use_hpu_amp = True
else:
self.use_cpu_amp = True
# Workaround to not set amp backend again when calling super().__init__(...)
# args.bf16 is not used after the __init__ anyway
args.bf16 = False
super().__init__(
model,
args,
data_collator,
train_dataset,
eval_dataset,
processing_class,
model_init,
compute_loss_func,
compute_metrics,
callbacks,
optimizers,
optimizer_cls_and_kwargs,
preprocess_logits_for_metrics,
)
if gaudi_config is None:
self.gaudi_config = GaudiConfig.from_pretrained(args.gaudi_config_name)
else:
self.gaudi_config = copy.deepcopy(gaudi_config)
if self.args.use_habana:
if self.args.use_hpu_graphs_for_inference:
self.already_wrapped_for_hpu_graphs = False
if self.args.deepspeed:
# Mixed-precision backends are turned off when using DeepSpeed since it manages this itself
self.gaudi_config.use_torch_autocast = False
self.use_hpu_amp = False
if self.args.deepspeed or self.args.dataloader_num_workers >= 1:
# To avoid warnings about parallelism in tokenizers
os.environ["TOKENIZERS_PARALLELISM"] = "false"
if self.gaudi_config.use_torch_autocast:
if not self.use_hpu_amp and not self.use_cpu_amp:
self.use_hpu_amp = True
logger.warning(
"The argument `--bf16` was not given but `use_torch_autocast` is True in the Gaudi configuration so mixed-precision training with Torch Autocast is enabled."
)
if self.use_hpu_amp and "PT_HPU_AUTOCAST_LOWER_PRECISION_OPS_LIST" not in os.environ:
self.gaudi_config.declare_autocast_bf16_fp32_ops()
if self.args.use_lazy_mode:
try:
import habana_frameworks.torch.core as htcore
except ImportError as error:
error.msg = f"Could not import habana_frameworks.torch.core. {error.msg}."
raise error
self.htcore = htcore
try:
import habana_frameworks.torch.hpu as hthpu
except ImportError as error:
error.msg = f"Could not import habana_frameworks.torch.hpu. {error.msg}."
raise error
if self.gaudi_config.use_dynamic_shapes:
hthpu.enable_dynamic_shape()
else:
hthpu.disable_dynamic_shape()
try:
from habana_frameworks.torch.hpu import random as hpu_random
except ImportError as error:
error.msg = f"Could not import habana_frameworks.torch.hpu.random. {error.msg}."
raise error
self.hpu_random = hpu_random
# Set the correct log level depending on the node
# Already done in super().init() but we have to do it again
# because we use optimum.utils.logging here and not
# transformers.utils.logging
log_level = args.get_process_log_level()
logging.set_verbosity(log_level)
logging.enable_default_handler()
logging.enable_explicit_format()
# Suppress PyTorch autocast warnings with Wav2Vec2
# This is a bug in PyTorch
warnings.filterwarnings(
"ignore", message="User provided device_type of 'cuda', but CUDA is not available. Disabling"
)
def _move_model_to_device(self, model, device):
model = model.to(device)
# Moving a model to HPU disconnects the tied weights, so we have to retie them.
if self.args.use_habana and hasattr(model, "tie_weights"):
model.tie_weights()
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.train_dataset is None or not has_length(self.train_dataset):
return None
# Build the sampler.
if self.args.group_by_length:
if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset):
lengths = (
self.train_dataset[self.args.length_column_name]
if self.args.length_column_name in self.train_dataset.column_names
else None
)
else:
lengths = None
model_input_name = (
self.processing_class.model_input_names[0] if self.processing_class is not None else None
)
return LengthGroupedSampler(
self.args.train_batch_size * self.args.gradient_accumulation_steps,
dataset=self.train_dataset,
lengths=lengths,
model_input_name=model_input_name,
)
else:
num_samples = len(self.train_dataset)
if (
not self.args.dataloader_drop_last
and num_samples % self.args.per_device_train_batch_size != 0
and self.args.parallel_mode != ParallelMode.DISTRIBUTED
):
# Make the total number of samples divisible by the batch size in lazy mode if needed
num_samples += (
self.args.per_device_train_batch_size - num_samples % self.args.per_device_train_batch_size
)
return RandomSampler(self.train_dataset, num_samples=num_samples)
def create_optimizer(self):
"""
Setup the optimizer.
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
"""
if self.optimizer is None:
decay_parameters = self.get_decay_parameter_names(self.model)
optimizer_grouped_parameters = []
for t_params, t_weight_decay in zip(
[
[p for n, p in self.model.named_parameters() if n in decay_parameters and p.requires_grad],
[p for n, p in self.model.named_parameters() if n not in decay_parameters and p.requires_grad],
],
[self.args.weight_decay, 0.0],
):
# Empty groups of parameters are filtered because they make FusedAdamW crash
if t_params:
optimizer_grouped_parameters.append(
{
"params": t_params,
"weight_decay": t_weight_decay,
}
)
if self.gaudi_config.use_fused_adam and self.args.use_habana:
try:
from habana_frameworks.torch.hpex.optimizers import FusedAdamW
except ImportError as error:
error.msg = (
f"Could not import 'FusedAdamW' from 'habana_frameworks.torch.hpex.optimizers'. {error.msg}."
)
raise error
optimizer_cls = FusedAdamW
optimizer_kwargs = {
"lr": self.args.learning_rate,
"betas": (self.args.adam_beta1, self.args.adam_beta2),
"eps": self.args.adam_epsilon,
}
elif self.optimizer_cls_and_kwargs is not None:
optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
else:
optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(self.args, self.model)
# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
# e.g. for GaLore optimizer.
if "params" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("params")
# Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
# e.g. for LOMO optimizer.
if "model" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("model")
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
# to avoid arguments conflicts.
if "optimizer_dict" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("optimizer_dict")
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
return self.optimizer
def _tune_save_checkpoint(self, checkpoint_dir: str):
output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
self.save_model(output_dir, _internal_call=True)
if self.args.should_save:
# TODO
# Update the `TrainerControl` state to where we are currently
# self.state.stateful_callbacks["TrainerControl"] = self.control.state()
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
def _wrap_model(self, model, training=True, dataloader=None):
# train/eval could be run multiple-times - if already wrapped, don't re-wrap it again
if self.accelerator.unwrap_model(model, keep_torch_compile=False) is not model:
return model
# Note: in torch.distributed mode, there's no point in wrapping the model
# inside a DistributedDataParallel as we'll be under `no_grad` anyways.
if not training:
return model
if self.args.parallel_mode == ParallelMode.DISTRIBUTED and self.args.distribution_strategy == "ddp":
kwargs = {}
if self.args.ddp_find_unused_parameters is not None:
kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters
if self.args.ddp_find_unused_parameters and self.args.gradient_checkpointing:
logger.warning(
"ddp_find_unused_parameters and gradient_checkpointing are both True, which may lead to an error:"
" https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021"
)
elif isinstance(model, PreTrainedModel):
# find_unused_parameters breaks checkpointing as per
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
kwargs["find_unused_parameters"] = not model.is_gradient_checkpointing
else:
kwargs["find_unused_parameters"] = True
if self.args.ddp_bucket_cap_mb is not None:
kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb
if self.args.use_habana:
kwargs["gradient_as_bucket_view"] = True
if self.args.ddp_broadcast_buffers is not None:
kwargs["broadcast_buffers"] = self.args.ddp_broadcast_buffers
self.accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs)
if self.args.use_hpu_graphs_for_training:
import habana_frameworks.torch as ht
ht.hpu.ModuleCacher()(model=model, inplace=True)
return model
def train(
self,
resume_from_checkpoint: Optional[Union[str, bool]] = None,
trial: Union["optuna.Trial", dict[str, Any], None] = None,
ignore_keys_for_eval: Optional[list[str]] = None,
**kwargs,
):
"""
Main training entry point.
Args:
resume_from_checkpoint (`str` or `bool`, *optional*):
If a `str`, local path to a saved checkpoint as saved by a previous instance of [`Trainer`]. If a
`bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance
of [`Trainer`]. If present, training will resume from the model/optimizer/scheduler states loaded here.
trial (`optuna.Trial` or `Dict[str, Any]`, *optional*):
The trial run or the hyperparameter dictionary for hyperparameter search.
ignore_keys_for_eval (`List[str]`, *optional*)
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions for evaluation during the training.
kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments used to hide deprecated arguments
"""
if resume_from_checkpoint is False:
resume_from_checkpoint = None
# memory metrics - must set up as early as possible
self._memory_tracker.start()
args = self.args
self.is_in_train = True
# Attach NEFTune hooks if necessary
if self.neftune_noise_alpha is not None:
self.model = self._activate_neftune(self.model)
# do_train is not a reliable argument, as it might not be set and .train() still called, so
# the following is a workaround:
if args.bf16_full_eval and not args.do_train and not self.is_model_parallel and self.model_init is None:
self._move_model_to_device(self.model, args.device)
if "model_path" in kwargs:
resume_from_checkpoint = kwargs.pop("model_path")
warnings.warn(
(
"`model_path` is deprecated and will be removed in a future version. Use `resume_from_checkpoint` "
"instead."
),
FutureWarning,
)
if len(kwargs) > 0:
raise TypeError(f"train() got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.")
# This might change the seed so needs to run first.
self._hp_search_setup(trial)
self._train_batch_size = self.args.train_batch_size
# Model re-init
model_reloaded = False
if self.model_init is not None:
# Seed must be set before instantiating the model when using model_init.
if self.args.full_determinism:
enable_full_determinism(self.args.seed)
else:
set_seed(self.args.seed)
self.model = self.call_model_init(trial)
model_reloaded = True
# Reinitializes optimizer and scheduler
self.optimizer, self.lr_scheduler = None, None
# Load potential model checkpoint
if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint:
resume_from_checkpoint = get_last_checkpoint(args.output_dir)
if resume_from_checkpoint is None:
raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})")
if resume_from_checkpoint is not None:
if not self.is_deepspeed_enabled and not self.is_fsdp_enabled:
self._load_from_checkpoint(resume_from_checkpoint)
# In case of repeating the find_executable_batch_size, set `self._train_batch_size` properly
state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
if state.train_batch_size is not None:
self._train_batch_size = state.train_batch_size
# If model was re-initialized, put it on the right device and update self.model_wrapped
if model_reloaded:
if self.place_model_on_device:
self._move_model_to_device(self.model, args.device)
self.model_wrapped = self.model
inner_training_loop = find_executable_batch_size(
self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size
)
if args.push_to_hub:
try:
# Disable progress bars when uploading models during checkpoints to avoid polluting stdout
hf_hub_utils.disable_progress_bars()
return inner_training_loop(
args=args,
resume_from_checkpoint=resume_from_checkpoint,
trial=trial,
ignore_keys_for_eval=ignore_keys_for_eval,
)
finally:
hf_hub_utils.enable_progress_bars()
else:
return inner_training_loop(
args=args,
resume_from_checkpoint=resume_from_checkpoint,
trial=trial,
ignore_keys_for_eval=ignore_keys_for_eval,
)
def _inner_training_loop(
self,
batch_size=None,
args=None,
resume_from_checkpoint=None,
trial=None,
ignore_keys_for_eval=None,
):
self.accelerator.free_memory()
self._train_batch_size = batch_size
if self.args.auto_find_batch_size:
if self.state.train_batch_size != self._train_batch_size:
from accelerate.utils import release_memory
(self.model_wrapped,) = release_memory(self.model_wrapped)
self.model_wrapped = self.model
# Check for DeepSpeed *after* the initial pass and modify the config
if self.is_deepspeed_enabled:
# Temporarily unset `self.args.train_batch_size`
original_bs = self.args.per_device_train_batch_size
self.args.per_device_train_batch_size = self._train_batch_size // max(1, self.args.n_gpu)
self.propagate_args_to_deepspeed(True)
self.args.per_device_train_batch_size = original_bs
self.state.train_batch_size = self._train_batch_size
logger.debug(f"Currently training with a batch size of: {self._train_batch_size}")
# Data loader and number of training steps
train_dataloader = self.get_train_dataloader()
# Setting up training control variables:
# number of training epochs: num_train_epochs
# number of training steps per epoch: num_update_steps_per_epoch
# total number of training steps to execute: max_steps
total_train_batch_size = self._train_batch_size * args.gradient_accumulation_steps * args.world_size
(
num_train_epochs,
num_update_steps_per_epoch,
num_examples,
num_train_samples,
epoch_based,
len_dataloader,
max_steps,
) = self.set_initial_training_values(args, train_dataloader, total_train_batch_size)
if (
self.accelerator.mpu.sequence_parallel_is_initialized()
and self.accelerator.mpu.get_sequence_parallel_world_size() > 1
):
total_train_batch_size = total_train_batch_size / self.accelerator.mpu.get_sequence_parallel_world_size()
num_train_tokens = None
if self.args.include_tokens_per_second:
num_train_tokens = self.num_tokens(train_dataloader, None if epoch_based else max_steps)
# If going by epochs, multiply tokens linearly
if len_dataloader is not None and epoch_based:
num_train_tokens *= args.num_train_epochs
# Otherwise since its steps, we just multiply by grad accum
else:
num_train_tokens *= args.gradient_accumulation_steps
if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
debug_overflow = DebugUnderflowOverflow(self.model) # noqa
delay_optimizer_creation = self.is_fsdp_enabled
# Can't delay optimizer creation when using FSDP2: https://github.com/huggingface/accelerate/blob/3f636d626063ffcf9a337c7d3624d61b7d187d59/src/accelerate/accelerator.py#L1404
is_fsdp2 = self.is_fsdp_enabled and (getattr(self.accelerator.state.fsdp_plugin, "fsdp_version", 1) == 2)
if is_fsdp2:
delay_optimizer_creation = False
# We need to reset the scheduler, as its parameters may be different on subsequent calls
if self._created_lr_scheduler:
self.lr_scheduler = None
self._created_lr_scheduler = False
if self.is_deepspeed_enabled:
self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps)
if not delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
self.state = TrainerState(
stateful_callbacks=[
cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
]
)
self.state.is_hyper_param_search = trial is not None
self.state.train_batch_size = self._train_batch_size
# Compute absolute values for logging, eval, and save if given as ratio
self.state.compute_steps(args, max_steps)
# Activate gradient checkpointing if needed
if args.gradient_checkpointing:
import transformers.modeling_utils
if args.deepspeed:
from deepspeed.runtime.activation_checkpointing.checkpointing import (
CheckpointFunction,
non_reentrant_checkpoint,
)
# HACK because outputs should always be tuples
def hpu_deepspeed_checkpointing(function, *checkpoint_args, use_reentrant: Optional[bool] = None):
"""DeepSpeed activation checkpointing."""
if use_reentrant is None:
use_reentrant = True
if use_reentrant:
all_outputs = []
CheckpointFunction.apply(function, all_outputs, *checkpoint_args)
else:
logger.info("DeepSpeed activation checkpointing=non_reentrant_checkpoint")
all_outputs = non_reentrant_checkpoint(function, *checkpoint_args)
# Always return a tuple
# When all_outputs contains only one element, DeepSpeed returns this element instead of a tuple
# which is not consistent with some models. See https://github.com/microsoft/DeepSpeed/issues/1057.
return tuple(all_outputs)
torch.utils.checkpoint.checkpoint = hpu_deepspeed_checkpointing
transformers.modeling_utils.checkpoint = hpu_deepspeed_checkpointing
elif args.use_lazy_mode:
from .gradient_checkpointing import checkpoint as lazy_mode_checkpointing
torch.utils.checkpoint.checkpoint = lazy_mode_checkpointing
transformers.modeling_utils.checkpoint = lazy_mode_checkpointing
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs)
# Wrap `_gradient_checkpointing_func` in the model with `transformer_engine` `activation_checkpointing` context.
if self.accelerator.fp8_enabled:
FP8ContextWrapper.gradient_checkpointing_wrap(self.model)
else:
# Hack because `RegressionModel` in test_trainer.py doesn't have `gradient_checkpointing_disable`
if hasattr(self.model, "gradient_checkpointing_disable"):
self.model.gradient_checkpointing_disable()
model = self._wrap_model(self.model_wrapped)
# as the model is wrapped, don't use `accelerator.prepare`
# this is for unhandled cases such as
# FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
use_accelerator_prepare = True if model is self.model else False
if use_accelerator_prepare and self.is_fsdp_enabled:
# In case of auto_find_batch_size=True
# Remove FSDP wrapping from sub-models.
self.model = unwrap_model(self.model, recursive=True)
if delay_optimizer_creation:
if use_accelerator_prepare:
# configure fsdp plugin for qlora if any
self._fsdp_qlora_plugin_updates()
if self.accelerator.mixed_precision != "fp8":
self.model = self.accelerator.prepare(self.model)
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
# prepare using `accelerator` prepare
if use_accelerator_prepare:
self.model.train()
if hasattr(self.lr_scheduler, "step"):
model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
else:
# to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config.
model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
self.model, self.optimizer, self.lr_scheduler
)
elif self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
# In this case we are in DDP + LOMO, which should be supported
self.optimizer = self.accelerator.prepare(self.optimizer)
if self.is_fsdp_enabled:
self.model = self.model_wrapped = model
# for the rest of this function `model` is the outside model, whether it was wrapped or not
if model is not self.model:
self.model_wrapped = model
# backward compatibility
if self.is_deepspeed_enabled:
self.deepspeed = self.model_wrapped
# ckpt loading
if resume_from_checkpoint is not None:
if self.is_deepspeed_enabled:
deepspeed_load_checkpoint(
self.model_wrapped, resume_from_checkpoint, load_module_strict=not _is_peft_model(self.model)
)
elif self.is_fsdp_enabled:
self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped)
# Check if saved optimizer or scheduler states exist
self._load_optimizer_and_scheduler(resume_from_checkpoint)
self._load_scaler(resume_from_checkpoint)
if self.gaudi_config.use_fused_clip_norm and self.args.use_habana:
try:
from habana_frameworks.torch.hpex.normalization import FusedClipNorm
except ImportError as error:
error.msg = f"Could not import habana_frameworks.torch.hpex.normalization. {error.msg}."
raise error
self.FusedNorm = FusedClipNorm(model.parameters(), args.max_grad_norm)
else:
self.FusedNorm = None
# important: at this point:
# self.model is the Transformers Model
# self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc.
# FSDP(Transformers Model), Dynamo Optimized Module(Transformers Model) etc.
# Train!
logger.info("***** Running training *****")
logger.info(f" Num examples = {num_examples:,}")
logger.info(f" Num Epochs = {num_train_epochs:,}")
logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}")
if self.args.per_device_train_batch_size != self._train_batch_size:
logger.info(f" Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {max_steps:,}")
logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}")
self.state.epoch = 0
start_time = time.time()
start_time_after_warmup = None
epochs_trained = 0
steps_trained_in_current_epoch = 0
steps_trained_progress_bar = None
# Check if continuing training from a checkpoint
if resume_from_checkpoint is not None and os.path.isfile(
os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
):
self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
self.compare_trainer_and_checkpoint_args(self.args, self.state)
self._load_callback_state()
epochs_trained = int(self.state.global_step // num_update_steps_per_epoch)
if not args.ignore_data_skip:
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
steps_trained_in_current_epoch *= args.gradient_accumulation_steps
else:
steps_trained_in_current_epoch = 0
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
logger.info(f" Continuing training from epoch {epochs_trained}")
logger.info(f" Continuing training from global step {self.state.global_step}")
if not args.ignore_data_skip:
logger.info(
f" Will skip the first {epochs_trained} epochs then the first"
f" {steps_trained_in_current_epoch} batches in the first epoch."
)
# In multi-worker training: broadcast model parameters from worker:0 to all the others.
# This must be done manually unless DistributedDataParallel is used.
if self.args.parallel_mode == ParallelMode.DISTRIBUTED and self.args.distribution_strategy == "fast_ddp":
from ..distributed import all_reduce_gradients
logger.debug(
f"Broadcasting the model parameters to assure that each of {self.args.world_size} workers start the training from the same point."
)
for param in model.parameters():
torch.distributed.broadcast(param.data, src=0)
# Update the references
for attr in ("model", "optimizer", "lr_scheduler"):
setattr(self.callback_handler, attr, getattr(self, attr))
self.callback_handler.train_dataloader = train_dataloader
self.state.init_training_references(self, max_steps, num_train_epochs, trial)
# tr_loss is a tensor to avoid synchronization of TPUs through .item()
tr_loss = torch.tensor(0.0, device=args.device)
# _total_loss_scalar is updated every time .item() has to be called on tr_loss and stores the sum of all losses
self._total_loss_scalar = 0.0
self._globalstep_last_logged = self.state.global_step
self._zero_model_grad(model)
# Gradient clipping
grad_norm: Optional[float] = None
_should_compute_grad_norm: bool = self.accelerator.distributed_type != DistributedType.DEEPSPEED and (
args.max_grad_norm is not None and args.max_grad_norm > 0
)
learning_rate = None
# attn_softmax_bf16 and use_flash_attention are enabled only for llama, qwen2, starcoder2, gemma and baichuan
# lazy_mode for llama, qwen2, starcoder2 and mistral
_should_update_inputs, _inputs_update = _get_input_update_settings(self.model, lazy_mode=args.use_lazy_mode)
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
if args.eval_on_start:
self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True)
if self.args.adjust_throughput:
self.log_evaluate_save_time = 0
else:
self.log_evaluate_save_time = None
hb_profiler = HabanaProfile(
warmup=self.args.profiling_warmup_steps,
active=self.args.profiling_steps,
record_shapes=self.args.profiling_record_shapes,
with_stack=self.args.profiling_with_stack,
)
hb_profiler.start()
if _is_peft_model(self.model) and self.model.peft_type == PeftType.ADALORA:
self.model.base_model.peft_config[self.model.trainable_adapter_name].total_step = max_steps
if max_steps < self.model.base_model.peft_config[self.model.trainable_adapter_name].tfinal:
self.model.base_model.peft_config[self.model.trainable_adapter_name].tfinal = 0
for epoch in range(epochs_trained, num_train_epochs):
epoch_dataloader = train_dataloader
if hasattr(epoch_dataloader, "set_epoch"):
epoch_dataloader.set_epoch(epoch)
# Reset the past mems state at the beginning of each epoch if necessary.
if args.past_index >= 0:
self._past = None
steps_in_epoch = (
len(epoch_dataloader)
if len_dataloader is not None
else args.max_steps * args.gradient_accumulation_steps
)
self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
self._load_rng_state(resume_from_checkpoint)
rng_to_sync = False
steps_skipped = 0
if steps_trained_in_current_epoch > 0:
epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch)
steps_skipped = steps_trained_in_current_epoch
steps_trained_in_current_epoch = 0
rng_to_sync = True
step = -1
epoch_iterator = iter(epoch_dataloader)
# We chunkify the epoch iterator into gradient accumulation steps `n` batches
remainder = num_examples % args.gradient_accumulation_steps
if remainder == 0:
remainder = args.gradient_accumulation_steps
update_step = -1
total_updates = steps_in_epoch // args.gradient_accumulation_steps + 1
if args.gradient_accumulation_steps == 1:
total_updates -= 1
for _ in range(total_updates):
update_step += 1
num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder
batch_samples, num_items_in_batch = self.get_batch_samples_transformers(
epoch_iterator, num_batches, args.device
)
for i, inputs in enumerate(batch_samples):
step += 1
if self.args.compile_from_sec_iteration and is_torch_version(">=", "2.6.0"):
torch.compiler.set_stance("force_eager" if step == 0 else "default")
if (
args.throughput_warmup_steps > 0
and (args.throughput_warmup_steps * args.gradient_accumulation_steps)
== epoch * steps_in_epoch + step
):
start_time_after_warmup = time.time()
do_sync_step = (step + 1) % args.gradient_accumulation_steps == 0 or (step + 1) == steps_in_epoch
# Since we perform prefetching, we need to manually set sync_gradients
self.accelerator.gradient_state._set_sync_gradients(do_sync_step)
if self.args.include_num_input_tokens_seen:
main_input_name = getattr(self.model, "main_input_name", "input_ids")
if main_input_name not in inputs:
logger.warning(
"Tried to track the number of tokens seen, however the current model is "
"not configured properly to know what item is the input. To fix this, add "
"a `main_input_name` attribute to the model class you are using."
)
else:
input_tokens = inputs[main_input_name].numel()
input_tokens = torch.tensor(input_tokens, device=self.args.device, dtype=torch.int64)
self.state.num_input_tokens_seen += self.accelerator.gather(input_tokens).sum().item()
if rng_to_sync:
self._load_rng_state(resume_from_checkpoint)
rng_to_sync = False
# Skip past any already trained steps if resuming training
if steps_trained_in_current_epoch > 0:
steps_trained_in_current_epoch -= 1
if steps_trained_progress_bar is not None:
steps_trained_progress_bar.update(1)
if steps_trained_in_current_epoch == 0:
self._load_rng_state(resume_from_checkpoint)
continue
elif steps_trained_progress_bar is not None:
steps_trained_progress_bar.close()
steps_trained_progress_bar = None
if step % args.gradient_accumulation_steps == 0:
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
# attn_softmax_bf16 and use_flash_attention is enabled only for llama, qwen2, starcoder2, gemma, baichuan and chatglm
# lazy_mode for llama, qwen2, starcoder2 and mistral
if _should_update_inputs:
inputs.update(_inputs_update)
# TODO: keep syncs for fast DDP?
# We explicitly want to avoid relying on `accelerator.accumulate` for generation training
context = (
functools.partial(self.accelerator.no_sync, model=model)
if i != len(batch_samples) - 1
and self.accelerator.distributed_type != DistributedType.DEEPSPEED
else contextlib.nullcontext
)
with context():
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
if (
args.parallel_mode == ParallelMode.DISTRIBUTED
and args.distribution_strategy == "fast_ddp"
and do_sync_step
):
all_reduce_gradients(
model, use_hpu_graphs=True
) # use HPU graphs for gradient fusion regardless of args.use_hpu_graphs_for_training setting
if args.logging_nan_inf_filter and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)):
# if loss is nan or inf simply add the average of previous logged losses
tr_loss = tr_loss + tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
else:
if tr_loss.device != tr_loss_step.device:
raise ValueError(
f"Calculated loss must be on the original device: {tr_loss.device} but device in use is {tr_loss_step.device}"
)
tr_loss = tr_loss + tr_loss_step
self.current_flos += float(self.floating_point_ops(inputs))
if args.use_lazy_mode:
self.htcore.mark_step()
if do_sync_step:
# Since we perform prefetching, we need to manually set sync_gradients to True
self.accelerator.gradient_state._set_sync_gradients(True)
# If the condition is true, we need to compute grad_norm, deepspeed does its own clipping
if _should_compute_grad_norm:
# Gradient clipping
if self.FusedNorm is not None:
# TODO: to merge self.accelerator.clip_grad_norm_ when HMP is removed
grad_norm = self.FusedNorm.clip_norm(model.parameters())
else:
# Revert to normal clipping otherwise
grad_norm = self.accelerator.clip_grad_norm_(
model.parameters(),
args.max_grad_norm,
)
self.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control)
self.optimizer.step()
self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control)
# get leaning rate before update
learning_rate = self._get_learning_rate()
if not self.accelerator.optimizer_step_was_skipped:
# Delay optimizer scheduling until metrics are generated
if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
self.lr_scheduler.step()
self._zero_model_grad(model)
self.state.global_step += 1
self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
if args.use_lazy_mode:
self.htcore.mark_step()
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
self._maybe_log_save_evaluate(
tr_loss,
grad_norm,
model,
trial,
epoch,
ignore_keys_for_eval,
start_time,
learning_rate=learning_rate,
)
else:
self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
hb_profiler.step()
if self.control.should_epoch_stop or self.control.should_training_stop:
break
# We also need to break out of the nested loop
if self.control.should_epoch_stop or self.control.should_training_stop:
break
if step < 0:
logger.warning(
"There seems not to be a single sample in your epoch_iterator, stopping training at step"
f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
f" num_steps ({max_steps}) higher than the number of available samples."
)
self.control.should_training_stop = True
self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
self._maybe_log_save_evaluate(
tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, learning_rate=learning_rate
)
if self.control.should_training_stop:
break
hb_profiler.stop()
if args.past_index and hasattr(self, "_past"):
# Clean the state at the end of training
delattr(self, "_past")
logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
# Wait for everyone to get here so we are sure the model has been saved by process 0.
if args.parallel_mode == ParallelMode.DISTRIBUTED:
torch.distributed.barrier()
self._load_best_model()
# add remaining tr_loss
self._total_loss_scalar += tr_loss.item()
effective_global_step = max(self.state.global_step, 0.001) # Avoid ZeroDivisionError
train_loss = self._total_loss_scalar / effective_global_step
# Warmup steps are removed from the calculation of speed metrics
num_samples_for_speed_metrics = num_train_samples - args.throughput_warmup_steps * total_train_batch_size
num_steps_for_speed_metrics = self.state.max_steps - args.throughput_warmup_steps
metrics = speed_metrics(
"train",
start_time,
num_samples=num_samples_for_speed_metrics,
num_steps=num_steps_for_speed_metrics,
num_tokens=num_train_tokens,
start_time_after_warmup=start_time_after_warmup,
log_evaluate_save_time=self.log_evaluate_save_time,
)
self.store_flos()
metrics["total_flos"] = self.state.total_flos
metrics["train_loss"] = train_loss
self.is_in_train = False
self._memory_tracker.stop_and_update_metrics(metrics)
self.log(metrics)
run_dir = self._get_output_dir(trial)
checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir)
# Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save.
if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1:
for checkpoint in checkpoints_sorted:
if not os.path.samefile(checkpoint, self.state.best_model_checkpoint):
logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
shutil.rmtree(checkpoint, ignore_errors=True)
self.control = self.callback_handler.on_train_end(args, self.state, self.control)
# Wait for the checkpoint to be uploaded.
self._finish_current_push()
# After training we make sure to retrieve back the original forward pass method
# for the embedding layer by removing the forward post hook.
if self.neftune_noise_alpha is not None:
self._deactivate_neftune(self.model)
return TrainOutput(self.state.global_step, train_loss, metrics)
def _load_best_model(self):
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
best_safe_model_path = os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_NAME)
best_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_WEIGHTS_NAME)
best_safe_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME)
model = self.model
if self.is_deepspeed_enabled:
deepspeed_load_checkpoint(
self.model_wrapped,
self.state.best_model_checkpoint,
load_module_strict=not _is_peft_model(self.model),
)
elif self.is_fsdp_enabled:
load_result = load_fsdp_model(
self.accelerator.state.fsdp_plugin,
self.accelerator,
model,
self.state.best_model_checkpoint,
**_get_fsdp_ckpt_kwargs(),
)
elif (
os.path.exists(best_model_path)
or os.path.exists(best_safe_model_path)
or os.path.exists(best_adapter_model_path)
or os.path.exists(best_safe_adapter_model_path)
):
has_been_loaded = True
if _is_peft_model(model):
# If train a model using PEFT & LoRA, assume that adapter have been saved properly.
# TODO: in the future support only specific min PEFT versions
if (hasattr(model, "active_adapter") or hasattr(model, "active_adapters")) and hasattr(
model, "load_adapter"
):
# For BC for older PEFT versions
if hasattr(model, "active_adapters"):
active_adapter = model.active_adapters[0]
if len(model.active_adapters) > 1:
logger.warning("Detected multiple active adapters, will only consider the first one")
else:
active_adapter = model.active_adapter
if os.path.exists(best_adapter_model_path) or os.path.exists(best_safe_adapter_model_path):
try:
model.load_adapter(self.state.best_model_checkpoint, active_adapter)
except RuntimeError as exc:
if model.peft_config[active_adapter].is_prompt_learning:
# for context: https://github.com/huggingface/peft/issues/2256
msg = (
"When using prompt learning PEFT methods such as "
f"{model.peft_config[active_adapter].peft_type.value}, setting "
"load_best_model_at_end=True can lead to errors, it is recommended "
"to set this to False and to load the model manually from the checkpoint "
"directory using PeftModel.from_pretrained(base_model, <path>) after training "
"has finished."
)
raise RuntimeError(msg) from exc
else:
raise
# Load_adapter has no return value present, modify it when appropriate.
from torch.nn.modules.module import _IncompatibleKeys
load_result = _IncompatibleKeys([], [])
else:
logger.warning(
"The intermediate checkpoints of PEFT may not be saved correctly, "
f"consider using a custom callback to save {ADAPTER_WEIGHTS_NAME} in corresponding saving folders. "
"Check some examples here: https://github.com/huggingface/peft/issues/96"
)
has_been_loaded = False
else:
logger.warning("Could not load adapter model, make sure to have `peft>=0.3.0` installed")
has_been_loaded = False
else:
# We load the model state dict on the CPU to avoid an OOM error.
if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
else:
state_dict = torch.load(best_model_path, map_location="cpu", weights_only=True)
# If the model is on the GPU, it still works!
# workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
# which takes *args instead of **kwargs
load_result = model.load_state_dict(state_dict, False)
if has_been_loaded:
self._issue_warnings_after_load(load_result)
elif os.path.exists(os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_INDEX_NAME)) or os.path.exists(
os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)
):
load_result = load_sharded_checkpoint(model, self.state.best_model_checkpoint, strict=False)
self._issue_warnings_after_load(load_result)
else:
logger.warning(
f"Could not locate the best model at {best_model_path}, if you are running a distributed training "
"on multiple nodes, you should activate `--save_on_each_node`."
)
def _maybe_log_save_evaluate(
self, tr_loss, _grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, learning_rate=None
):
timer = HabanaGenerationTime()
timer.start()
if self.args.adjust_throughput:
timer.step()
if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
logs: dict[str, float] = {}
# all_gather + mean() to get average loss over all processes
tr_loss_scalar = self._nested_gather(tr_loss).mean().item()
# reset tr_loss to zero
tr_loss -= tr_loss
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
# This grad_norm block was outside of _maybe_log_save_evaluate method causing perf degradation.
# Moving it here so the grad tensor is only copied when it's needed.
if self.accelerator.distributed_type == DistributedType.DEEPSPEED:
grad_norm = model.get_global_grad_norm()
# In some cases the grad norm may not return a float
if hasattr(grad_norm, "item"):
grad_norm = grad_norm.item()
else:
if (
_grad_norm is not None
and self.accelerator.distributed_type != DistributedType.FSDP
and _grad_norm.size() == torch.Size([1])
):
grad_norm = _grad_norm.detach().item()
else:
grad_norm = None
if grad_norm is not None:
logs["grad_norm"] = grad_norm
if learning_rate is not None:
logs["learning_rate"] = learning_rate
else:
logs["learning_rate"] = self._get_learning_rate()
self._total_loss_scalar += tr_loss_scalar
self._globalstep_last_logged = self.state.global_step
self.store_flos()
self.log(logs, start_time=start_time)
metrics = None
if self.control.should_evaluate:
metrics = self._evaluate(trial, ignore_keys_for_eval)
is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial)
if self.args.save_strategy == SaveStrategy.BEST:
self.control.should_save = is_new_best_metric
if self.control.should_save:
self._save_checkpoint(model, trial)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
if self.args.adjust_throughput:
timer.step()
self.log_evaluate_save_time += timer.last_duration
def _load_rng_state(self, checkpoint):
# Load RNG states from `checkpoint`
if checkpoint is None:
return
if self.args.world_size > 1:
process_index = self.args.process_index
rng_file = os.path.join(checkpoint, f"rng_state_{process_index}.pth")
if not os.path.isfile(rng_file):
logger.info(
f"Didn't find an RNG file for process {process_index}, if you are resuming a training that "
"wasn't launched in a distributed fashion, reproducibility is not guaranteed."
)
return
else:
rng_file = os.path.join(checkpoint, "rng_state.pth")
if not os.path.isfile(rng_file):
logger.info(
"Didn't find an RNG file, if you are resuming a training that was launched in a distributed "
"fashion, reproducibility is not guaranteed."
)
return
with safe_globals():
checkpoint_rng_state = torch.load(rng_file)
random.setstate(checkpoint_rng_state["python"])
np.random.set_state(checkpoint_rng_state["numpy"])
torch.random.set_rng_state(checkpoint_rng_state["cpu"])
if self.args.use_habana:
if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
self.hpu_random.set_rng_state_all(checkpoint_rng_state["hpu"])
else:
try:
self.hpu_random.set_rng_state(checkpoint_rng_state["hpu"])
except Exception as e:
logger.info(
f"Didn't manage to set back the RNG states of the HPU because of the following error:\n {e}"
"\nThis won't yield the same results as if the training had not been interrupted."
)
def _save_rng_state(self, output_dir):
# Save RNG state in non-distributed training
rng_states = {
"python": random.getstate(),
"numpy": np.random.get_state(),
"cpu": torch.random.get_rng_state(),
}
if self.args.use_habana:
if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
# In non distributed, we save the global HPU RNG state
rng_states["hpu"] = self.hpu_random.get_rng_state_all()
else:
rng_states["hpu"] = self.hpu_random.get_rng_state()
# A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may
# not yet exist.
os.makedirs(output_dir, exist_ok=True)
if self.args.world_size <= 1:
torch.save(rng_states, os.path.join(output_dir, "rng_state.pth"))
else:
torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth"))
def _save_optimizer_and_scheduler(self, output_dir):
if self.is_deepspeed_enabled:
# under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed
# config `stage3_gather_16bit_weights_on_model_save` is True
accept_exclude_frozen_parameters = "exclude_frozen_parameters" in set(
inspect.signature(self.model_wrapped.save_checkpoint).parameters.keys()
)
if accept_exclude_frozen_parameters and _is_peft_model(self.model):
self.model_wrapped.save_checkpoint(output_dir, exclude_frozen_parameters=True)
else:
self.model_wrapped.save_checkpoint(output_dir)
elif self.is_fsdp_enabled:
if isinstance(self.model, torch._dynamo.eval_frame.OptimizedModule):
# TODO: for some reason the fsdp model is not unwrapped correctly here, the self.mode
# shouldn't be an OptimizedModule at this point.
model = self.model._orig_mod
else:
model = self.model
# save fsdp specific ckpt for resuming from ckpt
save_fsdp_model(
self.accelerator.state.fsdp_plugin, self.accelerator, model, output_dir, **_get_fsdp_ckpt_kwargs()
)
save_fsdp_optimizer(
self.accelerator.state.fsdp_plugin, self.accelerator, self.optimizer, model, output_dir
)
elif self.args.should_save:
# deepspeed.save_checkpoint above saves model/optim/sched
# This block is executed by the main process only
optim_dict = self.optimizer.state_dict()
if self.args.use_habana:
# Move the state dict from HPU to CPU before saving
optim_dict = to_device_dtype(optim_dict, target_device=torch.device("cpu"))
torch.save(optim_dict, os.path.join(output_dir, OPTIMIZER_NAME))
# Save SCHEDULER & SCALER
is_deepspeed_custom_scheduler = self.is_deepspeed_enabled and not isinstance(
self.lr_scheduler, DeepSpeedSchedulerWrapper
)
if self.args.should_save and (not self.is_deepspeed_enabled or is_deepspeed_custom_scheduler):
if self.args.use_habana:
# Move the state dict from HPU to CPU before saving
scheduler_dict = self.lr_scheduler.state_dict()
scheduler_dict = to_device_dtype(scheduler_dict, target_device=torch.device("cpu"))
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
def _load_optimizer_and_scheduler(self, checkpoint):
"""If optimizer and scheduler states exist, load them."""
if checkpoint is None:
return
if self.is_deepspeed_enabled:
# deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
if not isinstance(self.lr_scheduler, DeepSpeedSchedulerWrapper):
with warnings.catch_warnings(record=True) as caught_warnings:
self.lr_scheduler.load_state_dict(
torch.load(os.path.join(checkpoint, SCHEDULER_NAME), weights_only=True)
)
reissue_pt_warnings(caught_warnings)
return
checkpoint_file_exists = (
os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME))
or os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME_BIN))
or (
os.path.isdir(checkpoint)
and any(
OPTIMIZER_NAME_BIN.split(".")[0] in folder_name
for folder_name in os.listdir(checkpoint)
if os.path.isdir(os.path.join(checkpoint, folder_name))
)
)
)
if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
# We use the CPU when training on one GPU to avoid OOM for GPU RAM when training big models.
# In distributed training however, we load directly on each GPU and risk the GPU OOM as it's more
# likely to get OOM on CPU (since we load num_gpu times the optimizer state
map_location = "cpu" if self.args.use_habana else self.args.device
if self.is_fsdp_enabled:
load_fsdp_optimizer(
self.accelerator.state.fsdp_plugin,
self.accelerator,
self.optimizer,
self.model,
checkpoint,
**_get_fsdp_ckpt_kwargs(),
)
else:
self.optimizer.load_state_dict(
torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location, weights_only=True)
)
with warnings.catch_warnings(record=True) as caught_warnings:
self.lr_scheduler.load_state_dict(
torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location=map_location, weights_only=True)
)
reissue_pt_warnings(caught_warnings)
# Move optimizer state to HPU
if self.args.use_habana:
to_device_dtype(self.optimizer.state.values(), target_device=torch.device("hpu"))
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
"""
Log `logs` on the various objects watching training.
Subclass and override this method to inject custom behavior.
Args:
logs (`Dict[str, float]`):
The values to log.
start_time (`Optional[float]`):
The start of training.
"""
if self.state.epoch is not None:
logs["epoch"] = self.state.epoch
if self.args.include_num_input_tokens_seen:
logs["num_input_tokens_seen"] = self.state.num_input_tokens_seen
if start_time is not None:
speed_metrics("train", start_time, num_tokens=self.state.num_input_tokens_seen)
mem_stats = get_hpu_memory_stats(self.args.device)
logs.update(mem_stats)
output = {**logs, **{"step": self.state.global_step}}
self.state.log_history.append(output)
self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]:
"""
Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors.
Compared to Transformers, it is also possible to enable non-blocking data copy.
"""
if isinstance(data, Mapping):
return type(data)({k: self._prepare_input(v) for k, v in data.items()})
elif isinstance(data, (tuple, list)):
return type(data)(self._prepare_input(v) for v in data)
elif isinstance(data, torch.Tensor):
if (
self.accelerator.mpu.sequence_parallel_is_initialized()
and self.accelerator.mpu.get_sequence_parallel_world_size() > 1
):
seq_parallel_world_rank = self.accelerator.mpu.get_sequence_parallel_rank()
sub_seq_length = int(data.size()[1] / self.accelerator.mpu.get_sequence_parallel_world_size())
data = data[
:, seq_parallel_world_rank * sub_seq_length : (seq_parallel_world_rank + 1) * sub_seq_length
]
kwargs = {"device": self.args.device}
if self.is_deepspeed_enabled and (torch.is_floating_point(data) or torch.is_complex(data)):
# NLP models inputs are int/uint and those get adjusted to the right dtype of the
# embedding. Other models such as wav2vec2's inputs are already float and thus
# may need special handling to match the dtypes of the model
kwargs.update({"dtype": self.accelerator.state.deepspeed_plugin.hf_ds_config.dtype()})
if self.args.non_blocking_data_copy:
return data.to(**kwargs, non_blocking=True)
else:
return data.to(**kwargs)
return data
def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True):
"""
A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired
arguments, depending on the situation.
Modified by Habana to enable using `autocast` on Gaudi devices.
"""
if self.use_cpu_amp:
ctx_manager = torch.autocast(device_type="cpu", dtype=torch.bfloat16, cache_enabled=cache_enabled)
elif self.use_hpu_amp:
ctx_manager = torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=True)
else:
ctx_manager = contextlib.nullcontext()
# Merge autocast context and `fp8_autocast` context if FP8 is enabled.
# Currently FP8 is enabled only for training.
if self.accelerator.fp8_enabled and self.model.training:
ctx_manager = FP8ContextWrapper(ctx_manager, fp8_recipe=self.accelerator.fp8_recipe)
return ctx_manager
def training_step(
self, model: torch.nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch=None
) -> torch.Tensor:
"""
Perform a training step on a batch of inputs.
Subclass and override to inject custom behavior.
Args:
model (`torch.nn.Module`):
The model to train.
inputs (`Dict[str, Union[torch.Tensor, Any]]`):
The inputs and targets of the model.
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
argument `labels`. Check your model's documentation for all accepted arguments.
Return:
`torch.Tensor`: The tensor with training loss on this batch.
"""
model.train()
# TODO
# if hasattr(self.optimizer, "train") and callable(self.optimizer.train):
# self.optimizer.train()
inputs = self._prepare_inputs(inputs)
with self.compute_loss_context_manager():
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
del inputs
kwargs = {}
# For LOMO optimizers you need to explicitly use the learning rate
if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
kwargs["learning_rate"] = self._get_learning_rate()
if self.args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training
if self.args.use_lazy_mode and self.args.pipelining_fwd_bwd:
self.htcore.mark_step()
# Finally we need to normalize the loss for reporting
if not self.model_accepts_loss_kwargs and self.compute_loss_func is None:
# temporary fix to calculate loss correctly
loss = loss / self.args.gradient_accumulation_steps
# Turning off loss scaling w.r.t. gradient accumulation when DeepSpeed is enabled
# https://github.com/huggingface/transformers/pull/35808
if self.accelerator.distributed_type == DistributedType.DEEPSPEED:
kwargs["scale_wrt_gas"] = False
if _is_peft_model(self.model) and self.model.peft_type == PeftType.ADALORA:
assert not (self.accelerator.fp8_enabled and self.args.gradient_checkpointing), (
"FP8 precision with gradient_checkpointing is currently not supported with PeftType.ADALORA"
)
if self.is_deepspeed_enabled and not is_deepspeed_zero3_enabled():
self.accelerator.deepspeed_engine_wrapped.engine.backward(loss)
self.model.base_model.update_and_allocate(self.state.global_step)
self.accelerator.deepspeed_engine_wrapped.engine.step()
else:
self.accelerator.backward(loss, **kwargs)
self.model.base_model.update_and_allocate(self.state.global_step)
else:
if self.accelerator.fp8_enabled and self.args.gradient_checkpointing:
# The precision used in backward pass should be same as the one used in forward pass.
# However when training with gradient_checkpointing and FP8 precision, recompute forward
# in backward does not automatically run with FP8 precision. In order to handle this,
# the backward is run in `fp8_autocast` context
with FP8ContextWrapper.create_fp8_context(fp8_recipe=self.accelerator.fp8_recipe):
self.accelerator.backward(loss, **kwargs)
else:
self.accelerator.backward(loss, **kwargs)
return loss.detach()
def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
"""
Will save the model, so you can reload it using `from_pretrained()`.
Will only save from the main process.
"""
if output_dir is None:
output_dir = self.args.output_dir
if self.is_fsdp_enabled:
if "FULL_STATE_DICT" in str(self.accelerator.state.fsdp_plugin.state_dict_type):
state_dict = self.accelerator.get_state_dict(self.model)
if self.args.should_save:
self._save(output_dir, state_dict=state_dict)
elif self.is_deepspeed_enabled:
try:
state_dict = self.accelerator.get_state_dict(self.deepspeed)
if self.args.should_save:
self._save(output_dir, state_dict=state_dict)
except ValueError:
logger.warning(
" stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use"
" zero_to_fp32.py to recover weights"
)
if self.args.should_save:
self._save(output_dir, state_dict={})
# remove the dummy state_dict
remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME])
accept_exclude_frozen_parameters = "exclude_frozen_parameters" in set(
inspect.signature(self.model_wrapped.save_checkpoint).parameters.keys()
)
if accept_exclude_frozen_parameters and _is_peft_model(self.model):
self.model_wrapped.save_checkpoint(output_dir, exclude_frozen_parameters=True)
else:
self.model_wrapped.save_checkpoint(output_dir)
elif self.args.should_save:
self._save(output_dir)
# Push to the Hub when `save_model` is called by the user.
if self.args.push_to_hub and not _internal_call:
self.push_to_hub(commit_message="Model save")
def _save(self, output_dir: Optional[str] = None, state_dict=None):
# If we are executing this function, we are the process zero, so we don't check for that.
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Saving model checkpoint to {output_dir}")
supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
if state_dict is None:
state_dict = self.model.state_dict()
if state_dict and self.args.use_habana:
# state_dict items have to be saved on the CPU
state_dict = to_device_dtype(state_dict, target_device=torch.device("cpu"))
# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
if not isinstance(self.model, supported_classes):
if isinstance(self.accelerator.unwrap_model(self.model, keep_torch_compile=False), supported_classes):
self.accelerator.unwrap_model(self.model, keep_torch_compile=False).save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
)
else:
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
if self.args.save_safetensors:
safetensors.torch.save_file(
state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"}
)
else:
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
self.model.save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
)
if self.processing_class is not None:
self.processing_class.save_pretrained(output_dir)
elif (
self.data_collator is not None
and hasattr(self.data_collator, "tokenizer")
and self.data_collator.tokenizer is not None
):
logger.info("Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`")
self.data_collator.tokenizer.save_pretrained(output_dir)
self.gaudi_config.save_pretrained(output_dir)
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
def evaluate(
self,
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
ignore_keys: Optional[list[str]] = None,
metric_key_prefix: str = "eval",
) -> dict[str, float]:
"""
From https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/trainer.py#L3162 with the following modification
1. use throughput_warmup_steps in evaluation throughput calculation
"""
# handle multiple eval datasets
override = eval_dataset is not None
eval_dataset = eval_dataset if override else self.eval_dataset
if isinstance(eval_dataset, dict):
metrics = {}
for eval_dataset_name, _eval_dataset in eval_dataset.items():
dataset_metrics = self.evaluate(
eval_dataset=_eval_dataset if override else eval_dataset_name,
ignore_keys=ignore_keys,
metric_key_prefix=f"{metric_key_prefix}_{eval_dataset_name}",
)
metrics.update(dataset_metrics)
return metrics
# memory metrics - must set up as early as possible
self._memory_tracker.start()
eval_dataloader = self.get_eval_dataloader(eval_dataset)
start_time = time.time()
self.start_time_after_warmup = None
eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
output = eval_loop(
eval_dataloader,
description="Evaluation",
# No point gathering the predictions if there are no metrics, otherwise we defer to
# self.args.prediction_loss_only
prediction_loss_only=True if self.compute_metrics is None else None,
ignore_keys=ignore_keys,
metric_key_prefix=metric_key_prefix,
)
total_batch_size = self.args.eval_batch_size * self.args.world_size
if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
if f"{metric_key_prefix}_model_preparation_time" in output.metrics:
start_time += output.metrics[f"{metric_key_prefix}_model_preparation_time"]
num_samples = output.num_samples - self.args.throughput_warmup_steps * total_batch_size
num_steps = math.ceil(output.num_samples / total_batch_size) - self.args.throughput_warmup_steps
eval_steps = math.ceil(output.num_samples / total_batch_size)
if eval_steps <= self.args.throughput_warmup_steps:
logger.warning(
f" Warmup steps are taken into account for the throughput calculation because the number of evaluation steps ({eval_steps}) is smaller than the number of warmup steps ({self.args.throughput_warmup_steps})"
)
num_samples = output.num_samples
num_steps = eval_steps
output.metrics.update(
speed_metrics(
metric_key_prefix,
start_time,
num_samples=num_samples,
num_steps=num_steps,
start_time_after_warmup=self.start_time_after_warmup,
)
)
self.log(output.metrics)
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
self._memory_tracker.stop_and_update_metrics(output.metrics)
return output.metrics
def predict(
self, test_dataset: Dataset, ignore_keys: Optional[list[str]] = None, metric_key_prefix: str = "test"
) -> PredictionOutput:
"""
From https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/trainer.py#L3904 with the following modification
1. comment out TPU related
2. use throughput_warmup_steps in evaluation throughput calculation
"""
# memory metrics - must set up as early as possible
self._memory_tracker.start()
test_dataloader = self.get_test_dataloader(test_dataset)
start_time = time.time()
self.start_time_after_warmup = None
eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
output = eval_loop(
test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
)
total_batch_size = self.args.eval_batch_size * self.args.world_size
if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
if f"{metric_key_prefix}_model_preparation_time" in output.metrics:
start_time += output.metrics[f"{metric_key_prefix}_model_preparation_time"]
num_samples = output.num_samples - self.args.throughput_warmup_steps * total_batch_size
num_steps = math.ceil(output.num_samples / total_batch_size) - self.args.throughput_warmup_steps
logger.info(f"num_samples : {num_samples}, num_steps: {num_steps}")
output.metrics.update(
speed_metrics(
metric_key_prefix,
start_time,
num_samples=num_samples,
num_steps=num_steps,
start_time_after_warmup=self.start_time_after_warmup,
)
)
self.control = self.callback_handler.on_predict(self.args, self.state, self.control, output.metrics)
self._memory_tracker.stop_and_update_metrics(output.metrics)
return PredictionOutput(predictions=output.predictions, label_ids=output.label_ids, metrics=output.metrics)
def evaluation_loop(
self,
dataloader: DataLoader,
description: str,
prediction_loss_only: Optional[bool] = None,
ignore_keys: Optional[list[str]] = None,
metric_key_prefix: str = "eval",
) -> EvalLoopOutput:
"""
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
Works both with or without labels.
"""
args = self.args
prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
# if eval is called w/o train, handle model prep here
if self.is_deepspeed_enabled and self.deepspeed is None:
_, _ = deepspeed_init(self, num_training_steps=0, inference=True)
model = self._wrap_model(self.model, training=False, dataloader=dataloader)
if len(self.accelerator._models) == 0 and model is self.model:
start_time = time.time()
model = (
self.accelerator.prepare(model)
if self.is_deepspeed_enabled
or (self.is_fsdp_enabled and self.accelerator.mixed_precision != "fp8" and not self.args.torch_compile)
else self.accelerator.prepare_model(model, evaluation_mode=True)
)
self.model_preparation_time = round(time.time() - start_time, 4)
if self.is_fsdp_enabled:
self.model = model
# for the rest of this function `model` is the outside model, whether it was wrapped or not
if model is not self.model:
self.model_wrapped = model
# backward compatibility
if self.is_deepspeed_enabled:
self.deepspeed = self.model_wrapped
model.eval()
# if hasattr(self.optimizer, "eval") and callable(self.optimizer.eval):
# self.optimizer.eval()
# Do not use HPU graphs if the training is ongoing because it detaches gradients
if args.use_hpu_graphs_for_inference and not self.is_in_train:
logger.info("Using HPU graphs for inference.")
# Do not wrap the model in HPU graphs if it has already been done
if not self.already_wrapped_for_hpu_graphs:
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
model = wrap_in_hpu_graph(
model, disable_tensor_cache=args.disable_tensor_cache_hpu_graphs, max_graphs=args.max_hpu_graphs
)
self.already_wrapped_for_hpu_graphs = True
# if full bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
# while ``train`` is running, cast it to the right dtype first and then put on device
if not self.is_in_train:
if args.bf16_full_eval:
model = model.to(dtype=torch.bfloat16, device=args.device)
batch_size = self.args.eval_batch_size
logger.info(f"\n***** Running {description} *****")
if has_length(dataloader):
logger.info(f" Num examples = {self.num_examples(dataloader)}")
else:
logger.info(" Num examples: Unknown")
logger.info(f" Batch size = {batch_size}")
self.callback_handler.eval_dataloader = dataloader
# Do this before wrapping.
eval_dataset = getattr(dataloader, "dataset", None)
if args.past_index >= 0:
self._past = None
# Initialize containers
all_losses = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
all_preds = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
all_labels = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
all_inputs = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
metrics = None
eval_set_kwargs = {}
# Will be useful when we have an iterable dataset so don't know its length.
observed_num_examples = 0
# attn_softmax_bf16 and use_flash_attention are enabled only for llama, qwen2, starcoder2, gemma and baichuan
_should_update_inputs, _inputs_update = _get_input_update_settings(self.model)
# set a default dtype of logits
logits_dtype: str = "float32"
# Main evaluation loop
start_time_eval = time.time()
for step, inputs in enumerate(dataloader):
if (
self.args.throughput_warmup_steps > 0
and not self.is_in_train
and step == self.args.throughput_warmup_steps
):
self.start_time_after_warmup = time.time()
self.compilation_time = self.start_time_after_warmup - start_time_eval
# Update the observed num examples
observed_batch_size = find_batch_size(inputs)
if observed_batch_size is not None:
observed_num_examples += observed_batch_size
# For batch samplers, batch_size is not known by the dataloader in advance.
if batch_size is None:
batch_size = observed_batch_size
# attn_softmax_bf16 and use_flash_attention are enabled only for llama, qwen2, starcoder2, gemma, baichuan and chatglm
if _should_update_inputs:
inputs.update(_inputs_update)
# Prediction step
losses, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
main_input_name = getattr(self.model, "main_input_name", "input_ids")
inputs_decode = (
self._prepare_input(inputs[main_input_name]) if "inputs" in args.include_for_metrics else None
)
# Update containers
if losses is not None:
losses = self.gather_function(losses.repeat(batch_size))
all_losses.add(losses)
if inputs_decode is not None:
inputs_decode = self.accelerator.pad_across_processes(inputs_decode, dim=1, pad_index=-100)
inputs_decode = self.gather_function(inputs_decode)
if not self.args.batch_eval_metrics or description == "Prediction":
all_inputs.add(inputs_decode)
if labels is not None:
# Pad labels here, preparing for preprocess_logits_for_metrics in next logits block.
labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100)
# Save the logits dtype since we need to convert them into floats during the process
# They will be converted back into their original dtype right before computing metrics
if logits is not None:
logits_dtype = get_dtype(logits)
if args.use_habana and logits_dtype != "float32":
logits = to_device_dtype(logits, target_dtype=torch.float32)
logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100)
if self.preprocess_logits_for_metrics is not None:
logits = self.preprocess_logits_for_metrics(logits, labels)
logits = self.gather_function(logits)
if not self.args.batch_eval_metrics or description == "Prediction":
all_preds.add(logits)
if labels is not None:
if self.args.context_parallel_size != 1:
labels = labels.clone()
labels = self.gather_function(labels)
if not self.args.batch_eval_metrics or description == "Prediction":
all_labels.add(labels)
self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
if self.args.batch_eval_metrics:
if self.compute_metrics is not None and logits is not None and labels is not None:
is_last_step = self.accelerator.gradient_state.end_of_dataloader
batch_kwargs = {}
batch_kwargs["losses"] = losses if "loss" in args.include_for_metrics else None
batch_kwargs["inputs"] = inputs if "inputs" in args.include_for_metrics else None
metrics = self.compute_metrics(
EvalPrediction(predictions=logits, label_ids=labels, **batch_kwargs),
compute_result=is_last_step,
)
del losses, logits, labels, inputs
# Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
elif args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
all_losses.to_cpu_and_numpy()
all_preds.to_cpu_and_numpy()
all_labels.to_cpu_and_numpy()
all_inputs.to_cpu_and_numpy()
del losses, logits, labels, inputs
# nested concat does accumulation on tensors of variable length.
# Added mark step here to avoid graph recompile
if args.use_lazy_mode:
self.htcore.mark_step()
# After all calls to `.gather_function`, reset to `gather_for_metrics`:
self.gather_function = self.accelerator.gather_for_metrics
if args.past_index and hasattr(self, "_past"):
# Clean the state at the end of the evaluation loop
delattr(self, "_past")
# Gather all remaining tensors and put them back on the CPU
all_losses = all_losses.get_arrays()
all_preds = all_preds.get_arrays()
all_labels = all_labels.get_arrays()
all_inputs = all_inputs.get_arrays()
# Number of samples
if has_length(eval_dataset):
num_samples = len(eval_dataset)
# The instance check is weird and does not actually check for the type, but whether the dataset has the right
# methods. Therefore we need to make sure it also has the attribute.
elif isinstance(eval_dataset, IterableDatasetShard) and getattr(eval_dataset, "num_examples", 0) > 0:
num_samples = eval_dataset.num_examples
else:
if has_length(dataloader):
num_samples = self.num_examples(dataloader)
else: # both len(dataloader.dataset) and len(dataloader) fail
num_samples = observed_num_examples
if num_samples == 0 and observed_num_examples > 0:
num_samples = observed_num_examples
# Convert predictions back into their original dtype if necessary
if all_preds is not None:
all_preds = convert_into_dtypes(all_preds, logits_dtype)
# Metrics!
if (
self.compute_metrics is not None
and all_preds is not None
and all_labels is not None
and not self.args.batch_eval_metrics
):
eval_set_kwargs["losses"] = all_losses if "loss" in args.include_for_metrics else None
eval_set_kwargs["inputs"] = all_inputs if "inputs" in args.include_for_metrics else None
metrics = self.compute_metrics(
EvalPrediction(predictions=all_preds, label_ids=all_labels, **eval_set_kwargs)
)
elif metrics is None:
metrics = {}
# To be JSON-serializable, we need to remove numpy types or zero-d tensors
metrics = denumpify_detensorize(metrics)
if isinstance(all_losses, list) and all_losses:
metrics[f"{metric_key_prefix}_loss"] = np.concatenate(all_losses).mean().item()
elif isinstance(all_losses, np.ndarray):
metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()
if hasattr(self, "model_preparation_time"):
metrics[f"{metric_key_prefix}_model_preparation_time"] = self.model_preparation_time
if hasattr(self, "compilation_time"):
metrics[f"{metric_key_prefix}_graph_compliation_duration"] = self.compilation_time
# Prefix all keys with metric_key_prefix + '_'
for key in list(metrics.keys()):
if not key.startswith(f"{metric_key_prefix}_"):
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples)
def prediction_step(
self,
model: torch.nn.Module,
inputs: dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[list[str]] = None,
) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Perform an evaluation step on `model` using `inputs`.
Subclass and override to inject custom behavior.
Args:
model (`torch.nn.Module`):
The model to evaluate.
inputs (`Dict[str, Union[torch.Tensor, Any]]`):
The inputs and targets of the model.
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
argument `labels`. Check your model's documentation for all accepted arguments.
prediction_loss_only (`bool`):
Whether or not to return the loss only.
ignore_keys (`List[str]`, *optional*):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions.
Return:
Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
logits and labels (each being optional).
"""
has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names)
# For CLIP-like models capable of returning loss values.
# If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss`
# is `True` in `model.forward`.
return_loss = inputs.get("return_loss", None)
if return_loss is None:
return_loss = self.can_return_loss
loss_without_labels = True if len(self.label_names) == 0 and return_loss else False
inputs = self._prepare_inputs(inputs)
if ignore_keys is None:
if hasattr(self.model, "config"):
ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", ["past_key_values"])
else:
ignore_keys = []
# labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
if has_labels or loss_without_labels:
labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
if len(labels) == 1:
labels = labels[0]
else:
labels = None
with torch.no_grad():
try:
if has_labels or loss_without_labels:
with self.compute_loss_context_manager():
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
loss = loss.detach().mean()
if isinstance(outputs, dict):
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
else:
logits = outputs[1:]
else:
loss = None
with self.compute_loss_context_manager():
outputs = model(**inputs)
if isinstance(outputs, dict):
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
else:
logits = outputs
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index - 1]
except RuntimeError as error:
if "cpu fallback is not supported during hpu graph capturing" in str(error):
error.args = (
f"{error}. You should run inference in lazy mode only with `use_lazy_mode=True` and `use_hpu_graphs_for_inference=False`.",
)
raise error
if self.args.use_lazy_mode and not (self.args.use_hpu_graphs_for_inference and not self.is_in_train):
self.htcore.mark_step()
if prediction_loss_only:
return (loss, None, None)
logits = nested_detach(logits)
if len(logits) == 1:
logits = logits[0]
return (loss, logits, labels)
def _push_from_checkpoint(self, checkpoint_folder):
# Only push from one node.
if not self.is_world_process_zero() or self.args.hub_strategy == HubStrategy.END:
return
# If we haven't finished the last push, we don't do this one unless args.hub_always_push=True.
if not self.args.hub_always_push and self.push_in_progress is not None and not self.push_in_progress.is_done():
return
output_dir = self.args.output_dir
# To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder
modeling_files = [CONFIG_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_NAME, GAUDI_CONFIG_NAME]
# Add sharded checkpoints if we have an index
for index_file in [WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME]:
index_path = os.path.join(checkpoint_folder, index_file)
if os.path.isfile(index_path):
modeling_files.append(index_file)
with open(index_path) as f:
index = json.loads(f.read())
shard_files = list(set(index["weight_map"].values()))
modeling_files.extend(shard_files)
if is_peft_available():
modeling_files.extend([ADAPTER_CONFIG_NAME, ADAPTER_WEIGHTS_NAME, ADAPTER_SAFE_WEIGHTS_NAME])
for modeling_file in modeling_files:
if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)):
shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file))
# Saving the processing class is fast and we don't know how many files it may have spawned, so we resave it to be sure.
if self.processing_class is not None:
self.processing_class.save_pretrained(output_dir)
# Same for the training arguments
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
if self.args.save_strategy == SaveStrategy.STEPS:
commit_message = f"Training in progress, step {self.state.global_step}"
else:
commit_message = f"Training in progress, epoch {int(self.state.epoch)}"
model_push_job = upload_folder(
repo_id=self.hub_model_id,
folder_path=output_dir,
commit_message=commit_message,
token=self.args.hub_token,
run_as_future=True,
ignore_patterns=["_*", f"{PREFIX_CHECKPOINT_DIR}-*"],
)
push_jobs = [model_push_job]
if self.args.hub_strategy in [HubStrategy.CHECKPOINT, HubStrategy.ALL_CHECKPOINTS]:
path_in_repo = (
"last-checkpoint" if self.args.hub_strategy == HubStrategy.CHECKPOINT else Path(checkpoint_folder).name
)
checkpoint_push = upload_folder(
repo_id=self.hub_model_id,
folder_path=checkpoint_folder,
path_in_repo=path_in_repo,
commit_message=commit_message + ", checkpoint",
token=self.args.hub_token,
run_as_future=True,
)
push_jobs.append(checkpoint_push)
if self.push_in_progress is None or self.push_in_progress.is_done():
self.push_in_progress = PushInProgress(push_jobs)
else:
self.push_in_progress.jobs.extend(push_jobs)
#
# Deprecated code
#
def prediction_loop(
self,
dataloader: DataLoader,
description: str,
prediction_loss_only: Optional[bool] = None,
ignore_keys: Optional[list[str]] = None,
metric_key_prefix: str = "eval",
) -> EvalLoopOutput:
"""
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
Works both with or without labels.
"""
args = self.args
if not has_length(dataloader):
raise ValueError("dataloader must implement a working __len__")
prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
# if eval is called w/o train, handle model prep here
if self.is_deepspeed_enabled and self.deepspeed is None:
_, _ = deepspeed_init(self, num_training_steps=0, inference=True)
model = self._wrap_model(self.model, training=False, dataloader=dataloader)
if len(self.accelerator._models) == 0 and model is self.model:
model = (
self.accelerator.prepare(model)
if self.is_deepspeed_enabled or self.is_fsdp_enabled
else self.accelerator.prepare_model(model, evaluation_mode=True)
)
if self.is_fsdp_enabled:
self.model = model
# for the rest of this function `model` is the outside model, whether it was wrapped or not
if model is not self.model:
self.model_wrapped = model
# backward compatibility
if self.is_deepspeed_enabled:
self.deepspeed = self.model_wrapped
model.eval()
# TODO
# if hasattr(self.optimizer, "eval") and callable(self.optimizer.eval):
# self.optimizer.eval()
# Do not use HPU graphs if the training is ongoing because it detaches gradients
if args.use_hpu_graphs_for_inference and not self.is_in_train:
logger.info("Using HPU graphs for inference.")
# Do not wrap the model in HPU graphs if it has already been done
if not self.already_wrapped_for_hpu_graphs:
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
model = wrap_in_hpu_graph(
model, disable_tensor_cache=args.disable_tensor_cache_hpu_graphs, max_graphs=args.max_hpu_graphs
)
self.already_wrapped_for_hpu_graphs = True
# if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
# while ``train`` is running, cast it to the right dtype first and then put on device
if not self.is_in_train:
if args.fp16_full_eval:
model = model.to(dtype=torch.float16, device=args.device)
elif args.bf16_full_eval:
model = model.to(dtype=torch.bfloat16, device=args.device)
batch_size = (
dataloader.total_batch_size
if getattr(dataloader, "_is_accelerate_prepared", False)
else dataloader.batch_size
)
if batch_size is None:
raise ValueError(
"Batch size cannot be None. Ensure the dataloader has a valid batch_size or total_batch_size."
)
num_examples = self.num_examples(dataloader)
logger.info(f"\n***** Running {description} *****")
logger.info(f" Num examples = {num_examples}")
logger.info(f" Batch size = {batch_size}")
losses_host: Optional[torch.Tensor] = None
preds_host: Union[torch.Tensor, list[torch.Tensor], None] = None
labels_host: Union[torch.Tensor, list[torch.Tensor], None] = None
inputs_host: Union[torch.Tensor, list[torch.Tensor], None] = None
metrics: Optional[dict] = None
eval_set_kwargs: dict = {}
world_size = max(1, args.world_size)
eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
if not prediction_loss_only:
# The actual number of eval_sample can be greater than num_examples in distributed settings (when we pass
# a batch size to the sampler)
make_multiple_of = None
if hasattr(dataloader, "sampler") and isinstance(dataloader.sampler, SequentialDistributedSampler):
make_multiple_of = dataloader.sampler.batch_size
preds_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
labels_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
inputs_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
if args.past_index >= 0:
self._past = None
self.callback_handler.eval_dataloader = dataloader
for step, inputs in enumerate(dataloader):
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
main_input_name = getattr(self.model, "main_input_name", "input_ids")
inputs_decode = (
self._prepare_input(inputs[main_input_name]) if "inputs" in args.include_for_metrics else None
)
if loss is not None:
losses = loss.repeat(batch_size)
losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
if logits is not None:
preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
if labels is not None:
labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
if inputs_decode is not None:
inputs_host = (
inputs_decode
if inputs_host is None
else nested_concat(inputs_host, inputs_decode, padding_index=-100)
)
self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
if self.args.batch_eval_metrics:
if self.compute_metrics is not None and preds_host is not None and labels_host is not None:
is_last_step = self.accelerator.gradient_state.end_of_dataloader
batch_kwargs = {}
batch_kwargs["losses"] = losses_host if "loss" in args.include_for_metrics else None
batch_kwargs["inputs"] = inputs_host if "inputs" in args.include_for_metrics else None
metrics = self.compute_metrics(
EvalPrediction(predictions=preds_host, label_ids=labels_host, **batch_kwargs),
compute_result=is_last_step,
)
if self.args.batch_eval_metrics or (
args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0
):
# Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
if not prediction_loss_only:
preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids"))
# Set back to None to begin a new accumulation
del losses_host, preds_host, labels_host, inputs_host
losses_host, preds_host, labels_host, inputs_host = None, None, None, None
# nested concat does accumulation on tensors of variable length.
# Added mark step here to avoid graph recompile
if args.use_lazy_mode:
self.htcore.mark_step()
if args.past_index and hasattr(self, "_past"):
# Clean the state at the end of the evaluation loop
delattr(self, "_past")
# Gather all remaining tensors and put them back on the CPU
eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
if not prediction_loss_only:
preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids"))
eval_loss = eval_losses_gatherer.finalize()
preds = preds_gatherer.finalize() if not prediction_loss_only else None
label_ids = labels_gatherer.finalize() if not prediction_loss_only else None
inputs_ids = inputs_gatherer.finalize() if not prediction_loss_only else None
if (
self.compute_metrics is not None
and preds is not None
and label_ids is not None
and not self.args.batch_eval_metrics
):
eval_set_kwargs["losses"] = eval_loss if "loss" in args.include_for_metrics else None
eval_set_kwargs["inputs"] = inputs_ids if "inputs" in args.include_for_metrics else None
metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids, **eval_set_kwargs))
elif metrics is None:
metrics = {}
# To be JSON-serializable, we need to remove numpy types or zero-d tensors
metrics = denumpify_detensorize(metrics)
if eval_loss is not None:
metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item()
# Prefix all keys with metric_key_prefix + '_'
for key in list(metrics.keys()):
if not key.startswith(f"{metric_key_prefix}_"):
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
return EvalLoopOutput(predictions=preds, label_ids=label_ids, metrics=metrics, num_samples=num_examples)
def create_accelerator_and_postprocess(self):
# We explicitly don't rely on the `Accelerator` to do gradient accumulation
grad_acc_kwargs = {}
if self.args.accelerator_config.gradient_accumulation_kwargs is not None:
grad_acc_kwargs = self.args.accelerator_config.gradient_accumulation_kwargs
# check if num_steps is attempted to be passed in gradient_accumulation_kwargs
if "num_steps" in grad_acc_kwargs:
if self.args.gradient_accumulation_steps > 1:
# raise because we do not know which setting is intended.
raise ValueError(
"The `AcceleratorConfig`'s `num_steps` is set but `gradient_accumulation_steps` is greater than 1 in the passed `TrainingArguments`"
"If using the passed `AcceleratorConfig` is desired, do not set the `TrainingArguments` `gradient_accumulation_steps`."
)
else:
self.args.gradient_accumulation_steps = grad_acc_kwargs["num_steps"]
accelerator_config = self.args.accelerator_config.to_dict()
# Extract dataloader config params from accelerator config
dataloader_params = ["split_batches", "dispatch_batches", "even_batches", "use_seedable_sampler"]
dataloader_config = DataLoaderConfiguration(
**{param: accelerator_config.pop(param) for param in dataloader_params}
)
if is_accelerate_available("1.1.0"):
dataloader_config.data_seed = self.args.data_seed
non_blocking = accelerator_config.pop("non_blocking")
if non_blocking and not self.args.dataloader_pin_memory:
logger.warning(
"`non_blocking` is enabled but `dataloader_pin_memory` is not. For the best performance, it's recommended to enable both."
)
dataloader_config.non_blocking = non_blocking
# this would have been updated above, no need for it anymore
accelerator_config.pop("gradient_accumulation_kwargs")
args = {
"dataloader_config": dataloader_config,
"deepspeed_plugins": self.args.deepspeed_plugin,
# OH specific
"distribution_strategy": self.args.distribution_strategy,
"use_regional_compilation": self.args.use_regional_compilation,
"compiled_autograd_enable": self.args.use_compiled_autograd,
}
# tp is initialized at Accelerator init phase so
# args should be prepared here
if self.args.tp_size > 1:
self.is_tp_enabled = True
args["torch_tp_plugin"] = TorchTensorParallelPlugin(tp_size=self.args.tp_size)
# create accelerator object
self.accelerator = GaudiAccelerator(**args)
# some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag
self.gather_function = self.accelerator.gather_for_metrics
if "use_gather_object" in inspect.signature(self.gather_function).parameters.keys():
self.gather_function = functools.partial(
self.gather_function, use_gather_object=self.args.eval_use_gather_object
)
# deepspeed and accelerate flags covering both trainer args and accelerate launcher
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
self.is_tp_enabled = getattr(self.accelerator.state, "torch_tp_plugin", None) is not None
# post accelerator creation setup
if self.is_fsdp_enabled:
fsdp_plugin = self.accelerator.state.fsdp_plugin
for param in ["limit_all_gathers", "activation_checkpointing"]:
setattr(fsdp_plugin, param, self.args.fsdp_config.get(param, getattr(fsdp_plugin, param)))
if fsdp_plugin.activation_checkpointing and self.args.gradient_checkpointing:
raise ValueError(
"The activation_checkpointing in FSDP config and the gradient_checkpointing in training arg "
"can't be set to True simultaneously. Please use FSDP's activation_checkpointing logic "
"when using FSDP."
)
if self.is_deepspeed_enabled and getattr(self.args, "hf_deepspeed_config", None) is None:
self.propagate_args_to_deepspeed()
# `save_only_model` can't be used with DeepSpeed/FSDP along with `load_best_model_at_end`
if (
self.args.save_only_model
and (self.is_deepspeed_enabled or self.is_fsdp_enabled)
and self.args.load_best_model_at_end
):
wrapper = "DeepSpeed" if self.is_deepspeed_enabled else "FSDP"
raise ValueError(f"{wrapper} can't be used with `save_only_model` along with `load_best_model_at_end`.")
# `auto_find_batch_size` isn't supported yet with DeepSpeed Zero-3
if (
self.is_deepspeed_enabled
and self.accelerator.state.deepspeed_plugin.zero_stage == 3
and self.args.auto_find_batch_size
):
raise ValueError(
"`auto_find_batch_size` isn't supported yet with DeepSpeed Zero-3. Please consider using Zero-2, Zero-1, or FSDP"
)
if (
self.args.save_only_model
and self.is_fsdp_enabled
and "SHARDED_STATE_DICT" in str(self.accelerator.state.fsdp_plugin.state_dict_type)
):
raise ValueError("save_only_model option is not compatible with FSDP state dict type 'SHARDED_STATE_DICT'")
def propagate_args_to_deepspeed(self, auto_find_batch_size=False):
"""
Sets values in the deepspeed plugin based on the Trainer args
"""
from .integrations.deepspeed import GaudiTrainerDeepSpeedConfig
ds_plugin = self.accelerator.state.deepspeed_plugin
ds_plugin.hf_ds_config = GaudiTrainerDeepSpeedConfig(ds_plugin.hf_ds_config.config)
ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config
ds_plugin.hf_ds_config.trainer_config_process(self.args, auto_find_batch_size)
def _zero_model_grad(self, model):
if hasattr(model, "_zero_grad_kwargs"):
model.zero_grad(**model._zero_grad_kwargs)
else:
# Optimization based on setting gradients to None (instead of zeroing them out) may only be used when gradients are not recorded using HPU graphs.
# HPU graphs rely on fixed tensors - setting gradients to None will enforce their re-allocation during the backward pass each step.
set_to_none = (
self.args.parallel_mode != ParallelMode.DISTRIBUTED or self.args.distribution_strategy == "ddp"
) and not self.args.use_hpu_graphs_for_training
try:
model.zero_grad(set_to_none=set_to_none)
model._zero_grad_kwargs = {"set_to_none": set_to_none}
except TypeError:
model.zero_grad()
model._zero_grad_kwargs = {}
def get_batch_samples_transformers(self, epoch_iterator, num_batches, device):
batch_samples = []
num_items_in_batch = None
for _ in range(num_batches):
try:
batch_samples.append(next(epoch_iterator))
except StopIteration:
break
count_num_items_in_batch = (
len(batch_samples) > 0
and "labels" in batch_samples[0]
and (
# num_items_in_batch is passed to model forward
# https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/trainer.py#L3757
self.model_accepts_loss_kwargs
# num_items_in_batch is passed to compute_loss_func
# https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/trainer.py#L3773
or self.compute_loss_func is not None
# num_items_in_batch is also verified if (self.model_accepts_loss_kwargs or self.compute_loss_func)
# https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/trainer.py#L3790
)
)
if count_num_items_in_batch:
# For now we don't support object detection
try:
num_items_in_batch = torch.cat([batch["labels"] for batch in batch_samples]).ne(-100).sum()
except (TypeError, AttributeError, RuntimeError):
pass
if num_items_in_batch is not None:
if self.args.average_tokens_across_devices:
num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum()
num_items_in_batch = num_items_in_batch.to(device)
return batch_samples, num_items_in_batch