optimum/onnxruntime/trainer.py (740 lines of code) (raw):

# Copyright 2022 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ The ORTTrainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task with ONNX Runtime. """ import functools import math import os import shutil import sys import time import types import warnings from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union # Integrations must be imported before ML frameworks: # isort: off import safetensors from transformers.integrations import hp_params from transformers.utils import is_accelerate_available from packaging import version if is_accelerate_available(): from accelerate import __version__ as accelerate_version from accelerate.utils import DistributedDataParallelKwargs if version.parse(accelerate_version) >= version.parse("0.16"): from accelerate import skip_first_batches else: skip_first_batches = None from accelerate.utils import DistributedType else: raise ImportError( "The package `accelerate` is required to use the ORTTrainer. Please install it following https://huggingface.co/docs/accelerate/basic_tutorials/install." ) # isort: on import huggingface_hub.utils as hf_hub_utils import torch import torch.distributed as dist from torch import nn from torch.utils.data import Dataset, RandomSampler from transformers.data.data_collator import DataCollator from transformers.debug_utils import DebugOption, DebugUnderflowOverflow from transformers.modeling_utils import PreTrainedModel, unwrap_model from transformers.tokenization_utils_base import PreTrainedTokenizerBase from transformers.trainer import Trainer from transformers.trainer_callback import ExportableState, TrainerCallback, TrainerState from transformers.trainer_pt_utils import ( get_model_param_count, get_module_class_from_name, get_parameter_names, ) from transformers.trainer_utils import ( EvalPrediction, HPSearchBackend, TrainOutput, enable_full_determinism, find_executable_batch_size, get_last_checkpoint, has_length, set_seed, speed_metrics, ) from transformers.training_args import ParallelMode from transformers.utils import ( SAFE_WEIGHTS_NAME, WEIGHTS_NAME, is_apex_available, is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, ) from ..utils import logging from ..utils.import_utils import is_transformers_version from .training_args import ORTOptimizerNames, ORTTrainingArguments from .utils import ( is_onnxruntime_training_available, ) if is_apex_available(): from apex import amp if is_transformers_version(">=", "4.33"): from transformers.integrations.deepspeed import ( deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_zero3_enabled, ) else: from transformers.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_zero3_enabled if is_transformers_version(">=", "4.39"): from transformers.utils import is_torch_xla_available as is_torch_tpu_xla_available if is_torch_tpu_xla_available(): import torch_xla.core.xla_model as xm else: from transformers.utils import is_torch_tpu_available as is_torch_tpu_xla_available if is_torch_tpu_xla_available(check_device=False): import torch_xla.core.xla_model as xm if TYPE_CHECKING: import optuna # Name of the files used for checkpointing TRAINER_STATE_NAME = "trainer_state.json" TRAINING_ARGS_NAME = "training_args.bin" logger = logging.get_logger(__name__) class ModuleWithLoss(PreTrainedModel): def __init__(self, model, args, label_smoother): super().__init__() self._original_model = model self.args = args # Label smoothing self.label_smoother = label_smoother def forward(self, inputs: Dict[str, Union[torch.Tensor, Any]], return_outputs, num_items_in_batch): # The compute_model_plus_loss_internal is assigned once the class is instantiated. # It should have same signature as Trainer.compute_loss(). # We do this to avoid potential un-synced states if we duplicated compute loss codes . return self.compute_model_plus_loss_internal(self._original_model, inputs, return_outputs, num_items_in_batch) @property def module(self): """The original `torch.nn.Module` that this module wraps. This property provides access to methods and properties on the original module.""" return self._original_model.module @property def config(self): return self._original_model.config class ORTTrainer(Trainer): """ ORTTrainer is a simple but feature-complete training and eval loop for ONNX Runtime, optimized for 🤗 Transformers. Args: model ([`~transformers.PreTrainedModel`] or `torch.nn.Module`, *optional*): The model to train, evaluate or use for predictions. If not provided, a `model_init` must be passed. <Tip> [`ORTTrainer`] is optimized to work with the [`~transformers.PreTrainedModel`] provided by the transformers library. You can still use your own models defined as `torch.nn.Module` for training with ONNX Runtime backend and inference with PyTorch backend as long as they work the same way as the 🤗 Transformers models. </Tip> args ([`ORTTrainingArguments`], *optional*): The arguments to tweak for training. Will default to a basic instance of [`ORTTrainingArguments`] with the `output_dir` set to a directory named *tmp_trainer* in the current directory if not provided. data_collator ([`~transformers.DataCollator`], *optional*): The function to use to form a batch from a list of elements of `train_dataset` or `eval_dataset`. Will default to [`~transformers.default_data_collator`] if no `tokenizer` is provided, an instance of [`~transformers.DataCollatorWithPadding`] otherwise. train_dataset (`torch.utils.data.Dataset` or `torch.utils.data.IterableDataset`, *optional*): The dataset to use for training. If it is a [`~datasets.Dataset`], columns not accepted by the `model.forward()` method are automatically removed. Note that if it's a `torch.utils.data.IterableDataset` with some randomization and you are training in a distributed fashion, your iterable dataset should either use a internal attribute `generator` that is a `torch.Generator` for the randomization that must be identical on all processes (and the ORTTrainer will manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally sets the seed of the RNGs used. eval_dataset (Union[`torch.utils.data.Dataset`, Dict[str, `torch.utils.data.Dataset`]), *optional*): The dataset to use for evaluation. If it is a [`~datasets.Dataset`], columns not accepted by the `model.forward()` method are automatically removed. If it is a dictionary, it will evaluate on each dataset prepending the dictionary key to the metric name. tokenizer ([`~transformers.PreTrainedTokenizerBase`], *optional*): The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs the maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an interrupted training or reuse the fine-tuned model. model_init (`Callable[[], PreTrainedModel]`, *optional*): A function that instantiates the model to be used. If provided, each call to [`ORTTrainer.train`] will start from a new instance of the model as given by this function. The function may have zero argument, or a single one containing the optuna/Ray Tune/SigOpt trial object, to be able to choose different architectures according to hyper parameters (such as layer count, sizes of inner layers, dropout probabilities etc). compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*): The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return a dictionary string to metric values. callbacks (List of [`TrainerCallback`], *optional*): A list of callbacks to customize the training loop. Will add those to the list of default callbacks detailed in [here](callback). If you want to remove one of the default callbacks used, use the [`ORTTrainer.remove_callback`] method. optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*): A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`. preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*): A function that preprocess the logits right before caching them at each evaluation step. Must take two tensors, the logits and the labels, and return the logits once processed as desired. The modifications made by this function will be reflected in the predictions received by `compute_metrics`. Note that the labels (second parameter) will be `None` if the dataset does not have them. Important attributes: - **model** -- Always points to the core model. If using a transformers model, it will be a [`~transformers.PreTrainedModel`] subclass. - **model_wrapped** -- Always points to the most external model in case one or more other modules wrap the original model. This is the model that should be used for the forward pass. For example, under `DeepSpeed`, the inner model is first wrapped in `ORTModule` and then in `DeepSpeed` and then again in `torch.nn.DistributedDataParallel`. If the inner model hasn't been wrapped, then `self.model_wrapped` is the same as `self.model`. - **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from data parallelism, this means some of the model layers are split on different GPUs). - **place_model_on_device** -- Whether or not to automatically place the model on the device - it will be set to `False` if model parallel or deepspeed is used, or if the default `ORTTrainingArguments.place_model_on_device` is overridden to return `False` . - **is_in_train** -- Whether or not a model is currently running `train` (e.g. when `evaluate` is called while in `train`) """ def __init__( self, model: Union[PreTrainedModel, nn.Module] = None, args: ORTTrainingArguments = None, data_collator: Optional[DataCollator] = None, train_dataset: Optional[Dataset] = None, eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, tokenizer: Optional[PreTrainedTokenizerBase] = None, model_init: Optional[Callable[[], PreTrainedModel]] = None, compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, callbacks: Optional[List[TrainerCallback]] = None, optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, ): super().__init__( model=model, args=args, data_collator=data_collator, train_dataset=train_dataset, eval_dataset=eval_dataset, tokenizer=tokenizer, model_init=model_init, compute_metrics=compute_metrics, callbacks=callbacks, optimizers=optimizers, preprocess_logits_for_metrics=preprocess_logits_for_metrics, ) # We leverage both training_model and inference_model in conjunction with model. # _training_model will be wrapped so it will use ORT and will use the overriden functions in ModuleWithLoss. # _training_model will be storing the default version of the model and will unwrap it in case of eval/test. # Only Wrap the model if we pass --use_module_with_loss flag. if args.use_module_with_loss: self._training_model = self.create_model_with_loss() self.model = model if self.args.local_rank: torch.cuda.set_device(self.args.local_rank) # this method will create a ModuleWithLoss Instance to use if you are passing --use_module_with_loss flag. # It will help reducing the peak memory usage by computing loss inside training. def create_model_with_loss(self): model_with_loss = ModuleWithLoss(self.model, self.args, self.label_smoother) model_with_loss.compute_model_plus_loss_internal = types.MethodType(Trainer.compute_loss, model_with_loss) return model_with_loss # we assume that training_model and inference_model have the same forward signature column. # self._signature_columns attribute only stores the first-time parsed signature def _set_signature_columns_if_needed(self): if self._signature_columns is None: # Inspect model forward signature to keep only the arguments it accepts. import inspect if isinstance(self.model, ModuleWithLoss): signature = inspect.signature(self.model._original_model.forward) else: signature = inspect.signature(self.model.forward) self._signature_columns = list(signature.parameters.keys()) # Labels may be named label or label_ids, the default data collator handles that. self._signature_columns += list(set(["label", "label_ids"] + self.label_names)) def compute_loss(self, model_with_loss, inputs, return_outputs=False, num_items_in_batch=None): # Run model forward + loss compute. if isinstance(self.model, ModuleWithLoss): # ORTModule Does not support the BatchEncoding Type so we have to convert to a dict. dict_inputs = dict(inputs.items()) return model_with_loss(dict_inputs, return_outputs, num_items_in_batch) else: return super().compute_loss(model_with_loss, inputs, return_outputs, num_items_in_batch) def train( self, resume_from_checkpoint: Optional[Union[str, bool]] = None, trial: Union["optuna.Trial", Dict[str, Any]] = None, ignore_keys_for_eval: Optional[List[str]] = None, **kwargs, ): """ Main entry point for training with ONNX Runtime accelerator. Args: resume_from_checkpoint (`str` or `bool`, *optional*): If a `str`, local path to a saved checkpoint as saved by a previous instance of [`ORTTrainer`]. If a `bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance of [`ORTTrainer`]. If present, training will resume from the model/optimizer/scheduler states loaded here. trial (`optuna.Trial` or `Dict[str, Any]`, *optional*): The trial run or the hyperparameter dictionary for hyperparameter search. ignore_keys_for_eval (`List[str]`, *optional*) A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions for evaluation during the training. kwargs (`Dict[str, Any]`, *optional*): Additional keyword arguments used to hide deprecated arguments """ if not is_onnxruntime_training_available(): raise ImportError( "You need to install `onnxruntime-training` to use `ORTTrainer` for training. Check out " "https://huggingface.co/docs/optimum/onnxruntime/usage_guides/trainer#install-onnx-runtime." ) if self.args.use_module_with_loss: self.model = self._training_model if resume_from_checkpoint is False: resume_from_checkpoint = None # memory metrics - must set up as early as possible self._memory_tracker.start() args = self.args self.is_in_train = True # do_train is not a reliable argument, as it might not be set and .train() still called, so # the following is a workaround: if (args.fp16_full_eval or args.bf16_full_eval) and not args.do_train: self._move_model_to_device(self.model, args.device) if "model_path" in kwargs: resume_from_checkpoint = kwargs.pop("model_path") warnings.warn( "`model_path` is deprecated and will be removed in a future version. Use `resume_from_checkpoint` " "instead.", FutureWarning, ) if len(kwargs) > 0: raise TypeError(f"train() received got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.") # This might change the seed so needs to run first. self._hp_search_setup(trial) self._train_batch_size = self.args.train_batch_size # Model re-init model_reloaded = False if self.model_init is not None: # Seed must be set before instantiating the model when using model_init. enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed) self.model = self.call_model_init(trial) model_reloaded = True # Reinitializes optimizer and scheduler self.optimizer, self.lr_scheduler = None, None # Load potential model checkpoint if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint: resume_from_checkpoint = get_last_checkpoint(args.output_dir) if resume_from_checkpoint is None: raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})") if ( resume_from_checkpoint is not None and not is_sagemaker_mp_enabled() and not self.is_deepspeed_enabled and not self.is_fsdp_enabled ): self._load_from_checkpoint(resume_from_checkpoint) # If model was re-initialized, put it on the right device and update self.model_wrapped if model_reloaded: if self.place_model_on_device: self._move_model_to_device(self.model, args.device) self.model_wrapped = self.model inner_training_loop = find_executable_batch_size( self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size ) if args.push_to_hub: try: # Disable progress bars when uploading models during checkpoints to avoid polluting stdout hf_hub_utils.disable_progress_bars() return inner_training_loop( args=args, resume_from_checkpoint=resume_from_checkpoint, trial=trial, ignore_keys_for_eval=ignore_keys_for_eval, ) finally: hf_hub_utils.enable_progress_bars() else: return inner_training_loop( args=args, resume_from_checkpoint=resume_from_checkpoint, trial=trial, ignore_keys_for_eval=ignore_keys_for_eval, ) def _inner_training_loop( self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None ): from torch_ort import ORTModule self.accelerator.free_memory() self._train_batch_size = batch_size logger.debug(f"Currently training with a batch size of: {self._train_batch_size}") # Data loader and number of training steps train_dataloader = self.get_train_dataloader() # Setting up training control variables: # number of training epochs: num_train_epochs # number of training steps per epoch: num_update_steps_per_epoch # total number of training steps to execute: max_steps total_train_batch_size = self._train_batch_size * args.gradient_accumulation_steps * args.world_size len_dataloader = None if has_length(train_dataloader): len_dataloader = len(train_dataloader) num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) num_examples = self.num_examples(train_dataloader) if args.max_steps > 0: max_steps = args.max_steps num_train_epochs = args.max_steps // num_update_steps_per_epoch + int( args.max_steps % num_update_steps_per_epoch > 0 ) # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's # the best we can do. num_train_samples = args.max_steps * total_train_batch_size else: max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch) num_train_epochs = math.ceil(args.num_train_epochs) num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs elif args.max_steps > 0: # Rely on max_steps when dataloader does not have a working size max_steps = args.max_steps # Setting a very large number of epochs so we go as many times as necessary over the iterator. num_train_epochs = sys.maxsize num_update_steps_per_epoch = max_steps num_examples = total_train_batch_size * args.max_steps num_train_samples = args.max_steps * total_train_batch_size else: raise ValueError( "args.max_steps must be set to a positive value if dataloader does not have a length, was" f" {args.max_steps}" ) if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug: if self.args.n_gpu > 1: # nn.DataParallel(model) replicates the model, creating new variables and module # references registered here no longer work on other gpus, breaking the module raise ValueError( "Currently --debug underflow_overflow is not supported under DP. Please use DDP" " (torch.distributed.launch)." ) else: debug_overflow = DebugUnderflowOverflow(self.model) # noqa delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled # Wrap the model with `ORTModule` logger.info("Wrap ORTModule for ONNX Runtime training.") if self.args.save_onnx: from torch_ort import DebugOptions model = ORTModule( self.model, DebugOptions(save_onnx=self.args.save_onnx, onnx_prefix=self.args.onnx_prefix) ) else: model = ORTModule(self.model) self.model_wrapped = model self.model = model # We need to reset the scheduler, as its parameters may be different on subsequent calls if self._created_lr_scheduler: self.lr_scheduler = None self._created_lr_scheduler = False if self.is_deepspeed_enabled: if is_deepspeed_zero3_enabled(): raise NotImplementedError( "`ORTTrainer` does not support ZeRO stage 3 for the moment. Please use DeepSpeed stage 1 or 2 instead." ) if args.bf16: warnings.warn( "ONNX Runtime doesn't support BF16 when executing some operators. The execution will fail if there are any" " op which doesn't support BF16 in the IR.", RuntimeWarning, ) self.model = model self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps) if not delay_optimizer_creation: self.create_optimizer_and_scheduler(num_training_steps=max_steps) self.state = TrainerState( stateful_callbacks=[ cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState) ] ) self.state.is_hyper_param_search = trial is not None self.state.train_batch_size = self._train_batch_size # Compute absolute values for logging, eval, and save if given as ratio if args.logging_steps is not None: if args.logging_steps < 1: self.state.logging_steps = math.ceil(max_steps * args.logging_steps) else: self.state.logging_steps = args.logging_steps if args.eval_steps is not None: if args.eval_steps < 1: self.state.eval_steps = math.ceil(max_steps * args.eval_steps) else: self.state.eval_steps = args.eval_steps if args.save_steps is not None: if args.save_steps < 1: self.state.save_steps = math.ceil(max_steps * args.save_steps) else: self.state.save_steps = args.save_steps # Activate gradient checkpointing if needed if args.gradient_checkpointing: self.model.gradient_checkpointing_enable() model = self._wrap_model(self.model_wrapped) # Wrap unless the ORTModule is already wrapped, eg. wrap DDP if (is_sagemaker_mp_enabled() or self.is_fsdp_enabled) and resume_from_checkpoint is not None: self._load_from_checkpoint(resume_from_checkpoint, model) # as the model is wrapped, don't use `accelerator.prepare` # this is for unhandled cases such as # FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX use_accelerator_prepare = True if model is self.model else False if delay_optimizer_creation: if use_accelerator_prepare: self.model = self.accelerator.prepare(self.model) self.create_optimizer_and_scheduler(num_training_steps=max_steps) # prepare using `accelerator` prepare if use_accelerator_prepare: self.model.train() self.accelerator.distributed_type == DistributedType.DEEPSPEED if hasattr(self.lr_scheduler, "step"): if self.use_apex: model = self.accelerator.prepare(self.model) else: model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) else: # to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config. model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( self.model, self.optimizer, self.lr_scheduler ) self.model = unwrap_model(model) if self.is_fsdp_enabled: self.model = model # for the rest of this function `model` is the outside model, whether it was wrapped or not if model is not self.model: self.model_wrapped = model # backward compatibility if self.is_deepspeed_enabled: self.deepspeed = self.model_wrapped # deepspeed ckpt loading if resume_from_checkpoint is not None and self.is_deepspeed_enabled: deepspeed_load_checkpoint(self.model_wrapped, resume_from_checkpoint) # Check if saved optimizer or scheduler states exist self._load_optimizer_and_scheduler(resume_from_checkpoint) # Important: at this point if enabled distributed training features: # self.model is the Transformers Model # self.model_wrapped is DDP(ORTModule(Transformers Model)), Deepspeed(ORTModule(Transformers Model)), etc. # Train! logger.info("***** Running training *****") logger.info(f" Num examples = {num_examples:,}") logger.info(f" Num Epochs = {num_train_epochs:,}") logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}") if self.args.per_device_train_batch_size != self._train_batch_size: logger.info(f" Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {max_steps:,}") logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}") self.state.epoch = 0 start_time = time.time() epochs_trained = 0 steps_trained_in_current_epoch = 0 steps_trained_progress_bar = None # Check if continuing training from a checkpoint if resume_from_checkpoint is not None and os.path.isfile( os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME) ): self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) epochs_trained = self.state.global_step // num_update_steps_per_epoch if not args.ignore_data_skip: steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) steps_trained_in_current_epoch *= args.gradient_accumulation_steps else: steps_trained_in_current_epoch = 0 logger.info(" Continuing training from checkpoint, will skip to saved global_step") logger.info(f" Continuing training from epoch {epochs_trained}") logger.info(f" Continuing training from global step {self.state.global_step}") if not args.ignore_data_skip: logger.info( f" Will skip the first {epochs_trained} epochs then the first" f" {steps_trained_in_current_epoch} batches in the first epoch." ) # Update the references self.callback_handler.model = self.model self.callback_handler.optimizer = self.optimizer self.callback_handler.lr_scheduler = self.lr_scheduler self.callback_handler.train_dataloader = train_dataloader if self.hp_name is not None and self._trial is not None: # use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial # parameter to Train when using DDP. self.state.trial_name = self.hp_name(self._trial) if trial is not None: assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial self.state.trial_params = hp_params(assignments) else: self.state.trial_params = None # This should be the same if the state has been saved but in case the training arguments changed, it's safer # to set this after the load. self.state.max_steps = max_steps self.state.num_train_epochs = num_train_epochs self.state.is_local_process_zero = self.is_local_process_zero() self.state.is_world_process_zero = self.is_world_process_zero() # tr_loss is a tensor to avoid synchronization of TPUs through .item() tr_loss = torch.tensor(0.0).to(args.device) # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses self._total_loss_scalar = 0.0 self._globalstep_last_logged = self.state.global_step model.zero_grad() self.control = self.callback_handler.on_train_begin(args, self.state, self.control) # Temp: remove after transformers 4.34 release def get_dataloader_sampler(dataloader): if hasattr(dataloader, "batch_sampler") and dataloader.batch_sampler is not None: return get_dataloader_sampler(dataloader.batch_sampler) elif hasattr(dataloader, "sampler"): return dataloader.sampler # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point. if not args.ignore_data_skip: for epoch in range(epochs_trained): sampler = get_dataloader_sampler(train_dataloader) is_random_sampler = isinstance(sampler, RandomSampler) if not is_random_sampler: # We just need to begin an iteration to create the randomization of the sampler. for _ in train_dataloader: break else: # Otherwise we need to call the whooooole sampler cause there is some random operation added # AT THE VERY END! sampler = sampler if sampler is not None else [] _ = list(sampler) total_batched_samples = 0 for epoch in range(epochs_trained, num_train_epochs): epoch_iterator = train_dataloader # Reset the past mems state at the beginning of each epoch if necessary. if args.past_index >= 0: self._past = None steps_in_epoch = ( len(epoch_iterator) if len_dataloader is not None else args.max_steps * args.gradient_accumulation_steps ) self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control) if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0: self._load_rng_state(resume_from_checkpoint) rng_to_sync = False steps_skipped = 0 if steps_trained_in_current_epoch > 0: epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch) steps_skipped = steps_trained_in_current_epoch steps_trained_in_current_epoch = 0 rng_to_sync = True step = -1 for step, inputs in enumerate(epoch_iterator): total_batched_samples += 1 if rng_to_sync: self._load_rng_state(resume_from_checkpoint) rng_to_sync = False # Skip past any already trained steps if resuming training if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 if steps_trained_progress_bar is not None: steps_trained_progress_bar.update(1) if steps_trained_in_current_epoch == 0: self._load_rng_state(resume_from_checkpoint) continue elif steps_trained_progress_bar is not None: steps_trained_progress_bar.close() steps_trained_progress_bar = None if step % args.gradient_accumulation_steps == 0: self.control = self.callback_handler.on_step_begin(args, self.state, self.control) with self.accelerator.accumulate(model): tr_loss_step = self.training_step(model, inputs) if ( args.logging_nan_inf_filter and not is_torch_tpu_xla_available() and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)) ): # if loss is nan or inf simply add the average of previous logged losses tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged) else: tr_loss += tr_loss_step self.current_flos += float(self.floating_point_ops(inputs)) is_last_step_and_steps_less_than_grad_acc = ( steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch ) if ( total_batched_samples % args.gradient_accumulation_steps == 0 or # last step in epoch but step is always smaller than gradient_accumulation_steps is_last_step_and_steps_less_than_grad_acc ): # the `or` condition of `is_last_step_and_steps_less_than_grad_acc` is not covered # in accelerate. So, explicitly enable sync gradients to True in that case. if is_last_step_and_steps_less_than_grad_acc or ( version.parse(accelerate_version) <= version.parse("0.20.3") ): self.accelerator.gradient_state._set_sync_gradients(True) # Gradient clipping if args.max_grad_norm is not None and args.max_grad_norm > 0: # deepspeed does its own clipping if is_sagemaker_mp_enabled() and args.fp16: _grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm) elif hasattr(self.optimizer, "clip_grad_norm"): # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping _grad_norm = self.optimizer.clip_grad_norm(args.max_grad_norm) elif hasattr(model, "clip_grad_norm_"): # Some models (like FullyShardedDDP) have a specific way to do gradient clipping _grad_norm = model.clip_grad_norm_(args.max_grad_norm) else: _grad_norm = self.accelerator.clip_grad_norm_( model.parameters(), args.max_grad_norm, ) if ( is_accelerate_available() and self.accelerator.distributed_type == DistributedType.DEEPSPEED ): grad_norm = model.get_global_grad_norm() else: grad_norm = _grad_norm.item() if _grad_norm is not None else None # Optimizer step self.optimizer.step() optimizer_was_run = not self.accelerator.optimizer_step_was_skipped if optimizer_was_run: # Delay optimizer scheduling until metrics are generated if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): self.lr_scheduler.step() model.zero_grad() self.state.global_step += 1 self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch self.control = self.callback_handler.on_step_end(args, self.state, self.control) if is_transformers_version(">=", "4.47.0"): self._maybe_log_save_evaluate( tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time ) else: self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval) else: self.control = self.callback_handler.on_substep_end(args, self.state, self.control) if self.control.should_epoch_stop or self.control.should_training_stop: break if step < 0: logger.warning( f"There seems to be not a single sample in your train dataloader, stopping training at step" f" {self.state.global_step}! This is expected if you're using an IterableDataset and set" f" num_steps ({max_steps}) higher than the number of available samples." ) self.control.should_training_stop = True self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) if is_transformers_version(">=", "4.47.0"): self._maybe_log_save_evaluate( tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time ) else: self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval) if DebugOption.TPU_METRICS_DEBUG in self.args.debug: logger.warning( "You enabled PyTorch/XLA debug metrics which is not supported by ONNX " "Runtime. Check your training configuration if this is unexpected." ) if self.control.should_training_stop: break if args.past_index and hasattr(self, "_past"): # Clean the state at the end of training delattr(self, "_past") logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") if args.load_best_model_at_end and self.state.best_model_checkpoint is not None: # Wait for everyone to get here so we are sur the model has been saved by process 0. if args.parallel_mode == ParallelMode.DISTRIBUTED: dist.barrier() self._load_best_model() # add remaining tr_loss self._total_loss_scalar += tr_loss.item() train_loss = self._total_loss_scalar / self.state.global_step metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps) self.store_flos() metrics["total_flos"] = self.state.total_flos metrics["train_loss"] = train_loss self.is_in_train = False self._memory_tracker.stop_and_update_metrics(metrics) self.log(metrics) run_dir = self._get_output_dir(trial) checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir) # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save. if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1: for checkpoint in checkpoints_sorted: if not os.path.samefile(checkpoint, self.state.best_model_checkpoint): logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") shutil.rmtree(checkpoint) self.control = self.callback_handler.on_train_end(args, self.state, self.control) # Wait for the checkpoint to be uploaded. self._finish_current_push() return TrainOutput(self.state.global_step, train_loss, metrics) def _wrap_model(self, model, training=True, dataloader=None): # TODO: ipex only works with inference with PyTorch, will move `inference_with_ort` to training arguments and # whether be able to use ipex will depend on both `self.args.use_ipex` and `self.args.ort_mode_eval`. if self.args.use_ipex: dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32 model = self.ipex_optimize_model(model, training, dtype=dtype) if is_sagemaker_mp_enabled(): raise NotImplementedError( "Sagemaker's distrubuted data parallel features are not supported by `ORTTrainer`." ) # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again if unwrap_model(model) is not model: from torch_ort import ORTModule if not isinstance(model, ORTModule): return model # Mixed precision training with apex (torch < 1.6) if self.use_apex and training: model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level) if self.args.fp16: from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer self.optimizer = FP16_Optimizer(self.optimizer) # Multi-gpu training (should be after apex fp16 initialization) / 8bit models does not support DDP if self.args.n_gpu > 1 and not getattr(model, "is_loaded_in_8bit", False): model = nn.DataParallel(model) if self.args.jit_mode_eval: start_time = time.time() model = self.torch_jit_model_eval(model, dataloader, training) self.jit_compilation_time = round(time.time() - start_time, 4) # Note: in torch.distributed mode, there's no point in wrapping the model # inside a DistributedDataParallel as we'll be under `no_grad` anyways. if not training: return model # Distributed training using PyTorch FSDP if self.is_fsdp_xla_enabled: try: from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP from torch_xla.distributed.fsdp import checkpoint_module from torch_xla.distributed.fsdp.wrap import ( size_based_auto_wrap_policy, transformer_auto_wrap_policy, ) except ImportError: raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.") auto_wrap_policy = None auto_wrapper_callable = None default_transformer_cls_names_to_wrap = getattr(model, "_no_split_modules", None) fsdp_transformer_layer_cls_to_wrap = self.args.fsdp_config.get( "transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap ) if self.args.fsdp_config["min_num_params"] > 0: auto_wrap_policy = functools.partial( size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["min_num_params"] ) elif fsdp_transformer_layer_cls_to_wrap is not None: transformer_cls_to_wrap = set() for layer_class in fsdp_transformer_layer_cls_to_wrap: transformer_cls = get_module_class_from_name(model, layer_class) if transformer_cls is None: raise Exception("Could not find the transformer layer class to wrap in the model.") else: transformer_cls_to_wrap.add(transformer_cls) auto_wrap_policy = functools.partial( transformer_auto_wrap_policy, # Transformer layer class to wrap transformer_layer_cls=transformer_cls_to_wrap, ) fsdp_kwargs = self.args.xla_fsdp_config if self.args.fsdp_config["xla_fsdp_grad_ckpt"]: # Apply gradient checkpointing to auto-wrapped sub-modules if specified def auto_wrapper_callable(m, *args, **kwargs): return FSDP(checkpoint_module(m), *args, **kwargs) # Wrap the base model with an outer FSDP wrapper self.model = model = FSDP( model, auto_wrap_policy=auto_wrap_policy, auto_wrapper_callable=auto_wrapper_callable, **fsdp_kwargs, ) # Patch `xm.optimizer_step` should not reduce gradients in this case, # as FSDP does not need gradient reduction over sharded parameters. def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}): loss = optimizer.step(**optimizer_args) if barrier: xm.mark_step() return loss xm.optimizer_step = patched_optimizer_step elif is_sagemaker_dp_enabled(): raise NotImplementedError( "Sagemaker's distrubuted data parallel features are not supported by `ORTTrainer` yet." ) elif self.args.parallel_mode == ParallelMode.DISTRIBUTED: kwargs = {} if self.args.ddp_find_unused_parameters is not None: kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters elif isinstance(model, PreTrainedModel): # find_unused_parameters breaks checkpointing as per # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021 kwargs["find_unused_parameters"] = not model.is_gradient_checkpointing else: kwargs["find_unused_parameters"] = True if self.args.ddp_bucket_cap_mb is not None: kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb if self.args.ddp_broadcast_buffers is not None: kwargs["broadcast_buffers"] = self.args.ddp_broadcast_buffers self.accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs) return model def create_optimizer(self): """ Setup the optimizer. We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the ORTTrainer's init through `optimizers`, or subclass and override this method in a subclass. """ opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model if self.optimizer is None: decay_parameters = get_parameter_names(opt_model, [nn.LayerNorm]) decay_parameters = [name for name in decay_parameters if "bias" not in name] optimizer_grouped_parameters = [ { "params": [p for n, p in opt_model.named_parameters() if n in decay_parameters], "weight_decay": self.args.weight_decay, }, { "params": [p for n, p in opt_model.named_parameters() if n not in decay_parameters], "weight_decay": 0.0, }, ] if self.args.optim in ORTOptimizerNames: optimizer_cls, optimizer_kwargs = ORTTrainer.get_ort_optimizer_cls_and_kwargs(self.args) else: optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) if optimizer_cls.__name__ == "Adam8bit": import bitsandbytes manager = bitsandbytes.optim.GlobalOptimManager.get_instance() skipped = 0 for module in opt_model.modules(): if isinstance(module, nn.Embedding): skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) logger.info(f"skipped {module}: {skipped/2**20}M params") manager.register_module_override(module, "weight", {"optim_bits": 32}) logger.debug(f"bitsandbytes: will optimize {module} in fp32") logger.info(f"skipped: {skipped/2**20}M params") if is_sagemaker_mp_enabled(): raise NotImplementedError( "Sagemaker's distributed data parallel features are not supported by `ORTTrainer` yet." ) return self.optimizer @staticmethod def get_ort_optimizer_cls_and_kwargs(args: ORTTrainingArguments) -> Tuple[Any, Any]: """ Returns the optimizer class and optimizer parameters implemented in ONNX Runtime based on `ORTTrainingArguments`. Args: args (`ORTTrainingArguments`): The training arguments for the training session. """ optimizer_kwargs = {"lr": args.learning_rate} adam_kwargs = {"betas": (args.adam_beta1, args.adam_beta2), "eps": args.adam_epsilon} if args.optim == ORTOptimizerNames.ADAMW_ORT_FUSED: try: from onnxruntime.training.optim import FusedAdam optimizer_cls = FusedAdam optimizer_kwargs.update(adam_kwargs) except ImportError: raise ImportError( "ORTTrainer tried to instantiate ORT FusedAdam but onnxruntime-training is not correctly installed!" ) else: raise ValueError(f"ORTTrainer cannot instantiate unsupported optimizer: {args.optim}") return optimizer_cls, optimizer_kwargs def _save(self, output_dir: Optional[str] = None, state_dict=None): # If we are executing this function, we are the process zero, so we don't check for that. output_dir = output_dir if output_dir is not None else self.args.output_dir os.makedirs(output_dir, exist_ok=True) logger.info(f"Saving model checkpoint to {output_dir}") supported_classes = (PreTrainedModel,) # Save a trained model and configuration using `save_pretrained()`. # They can then be reloaded using `from_pretrained()` if not isinstance(self.model, supported_classes): if state_dict is None: state_dict = self.model.state_dict() if isinstance(self.accelerator.unwrap_model(self.model), supported_classes): self.accelerator.unwrap_model(self.model).save_pretrained( output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors ) else: logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") if self.args.save_safetensors: safetensors.torch.save_model( self.model, os.path.join(output_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"} ) else: torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) else: self.model.save_pretrained( output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors ) if self.processing_class is not None: self.processing_class.save_pretrained(output_dir) # Good practice: save your training arguments together with the trained model torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))