optimum/intel/neural_compressor/trainer.py (676 lines of code) (raw):
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import os
import shutil
import sys
import time
from collections.abc import Mapping
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
# Integrations must be imported before ML frameworks:
# isort: off
from transformers.integrations import hp_params
from transformers.integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available
# isort: on
import datasets
import torch
import torch.distributed as dist
from neural_compressor import training
from neural_compressor.compression import DistillationCallbacks
from packaging import version
from torch import nn
from torch.utils.data import Dataset, RandomSampler
from transformers import Trainer
from transformers.data.data_collator import DataCollator
from transformers.debug_utils import DebugOption, DebugUnderflowOverflow
from transformers.feature_extraction_utils import FeatureExtractionMixin
from transformers.modeling_utils import PreTrainedModel, get_parameter_dtype, unwrap_model
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer import TRAINER_STATE_NAME
from transformers.trainer_callback import TrainerCallback, TrainerState
from transformers.trainer_pt_utils import get_dataloader_sampler, get_model_param_count
from transformers.trainer_utils import (
EvalPrediction,
HPSearchBackend,
TrainOutput,
has_length,
speed_metrics,
)
from transformers.training_args import ParallelMode, TrainingArguments
from transformers.utils import (
WEIGHTS_NAME,
is_accelerate_available,
is_apex_available,
is_sagemaker_mp_enabled,
logging,
)
from ..utils.constant import TRAINING_ARGS_NAME
from ..utils.import_utils import is_neural_compressor_version, is_transformers_version
from .configuration import INCConfig
if is_transformers_version(">=", "4.39.0"):
from transformers.utils import is_torch_xla_available
else:
from transformers.utils import is_torch_tpu_available as is_torch_xla_available
if is_accelerate_available():
from accelerate import __version__ as accelerate_version
from accelerate import skip_first_batches
if version.parse(accelerate_version) > version.parse("0.20.3"):
pass
DATA_SAMPLERS = [RandomSampler]
if version.parse(accelerate_version) > version.parse("0.23.0"):
from accelerate.data_loader import SeedableRandomSampler
DATA_SAMPLERS += [SeedableRandomSampler]
if is_deepspeed_available():
pass
if is_apex_available():
from apex import amp
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
if is_neural_compressor_version("<", "2.6"):
from neural_compressor.conf.pythonic_config import _BaseQuantizationConfig
else:
from neural_compressor.config import _BaseQuantizationConfig
__version__ = "4.46.0"
logger = logging.get_logger(__name__)
class INCTrainer(Trainer):
"""
INCTrainer enables Intel Neural Compression quantization aware training, pruning and distillation.
"""
def __init__(
self,
model: Union[PreTrainedModel, torch.nn.Module] = None,
args: TrainingArguments = None,
data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Dataset] = None,
processing_class: Optional[Union[PreTrainedTokenizerBase, FeatureExtractionMixin]] = None,
model_init: Callable[[], PreTrainedModel] = None,
compute_loss_func: Optional[Callable] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
callbacks: Optional[List[TrainerCallback]] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None,
quantization_config: Optional[_BaseQuantizationConfig] = None,
pruning_config: Optional[_BaseQuantizationConfig] = None,
distillation_config: Optional[_BaseQuantizationConfig] = None,
task: Optional[str] = None,
**kwargs,
):
self.neftune_noise_alpha = None
super().__init__(
model,
args,
data_collator,
train_dataset,
eval_dataset,
processing_class or kwargs.get("tokenizer", None),
model_init=model_init,
compute_metrics=compute_metrics,
callbacks=callbacks,
optimizers=optimizers,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)
if self.args.device.type == "cuda" and not is_neural_compressor_version(">", "2.0.0"):
logger.warning(
"Neural Compressor version must be > 2.0.0 to train on CUDA devices. "
"Please upgrade Neural Compressor or train your model on CPU devices instead."
)
inc_config = []
self.task = task
self.quantization_config = quantization_config
self.pruning_config = pruning_config
self.distillation_config = distillation_config
self._compression_manager = None
self.distillation_callback = None
# TODO : To deprecate once support transformers > 4.30.0
self.deepspeed = None
# Attach dtype and architecture to the config
if quantization_config is not None:
self.dtype = "int8"
self.model.config.backend = quantization_config.backend
else:
self.dtype = str(get_parameter_dtype(self.model)).split(".")[1]
self.model.config.backend = "default"
self.model.config.torch_dtype = self.dtype
self.model.config.framework = "pytorch_fx"
self.model.config.architectures = [self.model.__class__.__name__]
self._set_signature_columns_if_needed()
for config in [pruning_config, distillation_config, quantization_config]:
if config is not None:
inc_config.append(config)
if len(inc_config) >= 1 and self.args.do_train:
inc_config = inc_config if len(inc_config) > 1 else inc_config.pop()
self._compression_manager = training.prepare_compression(self.model, confs=inc_config)
self.model = self._compression_manager.model.model
self.model_wrapped = self.model
for callback in self._compression_manager.callbacks.callbacks_list:
if isinstance(callback, DistillationCallbacks):
self.distillation_callback = callback
break
self.inc_config = INCConfig(
quantization=self.quantization_config,
distillation=self.distillation_config,
pruning=self.pruning_config,
)
def _inner_training_loop(
self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None
):
self.accelerator.free_memory()
self._train_batch_size = batch_size
if self.args.auto_find_batch_size:
self.state.train_batch_size = self._train_batch_size
logger.debug(f"Currently training with a batch size of: {self._train_batch_size}")
# Data loader and number of training steps
train_dataloader = self.get_train_dataloader()
# Setting up training control variables:
# number of training epochs: num_train_epochs
# number of training steps per epoch: num_update_steps_per_epoch
# total number of training steps to execute: max_steps
total_train_batch_size = self._train_batch_size * args.gradient_accumulation_steps * args.world_size
len_dataloader = None
num_train_tokens = None
if has_length(train_dataloader):
len_dataloader = len(train_dataloader)
num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
num_examples = self.num_examples(train_dataloader)
if args.max_steps > 0:
max_steps = args.max_steps
num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
args.max_steps % num_update_steps_per_epoch > 0
)
# May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
# the best we can do.
num_train_samples = args.max_steps * total_train_batch_size
else:
max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
num_train_epochs = math.ceil(args.num_train_epochs)
num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs
elif args.max_steps > 0: # Rely on max_steps when dataloader does not have a working size
max_steps = args.max_steps
# Setting a very large number of epochs so we go as many times as necessary over the iterator.
num_train_epochs = sys.maxsize
num_update_steps_per_epoch = max_steps
num_examples = total_train_batch_size * args.max_steps
num_train_samples = args.max_steps * total_train_batch_size
else:
raise ValueError(
"args.max_steps must be set to a positive value if dataloader does not have a length, was"
f" {args.max_steps}"
)
if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
if self.args.n_gpu > 1:
# nn.DataParallel(model) replicates the model, creating new variables and module
# references registered here no longer work on other gpus, breaking the module
raise ValueError(
"Currently --debug underflow_overflow is not supported under DP. Please use DDP"
" (torch.distributed.launch)."
)
else:
debug_overflow = DebugUnderflowOverflow(self.model) # noqa
is_fsdp_xla_enabled = (
self.is_fsdp_xla_enabled if is_transformers_version(">=", "4.36.0") else self.fsdp is not None
)
delay_optimizer_creation = is_sagemaker_mp_enabled() or is_fsdp_xla_enabled or self.is_fsdp_enabled
if self.is_deepspeed_enabled:
self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps)
if not delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
if is_transformers_version(">=", "4.44.99"):
from transformers.trainer_callback import ExportableState
self.state = TrainerState(
stateful_callbacks=[
cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
]
)
else:
self.state = TrainerState()
self.state.is_hyper_param_search = trial is not None
self.state.train_batch_size = self._train_batch_size
# Compute absolute values for logging, eval, and save if given as ratio
if args.logging_steps is not None:
if args.logging_steps < 1:
self.state.logging_steps = math.ceil(max_steps * args.logging_steps)
else:
self.state.logging_steps = args.logging_steps
if args.eval_steps is not None:
if args.eval_steps < 1:
self.state.eval_steps = math.ceil(max_steps * args.eval_steps)
else:
self.state.eval_steps = args.eval_steps
if args.save_steps is not None:
if args.save_steps < 1:
self.state.save_steps = math.ceil(max_steps * args.save_steps)
else:
self.state.save_steps = args.save_steps
# Activate gradient checkpointing if needed
if args.gradient_checkpointing:
if args.gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {}
else:
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)
model = self._wrap_model(self.model_wrapped)
# as the model is wrapped, don't use `accelerator.prepare`
# this is for unhandled cases such as
# FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
use_accelerator_prepare = True if model is self.model else False
if delay_optimizer_creation:
if is_transformers_version("<", "4.36.0") and use_accelerator_prepare:
self.model = self.accelerator.prepare(self.model)
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
# prepare using `accelerator` prepare
if use_accelerator_prepare:
self.model.train()
if hasattr(self.lr_scheduler, "step"):
if self.use_apex:
model = self.accelerator.prepare(self.model)
else:
model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
else:
# to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config.
model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
self.model, self.optimizer, self.lr_scheduler
)
if self.is_fsdp_enabled:
self.model = self.model_wrapped = model
# for the rest of this function `model` is the outside model, whether it was wrapped or not
if model is not self.model:
self.model_wrapped = model
# backward compatibility
if self.is_deepspeed_enabled:
self.deepspeed = self.model_wrapped
# ckpt loading
if resume_from_checkpoint is not None:
if self.is_deepspeed_enabled:
deepspeed_load_checkpoint(self.model_wrapped, resume_from_checkpoint)
elif is_sagemaker_mp_enabled() or self.is_fsdp_enabled:
self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped)
# Check if saved optimizer or scheduler states exist
self._load_optimizer_and_scheduler(resume_from_checkpoint)
# important: at this point:
# self.model is the Transformers Model
# self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model),
# FSDP(Transformers Model), Dynamo Optimized Module(Transformers Model) etc.
# Train!
logger.info("***** Running training *****")
logger.info(f" Num examples = {num_examples:,}")
logger.info(f" Num Epochs = {num_train_epochs:,}")
logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}")
if self.args.per_device_train_batch_size != self._train_batch_size:
logger.info(f" Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {max_steps:,}")
logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}")
self.state.epoch = 0
start_time = time.time()
epochs_trained = 0
steps_trained_in_current_epoch = 0
steps_trained_progress_bar = None
# Check if continuing training from a checkpoint
if resume_from_checkpoint is not None and os.path.isfile(
os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
):
self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
epochs_trained = self.state.global_step // num_update_steps_per_epoch
if not args.ignore_data_skip:
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
steps_trained_in_current_epoch *= args.gradient_accumulation_steps
else:
steps_trained_in_current_epoch = 0
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
logger.info(f" Continuing training from epoch {epochs_trained}")
logger.info(f" Continuing training from global step {self.state.global_step}")
if not args.ignore_data_skip:
logger.info(
f" Will skip the first {epochs_trained} epochs then the first"
f" {steps_trained_in_current_epoch} batches in the first epoch."
)
# Update the references
self.callback_handler.model = self.model
self.callback_handler.optimizer = self.optimizer
self.callback_handler.lr_scheduler = self.lr_scheduler
self.callback_handler.train_dataloader = train_dataloader
if self.hp_name is not None and self._trial is not None:
# use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial
# parameter to Train when using DDP.
self.state.trial_name = self.hp_name(self._trial)
if trial is not None:
assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial
self.state.trial_params = hp_params(assignments)
else:
self.state.trial_params = None
# This should be the same if the state has been saved but in case the training arguments changed, it's safer
# to set this after the load.
self.state.max_steps = max_steps
self.state.num_train_epochs = num_train_epochs
self.state.is_local_process_zero = self.is_local_process_zero()
self.state.is_world_process_zero = self.is_world_process_zero()
# tr_loss is a tensor to avoid synchronization of TPUs through .item()
tr_loss = torch.tensor(0.0).to(args.device)
# _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
self._total_loss_scalar = 0.0
self._globalstep_last_logged = self.state.global_step
model.zero_grad()
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
if self._compression_manager is not None:
self._compression_manager.callbacks.on_train_begin()
# Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
if not args.ignore_data_skip:
for epoch in range(epochs_trained):
sampler = get_dataloader_sampler(train_dataloader)
sampler_kinds = [RandomSampler]
if version.parse(accelerate_version) > version.parse("0.23.0"):
sampler_kinds.append(SeedableRandomSampler)
is_random_sampler = isinstance(sampler, tuple(sampler_kinds))
if not is_random_sampler:
# We just need to begin an iteration to create the randomization of the sampler.
for _ in train_dataloader:
break
else:
# Otherwise we need to call the whooooole sampler cause there is some random operation added
# AT THE VERY END!
sampler = sampler if sampler is not None else []
_ = list(sampler)
total_batched_samples = 0
for epoch in range(epochs_trained, num_train_epochs):
epoch_iterator = train_dataloader
if hasattr(epoch_iterator, "set_epoch"):
epoch_iterator.set_epoch(epoch)
# Reset the past mems state at the beginning of each epoch if necessary.
if args.past_index >= 0:
self._past = None
steps_in_epoch = (
len(epoch_iterator)
if len_dataloader is not None
else args.max_steps * args.gradient_accumulation_steps
)
self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
if self._compression_manager is not None:
self._compression_manager.callbacks.on_epoch_begin(epoch)
if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
self._load_rng_state(resume_from_checkpoint)
rng_to_sync = False
steps_skipped = 0
if steps_trained_in_current_epoch > 0:
epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch)
steps_skipped = steps_trained_in_current_epoch
steps_trained_in_current_epoch = 0
rng_to_sync = True
step = -1
for step, inputs in enumerate(epoch_iterator):
total_batched_samples += 1
if is_transformers_version(">=", "4.36.0") and self.args.include_num_input_tokens_seen:
main_input_name = getattr(self.model, "main_input_name", "input_ids")
if main_input_name not in inputs:
logger.warning(
"Tried to track the number of tokens seen, however the current model is "
"not configured properly to know what item is the input. To fix this, add "
"a `main_input_name` attribute to the model class you are using."
)
else:
self.state.num_input_tokens_seen += self.accelerator.gather(inputs[main_input_name]).numel()
if rng_to_sync:
self._load_rng_state(resume_from_checkpoint)
rng_to_sync = False
# Skip past any already trained steps if resuming training
if steps_trained_in_current_epoch > 0:
steps_trained_in_current_epoch -= 1
if steps_trained_progress_bar is not None:
steps_trained_progress_bar.update(1)
if steps_trained_in_current_epoch == 0:
self._load_rng_state(resume_from_checkpoint)
continue
elif steps_trained_progress_bar is not None:
steps_trained_progress_bar.close()
steps_trained_progress_bar = None
if step % args.gradient_accumulation_steps == 0:
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
if self._compression_manager is not None:
self._compression_manager.callbacks.on_step_begin(step)
with self.accelerator.accumulate(model):
tr_loss_step = self.training_step(model, inputs)
if (
args.logging_nan_inf_filter
and not is_torch_xla_available()
and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
):
# if loss is nan or inf simply add the average of previous logged losses
tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
else:
tr_loss += tr_loss_step
self.current_flos += float(self.floating_point_ops(inputs))
is_last_step_and_steps_less_than_grad_acc = (
steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch
)
if (
total_batched_samples % args.gradient_accumulation_steps == 0
or
# last step in epoch but step is always smaller than gradient_accumulation_steps
is_last_step_and_steps_less_than_grad_acc
):
# the `or` condition of `is_last_step_and_steps_less_than_grad_acc` is not covered
# in accelerate. So, explicitly enable sync gradients to True in that case.
if is_last_step_and_steps_less_than_grad_acc:
self.accelerator.gradient_state._set_sync_gradients(True)
# Gradient clipping
if args.max_grad_norm is not None and args.max_grad_norm > 0:
# deepspeed does its own clipping
if is_sagemaker_mp_enabled() and args.fp16:
self.optimizer.clip_master_grads(args.max_grad_norm)
elif self.use_apex:
# Revert to normal clipping otherwise, handling Apex or full precision
nn.utils.clip_grad_norm_(
amp.master_params(self.optimizer),
args.max_grad_norm,
)
else:
self.accelerator.clip_grad_norm_(
model.parameters(),
args.max_grad_norm,
)
if self._compression_manager is not None:
self._compression_manager.callbacks.on_before_optimizer_step()
# Optimizer step
self.optimizer.step()
if self._compression_manager is not None:
self._compression_manager.callbacks.on_after_optimizer_step()
optimizer_was_run = not self.accelerator.optimizer_step_was_skipped
if optimizer_was_run:
# Delay optimizer scheduling until metrics are generated
if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
self.lr_scheduler.step()
model.zero_grad()
self.state.global_step += 1
self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
if self._compression_manager is not None:
self._compression_manager.callbacks.on_step_end()
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
else:
self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
if self.control.should_epoch_stop or self.control.should_training_stop:
break
if step < 0:
logger.warning(
"There seems to be not a single sample in your epoch_iterator, stopping training at step"
f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
f" num_steps ({max_steps}) higher than the number of available samples."
)
self.control.should_training_stop = True
self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
if self._compression_manager is not None:
self._compression_manager.callbacks.on_epoch_end()
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
if self.control.should_training_stop:
break
if args.past_index and hasattr(self, "_past"):
# Clean the state at the end of training
delattr(self, "_past")
logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
# Wait for everyone to get here so we are sure the model has been saved by process 0.
if is_torch_xla_available():
xm.rendezvous("load_best_model_at_end")
elif args.parallel_mode == ParallelMode.DISTRIBUTED:
dist.barrier()
elif is_sagemaker_mp_enabled():
smp.barrier()
self._load_best_model()
# add remaining tr_loss
self._total_loss_scalar += tr_loss.item()
train_loss = self._total_loss_scalar / self.state.global_step
metrics = speed_metrics(
"train",
start_time,
num_samples=num_train_samples,
num_steps=self.state.max_steps,
num_tokens=num_train_tokens,
)
self.store_flos()
metrics["total_flos"] = self.state.total_flos
metrics["train_loss"] = train_loss
self.is_in_train = False
self._memory_tracker.stop_and_update_metrics(metrics)
self.log(metrics)
run_dir = self._get_output_dir(trial)
checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir)
# Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save.
if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1:
for checkpoint in checkpoints_sorted:
if not os.path.samefile(checkpoint, self.state.best_model_checkpoint):
logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
shutil.rmtree(checkpoint)
self.control = self.callback_handler.on_train_end(args, self.state, self.control)
# Wait for the checkpoint to be uploaded.
self._finish_current_push()
# After training we make sure to retrieve back the original forward pass method
# for the embedding layer by removing the forward post hook.
if self.neftune_noise_alpha is not None:
self._deactivate_neftune(self.model)
if self._compression_manager is not None:
self._compression_manager.callbacks.on_train_end()
return TrainOutput(self.state.global_step, train_loss, metrics)
def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
"""
Will save the model, so you can reload it using `from_pretrained()`.
Will only save from the main process.
"""
if output_dir is None:
output_dir = self.args.output_dir
if self.args.should_save:
self._save(output_dir=output_dir)
# TODO: push to hub if self.args.push_to_hub and not _internal_call
def _save(self, output_dir=None, state_dict=None):
# If we are executing this function, we are the process zero, so we don't check for that.
output_dir = output_dir if output_dir is not None else self.args.output_dir
if os.path.isfile(output_dir):
logger.error(f"Provided path ({output_dir}) should be a directory, not a file")
return
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Saving model checkpoint to {output_dir}")
# If we save using the predefined names, we can load using `from_pretrained`
output_model_file = os.path.join(output_dir, WEIGHTS_NAME)
# Save the config
if self.model.can_generate():
if is_transformers_version(">=", "4.44.99"):
misplaced_generation_parameters = self.model.config._get_non_default_generation_parameters()
if len(misplaced_generation_parameters) > 0:
logger.warning(
"Moving the following attributes in the config to the generation config: "
f"{misplaced_generation_parameters}. You are seeing this warning because you've set "
"generation parameters in the model config, as opposed to in the generation config.",
)
for param_name, param_value in misplaced_generation_parameters.items():
setattr(self.model.generation_config, param_name, param_value)
setattr(self.model.config, param_name, None)
self.model.generation_config.save_pretrained(output_dir)
if self.model.config is not None:
self.model.config.save_pretrained(output_dir)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
# Save the model
if state_dict is None:
state_dict = self.model.state_dict()
if self._compression_manager is not None and hasattr(self._compression_manager.model, "q_config"):
state_dict["best_configure"] = self._compression_manager.model.q_config
torch.save(state_dict, output_model_file)
if self.pruning_config is not None:
self.inc_config.pruning["sparsity"] = round(self.get_model_sparsity(), 2)
self.inc_config.save_pretrained(output_dir)
logger.info(f"Model weights saved in {output_model_file}")
def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
if not self.args.remove_unused_columns:
return dataset
self._set_signature_columns_if_needed()
signature_columns = self._signature_columns
signature_columns += ["teacher_logits"]
ignored_columns = list(set(dataset.column_names) - set(signature_columns))
if len(ignored_columns) > 0:
dset_description = "" if description is None else f"in the {description} set"
logger.info(
f"The following columns {dset_description} don't have a corresponding argument in "
f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
f" If {', '.join(ignored_columns)} are not expected by `{self.model.__class__.__name__}.forward`, "
" you can safely ignore this message."
)
[k for k in signature_columns if k in dataset.column_names]
return dataset.remove_columns(ignored_columns)
@staticmethod
def _get_logits(model_outputs):
output_names = ["logits", "start_logits", "end_logits"]
return tuple(model_outputs.get(name) for name in output_names if name in model_outputs)
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
"""
How the loss is computed by Trainer. By default, all models return the loss in the first element.
"""
if self.label_smoother is not None and "labels" in inputs:
labels = inputs.pop("labels")
else:
labels = None
teacher_outputs = inputs.pop("teacher_logits", None)
outputs = model(**inputs)
# Save past state if it exists
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index]
if labels is not None:
if unwrap_model(model)._get_name() in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
loss = self.label_smoother(outputs, labels, shift_labels=True)
else:
loss = self.label_smoother(outputs, labels)
else:
if isinstance(outputs, dict) and "loss" not in outputs:
raise ValueError(
"The model did not return a loss from the inputs, only the following keys: "
f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
)
# We don't use .loss here since the model may return tuples instead of ModelOutput.
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
if self.distillation_config is not None:
student_outputs = self._get_logits(outputs)
if teacher_outputs is not None:
if len(teacher_outputs.shape) == 3 and teacher_outputs.shape[1] == 2:
teacher_outputs = tuple(teacher_outputs.transpose(1, 0))
else:
self.distillation_config.teacher_model.eval()
self.distillation_config.teacher_model.to(model.device)
teacher_outputs = self.distillation_config.teacher_model(**inputs)
teacher_outputs = self._get_logits(teacher_outputs)
if teacher_outputs is not None and self.distillation_callback is not None:
distillation_loss = self.compute_distillation_loss(student_outputs, teacher_outputs)
loss *= self.distillation_callback.criterion.loss_weights[0]
loss += distillation_loss * self.distillation_callback.criterion.loss_weights[1]
loss /= sum(self.distillation_callback.criterion.loss_weights)
if isinstance(outputs, dict):
outputs["loss"] = loss
else:
outputs[0] = loss
return (loss, outputs) if return_outputs else loss
def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]:
"""
Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors.
"""
if isinstance(data, Mapping):
return type(data)({k: self._prepare_input(v) for k, v in data.items()})
elif isinstance(data, (tuple, list)):
return type(data)(self._prepare_input(v) for v in data)
elif isinstance(data, torch.Tensor):
kwargs = {"device": self.model.device}
if self.deepspeed and data.dtype != torch.int64:
# NLP models inputs are int64 and those get adjusted to the right dtype of the
# embedding. Other models such as wav2vec2's inputs are already float and thus
# may need special handling to match the dtypes of the model
kwargs.update({"dtype": self.args.hf_deepspeed_config.dtype()})
return data.to(**kwargs)
return data
@staticmethod
def _get_logits(model_outputs):
output_names = ["logits", "start_logits", "end_logits"]
return tuple(model_outputs.get(name) for name in output_names if name in model_outputs)
def compute_distillation_loss(self, student_outputs, teacher_outputs):
"""
How the distillation loss is computed given the student and teacher outputs.
"""
distillation_loss = None
temperature = self.distillation_callback.criterion.temperature
for student_output, teacher_output in zip(student_outputs, teacher_outputs):
student_output = student_output / temperature
teacher_output = teacher_output / temperature
loss = self.distillation_callback.criterion.teacher_student_loss_cal(student_output, teacher_output)
distillation_loss = loss if distillation_loss is None else distillation_loss + loss
distillation_loss *= temperature**2
return distillation_loss
def evaluate(
self,
eval_dataset: Optional[Dataset] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
) -> Dict[str, float]:
if self.quantization_config is not None:
logger.warning("Evaluation of quantized models is not supported by the CUDA backend.")
self.model.to("cpu")
if getattr(self.model.config, "backend", None) == "ipex":
self.args.use_ipex = False
self.args.bf16 = False
self.use_cpu_amp = False
if (
getattr(self.model.config, "torch_dtype", None) == "int8"
and getattr(self.model.config, "framework", None) in {"pytorch", "pytorch_fx"}
and self.use_cpu_amp
):
logger.warning(
f"{self.model.config.framework} quantized model doesn't support BFloat16 input, setting `use_cpu_amp` to False."
)
self.use_cpu_amp = False
return super().evaluate(eval_dataset, ignore_keys, metric_key_prefix)
def predict(self, *args, **kwargs):
if self.quantization_config is not None:
logger.warning("Evaluation of quantized models is not supported by the CUDA backend.")
self.model.to("cpu")
return super().predict(*args, **kwargs)
def get_model_sparsity(self):
sparsity = 0.0
if self._compression_manager is not None:
sparsity = self._compression_manager.model.report_sparsity()[-1]
return sparsity
def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval):
# TODO : can be removed once transformers >= v4.38.0
if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
if is_torch_xla_available():
xm.mark_step()
logs: Dict[str, float] = {}
# all_gather + mean() to get average loss over all processes
tr_loss_scalar = self._nested_gather(tr_loss).mean().item()
# reset tr_loss to zero
tr_loss -= tr_loss
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
logs["learning_rate"] = self._get_learning_rate()
self._total_loss_scalar += tr_loss_scalar
self._globalstep_last_logged = self.state.global_step
self.store_flos()
self.log(logs)
metrics = None
if self.control.should_evaluate:
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
self._report_to_hp_search(trial, self.state.global_step, metrics)
# Run delayed LR scheduler now that metrics are populated
if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
metric_to_check = self.args.metric_for_best_model
if not metric_to_check.startswith("eval_"):
metric_to_check = f"eval_{metric_to_check}"
self.lr_scheduler.step(metrics[metric_to_check])
if self.control.should_save:
self._save_checkpoint(model, trial, metrics=metrics)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)