#  Copyright 2021 The HuggingFace Team. All rights reserved.
#  Copyright (c) 2022 Graphcore Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
import copy
import functools
import inspect
import math
import os
import random
import re
import shutil
import sys
import time
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union

# Integrations must be imported before ML frameworks:
import numpy as np
import poptorch
import torch
from huggingface_hub import Repository
from packaging import version
from peft import PeftModel
from poptorch import DataLoaderMode, PoplarExecutor
from poptorch.optim import LAMB, AdamW
from torch import nn, optim
from torch.utils.data import Dataset, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from tqdm.auto import tqdm
from transformers.configuration_utils import PretrainedConfig
from transformers.data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
from transformers.debug_utils import DebugOption, DebugUnderflowOverflow


from transformers.integrations import (  # isort: split
    get_reporting_integration_callbacks,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.optimization import get_scheduler
from transformers.pytorch_utils import is_torch_less_than_1_11
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer import OPTIMIZER_NAME, SCHEDULER_NAME, TRAINER_STATE_NAME, TRAINING_ARGS_NAME
from transformers.trainer_callback import (
    CallbackHandler,
    DefaultFlowCallback,
    PrinterCallback,
    ProgressCallback,
    TrainerCallback,
    TrainerControl,
    TrainerState,
)
from transformers.trainer_pt_utils import (
    IterableDatasetShard,
    LabelSmoother,
    LengthGroupedSampler,
    find_batch_size,
    get_parameter_names,
    nested_concat,
    nested_detach,
    nested_numpify,
    nested_truncate,
    reissue_pt_warnings,
)
from transformers.trainer_utils import (
    PREFIX_CHECKPOINT_DIR,
    EvalLoopOutput,
    EvalPrediction,
    HubStrategy,
    IntervalStrategy,
    PredictionOutput,
    RemoveColumnsCollator,
    TrainerMemoryTracker,
    TrainOutput,
    denumpify_detensorize,
    get_last_checkpoint,
    has_length,
    set_seed,
    speed_metrics,
)
from transformers.utils import (
    CONFIG_NAME,
    WEIGHTS_INDEX_NAME,
    WEIGHTS_NAME,
    find_labels,
    get_full_repo_name,
    is_datasets_available,
)

from optimum.graphcore.version import __version__
from optimum.utils import logging

from .data.data_collator import pad_on_batch_axis
from .ipu_configuration import IPU_CONFIG_NAME, IPUConfig
from .modelcard import IPUTrainingSummary
from .modeling_utils import to_pipelined
from .trainer_utils import _WorkerInit
from .training_args import IPUTrainingArguments


if is_datasets_available():
    import datasets


if TYPE_CHECKING:
    import optuna

logger = logging.get_logger(__name__)

_is_torch_generator_available = False

DEFAULT_CALLBACKS = [DefaultFlowCallback]
DEFAULT_PROGRESS_CALLBACK = ProgressCallback

# TODO: Import from transformers.utils when updating transformers version.
ADAPTER_WEIGHTS_NAME = "adapter_model.bin"


@dataclass
class IPUTrainerState(TrainerState):
    start_time: float = -1.0


class IPUTrainer:
    """
    `IPUTrainer` is a simple but feature-complete training and evaluation
      loop on Graphcore IPUs for PyTorch, optimized for 🤗 Transformers.

    Args:
        model ([`transformers.PreTrainedModel`] or `torch.nn.Module`, *optional*):
            The model to train, evaluate or use for predictions. If not provided, a `model_init` function must be passed.

            <Tip>

            [`IPUTrainer`] is optimized to work with the [`transformers.PreTrainedModel`] class provided by the 🤗 Transformers
            library. You can still use your own models defined as `torch.nn.Module` as long as they work in the same way as
            the 🤗 Transformers models.

            </Tip>

        args ([`IPUTrainingArguments`], *optional*):
            The arguments to tweak for training. Will default to a basic
            instance of [`IPUTrainingArguments`] with `output_dir` set to a
            directory named *tmp_trainer* in the current directory if not
            provided.
        data_collator ([`transformers.data.data_collator.DataCollator`], *optional*):
            The function to use to form a batch from a list of elements of
            `train_dataset` or `eval_dataset`. Will default to
            [`transformers.data.default_data_collator`] if no `tokenizer` is
            provided, or an instance of
            [`~transformers.data.DataCollatorWithPadding`] otherwise.
        train_dataset (`torch.utils.data.Dataset` or `torch.utils.data.IterableDataset`, *optional*):
            The dataset to use for training. If it is a [`~datasets.Dataset`]
            dataset, the columns not accepted by the
            `model.forward()` method are automatically removed.

            Note that if it's a `torch.utils.data.IterableDataset` dataset with
            some randomization and you are training in a distributed fashion,
            your iterable dataset should either use an internal attribute
            `generator` that is a `torch.Generator` object for the randomization that
            must be identical on all processes (and the trainer will manually
            set the seed of this `generator` at each epoch) or have a
            `set_epoch()` method that internally sets the seed of the RNGs used.
        eval_dataset (Union[`torch.utils.data.Dataset`, Dict[str, `torch.utils.data.Dataset`]), *optional*):
             The dataset to use for evaluation. If it is a [`~datasets.Dataset`] dataset, the columns not accepted by the
             `model.forward()` method are automatically removed. If it is a dictionary, it will evaluate on each
             dataset prepending the dictionary key to the metric name.
        tokenizer ([`transformers.PreTrainedTokenizerBase`], *optional*):
            The tokenizer used to preprocess the data. If provided, it will be
            used to automatically pad the inputs to the maximum length when
            batching inputs, and it will be saved along the model to make it
            easier to rerun an interrupted training or reuse the fine-tuned
            model.
        model_init (`Callable[[], transformers.PreTrainedModel]`, *optional*):
            A function that instantiates the model to be used. If provided, each call to [`IPUTrainer.train`] will start
            from a new instance of the model as given by this function.

            The function may have no arguments, or a single argument containing the optuna/Ray Tune/SigOpt trial object, to
            be able to choose different architectures according to hyper parameters (such as layer count, sizes of
            inner layers and dropout probabilities). **Note: this feature is not supported for now.**

        compute_metrics (`Callable[[~transformers.trainer_utils.EvalPrediction], Dict]`, *optional*):
            The function that will be used to compute metrics at evaluation. Must take a
            [`~transformers.trainer_utils.EvalPrediction`] and return a dictionary of strings to metric values.
        callbacks (List of [`transformers.trainer_callback.TrainerCallback`], *optional*):
            A list of callbacks to customize the training loop. Will add those to the list of default callbacks
            detailed in [here](callback).

            If you want to remove one of the default callbacks used, use the [`Trainer.remove_callback`] method.
        optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*): A tuple
            containing the optimizer and the scheduler to use. Will default to an instance of `poptorch.AdamW` on your model
            and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
        preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*):
            A function that preprocesses the logits right before caching them at each evaluation step. Must take two
            tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
            by this function will be reflected in the predictions received by `compute_metrics`.

            Note that the labels (second parameter) will be `None` if the dataset does not have them.
    """

    from transformers.trainer_pt_utils import log_metrics, metrics_format, save_metrics, save_state

    from .trainer_pt_utils import _get_learning_rate

    def __init__(
        self,
        model: Union[PreTrainedModel, nn.Module] = None,
        ipu_config: IPUConfig = None,
        args: IPUTrainingArguments = None,
        data_collator: Optional[DataCollator] = None,
        eval_data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Dataset] = None,
        tokenizer: Optional[PreTrainedTokenizerBase] = None,
        model_init: Callable[[], PreTrainedModel] = None,
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
        callbacks: Optional[List[TrainerCallback]] = None,
        optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
        preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
        force_to_pipelined: bool = False,
    ):
        if args is None:
            output_dir = "tmp_trainer"
            logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.")
            args = IPUTrainingArguments(output_dir=output_dir)
        self.args = args
        # Seed must be set before instantiating the model when using model
        set_seed(self.args.seed)
        self.is_in_train = False

        # memory metrics - must set up as early as possible
        self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
        self._memory_tracker.start()

        # set the correct log level depending on the node
        log_level = args.get_process_log_level()
        logging.set_verbosity(log_level)

        # force device and distributed setup init explicitly
        args._setup_devices

        if model is None:
            if model_init is not None:
                raise RuntimeError("`model_init` is not supported by `IPUTrainer` yet")
            else:
                raise RuntimeError("`IPUTrainer` requires either a `model` or `model_init` argument")
        else:
            if model_init is not None:
                warnings.warn(
                    "`IPUTrainer` requires either a `model` or `model_init` argument, but not both. "
                    "`model_init` will overwrite your model when calling the `train` method. This will become a fatal error in the next release.",
                    FutureWarning,
                )
            self.model_init = model_init

        # TODO: not sure about setting the data_collator?
        default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
        self.data_collator = data_collator if data_collator is not None else default_collator
        # If no eval_data_collator is specified then use the train data_collator
        self.eval_data_collator = eval_data_collator if eval_data_collator is not None else self.data_collator
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
        self.tokenizer = tokenizer

        self.ipu_config = copy.deepcopy(ipu_config)
        # set replication factor using n_ipu (can be overruled by ipu_config_overrides)
        if (n_ipu := self.args.n_ipu) is not None:
            if self.ipu_config.replication_factor > 1 or self.ipu_config.inference_replication_factor > 1:
                warnings.warn(
                    "IPUTrainer is overwriting the replication factors set in self.ipu_config because `--n_ipu` was provided."
                )
            self.ipu_config.replication_factor = n_ipu // self.ipu_config.ipus_per_replica
            self.ipu_config.inference_replication_factor = n_ipu // self.ipu_config.inference_ipus_per_replica
        if self.ipu_config.replication_factor > 1 or self.ipu_config.inference_replication_factor > 1:
            os.environ["TOKENIZERS_PARALLELISM"] = "true"
        if args.ipu_config_overrides:
            logger.info(f"Overriding IPU config: {args.ipu_config_overrides}")
            self.ipu_config.update_from_string(args.ipu_config_overrides)
        self.ipu_config.seed = self.args.seed
        self.opts = self.ipu_config.to_options(compile_only=args.compile_only)
        self.eval_opts = self.ipu_config.to_options(for_inference=True, compile_only=args.compile_only)

        # If batch axis padding enabled, wrap train/eval data collators with `pad_on_batch_axis` wrapper
        if self.args.pad_on_batch_axis:
            logger.info(
                "Padding on batch axis enabled. Each batch fed to the compiled model during training will have the proper size"
            )
            if self.args.do_train:
                data_collator_wrapper = pad_on_batch_axis(
                    self.args.per_device_train_batch_size * self.ipu_config.batch_size_factor()
                )
                self.data_collator = data_collator_wrapper(self.data_collator)

            if self.args.do_eval:
                data_collator_wrapper = pad_on_batch_axis(
                    self.args.per_device_eval_batch_size * self.ipu_config.batch_size_factor(for_inference=True),
                )
                self.eval_data_collator = data_collator_wrapper(self.eval_data_collator)

        self.model = to_pipelined(model, self.ipu_config, force=force_to_pipelined)
        self.model.parallelize(**self.ipu_config.parallelize_kwargs)

        self.original_model = model

        if not self.args.fp32:
            self.model = self.model.half()

        self.training_model = None
        self.inference_model = None

        self.compute_metrics = compute_metrics
        self.preprocess_logits_for_metrics = preprocess_logits_for_metrics
        self.optimizer, self.lr_scheduler = optimizers
        if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):
            raise RuntimeError(
                "Passing a `model_init` is incompatible with providing the `optimizers` argument. "
                "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
            )

        if self.optimizer is not None and not isinstance(self.optimizer, poptorch.optim.Optimizer):
            self.optimizer = self.pytorch_optimizer_to_poptorch(self.optimizer, model, self.model)

        default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
        callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
        self.callback_handler = CallbackHandler(
            callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler
        )
        self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)

        # Will be set to True by `self._setup_loggers()` on first call to `self.log()`.
        self._loggers_initialized = False

        # Create clone of distant repo and output directory if needed
        if self.args.push_to_hub:
            self.init_git_repo()

        if self.args.should_save:
            os.makedirs(self.args.output_dir, exist_ok=True)

        if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
            raise ValueError("`data_collator` should be a simple callable (function, class with `__call__`).")

        if not callable(self.eval_data_collator) and callable(getattr(self.eval_data_collator, "collate_batch", None)):
            raise ValueError("`eval_data_collator` should be a simple callable (function, class with `__call__`).")

        if args.max_steps > 0:
            logger.info("max_steps is given. It will override any value given in num_train_epochs")

        if train_dataset is not None and not isinstance(train_dataset, collections.abc.Sized) and args.max_steps <= 0:
            raise ValueError("train_dataset does not implement __len__. max_steps has to be specified")

        self._signature_columns = None

        # Label smoothing
        if self.args.label_smoothing_factor != 0:
            self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor)
        else:
            self.label_smoother = None

        self.state = IPUTrainerState()
        self.control = TrainerControl()
        # Internal variable to count flos in each process, will be accumulated in `self.state.total_flos` then
        # returned to 0 every time flos need to be logged
        self.current_flos = 0
        self.hp_search_backend = None
        self.use_tune_checkpoints = False
        default_label_names = find_labels(model.__class__)
        self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
        self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)

        # very last
        self._memory_tracker.stop_and_update_metrics()

        # If compile-only then compile and exit
        if args.compile_only:
            logger.info("Called with compile_only=True. Compiling models then exiting.")
            if args.do_train:
                train_dl = self.get_train_dataloader()
                model = self.wrap_model(self.model)
                try:
                    model_inputs = next(iter(train_dl))
                except StopIteration:
                    raise ValueError(
                        "Couldn't get first sample from dataloader, please check for warnings "
                        "during dataloader construction."
                    )
                self.compile_model(model, model_inputs, log=True)
            if args.do_eval:
                # Same thing with _wrap_and_compile_for_evaluation
                eval_dl = self.get_eval_dataloader()
                model = self._wrap_and_compile_model_for_evaluation(eval_dl, False)
            logger.info("Exiting after compiling models with compile_only=True")
            sys.exit(0)

    def pytorch_optimizer_to_poptorch(
        self,
        optimizer: optim.Optimizer,
        model: Union[PreTrainedModel, nn.Module],
        pipelined_model: Union[PreTrainedModel, nn.Module],
    ) -> poptorch.optim.Optimizer:
        """
        Converts a PyTorch optimizer to a PopTorch optimizer.

        Args:
            optimizer (`torch.optim.Optimizer`):
                The PyTorch optimizer to convert.
            model (`[transformers.PreTrainedModel]` or `torch.nn.Module`):
                The original model the optimizer has parameter references to.
            pipelined_model (`[transformers.PreTrainedModel] or `torch.nn.Module`):
                The pipelined version of the model. Its parameters will be used by the PopTorch optimizer.

        Returns:
            `poptorch.optim.Optimizer`: The converted PopTorch optimizer.
        """
        first_order_type = torch.float32 if self.args.fp32 else torch.float16
        optimizer_kwargs = {
            "loss_scaling": self.args.loss_scaling,
            "accum_type": first_order_type,
            "first_order_momentum_accum_type": first_order_type,
            "second_order_momentum_accum_type": torch.float32,
        }
        # TODO: disabled max_grad_norm because it make things fail, fix it.
        max_grad_norm = self.args.max_grad_norm
        self.args.max_grad_norm = None
        pytorch_to_poptorch_mapping = {
            optim.SGD: (poptorch.optim.SGD, {"loss_scaling": self.args.loss_scaling}),
            optim.Adam: (poptorch.optim.Adam, {"max_grad_norm": self.args.max_grad_norm, **optimizer_kwargs}),
            optim.AdamW: (poptorch.optim.AdamW, {"max_grad_norm": self.args.max_grad_norm, **optimizer_kwargs}),
            optim.RMSprop: (poptorch.optim.RMSprop, optimizer_kwargs),
        }
        self.args.max_grad_norm = max_grad_norm
        poptorch_optimizer_cls, kwargs = pytorch_to_poptorch_mapping.get(optimizer.__class__, (None, {}))
        if poptorch_optimizer_cls is None:
            raise KeyError(f"Could not find a PopTorch counterpart for optimizer {optimizer.__class__.__name__}")

        # Some dummy value that should be overridden by the real value with .load_state_dict, using some absurd value to
        # make clear if the value is not properly overridden.
        dummy_lr = 1e4
        poptorch_optimizer = poptorch_optimizer_cls(optimizer.param_groups, lr=dummy_lr, **kwargs)
        poptorch_optimizer.load_state_dict({"ipu_state": None, "ipu_param": None, **optimizer.state_dict()})

        # Currently poptorch_optimizer contains references to the original model parameters, so we need to change those
        # to references to the pipelined model parameters.
        id2name = {id(param): name for name, param in model.named_parameters()}
        name2param = dict(pipelined_model.named_parameters())
        for group in poptorch_optimizer.param_groups:
            for idx, param in enumerate(group["params"]):
                group["params"][idx] = name2param[id2name[id(param)]]

        return poptorch_optimizer

    def compile_model(
        self,
        model: poptorch.PoplarExecutor,
        sample_batch: Union[Dict[str, torch.Tensor], Tuple[torch.Tensor]],
        log: bool = False,
    ):
        """
        Compiles the model with PopTorch.

        Args:
            model (`poptorch.PoplarExecutor`):
                The model to compile (already wrapped).
            sample_batch (`Dict[str, torch.Tensor]` or `Tuple[torch.Tensor]`):
                The inputs to use for the compilation. This will set the input shapes that the compiled model can accept.
            log (`bool`, *optional*, defaults to `False`):
                If `True`, logs that the compilation is in progress.
        """
        # Skipping compilation if the model was already compiled.
        if model.isCompiled():
            return
        if log:
            logger.info("Compiling Model...")
        sample_batch = self._prepare_inputs(sample_batch)
        start_compile = time.perf_counter()
        if isinstance(sample_batch, tuple):
            model.compile(*sample_batch)
        else:
            model.compile(**sample_batch)
        duration_compilation = time.perf_counter() - start_compile
        if log:
            logger.info(f"Compiled/Loaded model in {duration_compilation} secs")

    def add_callback(self, callback):
        """
        Adds a callback to the current list of [`~transformer.TrainerCallback`].

        Args:
           callback (`type` or [`~transformer.TrainerCallback`]):
               A [`~transformer.TrainerCallback`] class or an instance of [`~transformer.TrainerCallback`]. In the
               first case, will instantiate a member of that class.
        """
        self.callback_handler.add_callback(callback)

    def pop_callback(self, callback):
        """
        Removes a callback from the current list of [`~transformer.TrainerCallback`] and returns it.

        If the callback is not found, returns `None` (and no error is raised).

        Args:
           callback (`type` or [`~transformer.TrainerCallback`]):
               A [`~transformer.TrainerCallback`] class or an instance of [`~transformer.TrainerCallback`]. In the
               first case, will pop the first member of that class found in the list of callbacks.

        Returns:
            [`~transformer.TrainerCallback`]: The callback was removed, if found.
        """
        return self.callback_handler.pop_callback(callback)

    def remove_callback(self, callback):
        """
        Removes a callback from the current list of [`~transformer.TrainerCallback`].

        Args:
           callback (`type` or [`~transformer.TrainerCallback`]):
               A [`~transformer.TrainerCallback`] class or an instance of [`~transformer.TrainerCallback`]. In the
               first case, will remove the first member of that class found in the list of callbacks.
        """
        self.callback_handler.remove_callback(callback)

    def _set_signature_columns_if_needed(self):
        if self._signature_columns is None:
            # Inspect model forward signature to keep only the arguments it accepts.
            signature = inspect.signature(self.model.forward)
            self._signature_columns = list(signature.parameters.keys())
            # Labels may be named label or label_ids, the default data collator handles that.
            self._signature_columns += list(set(["label", "label_ids"] + self.label_names))

    def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
        if not self.args.remove_unused_columns:
            return dataset
        self._set_signature_columns_if_needed()
        signature_columns = self._signature_columns

        ignored_columns = list(set(dataset.column_names) - set(self._signature_columns))
        if len(ignored_columns) > 0:
            dset_description = "" if description is None else f"in the {description} set "
            logger.info(
                f"The following columns {dset_description} don't have a corresponding argument in "
                f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
            )

        columns = [k for k in signature_columns if k in dataset.column_names]

        if version.parse(datasets.__version__) < version.parse("1.4.0"):
            dataset.set_format(
                type=dataset.format["type"], columns=columns, format_kwargs=dataset.format["format_kwargs"]
            )
            return dataset
        else:
            return dataset.remove_columns(ignored_columns)

    def _get_collator_with_removed_columns(
        self, data_collator: Callable, description: Optional[str] = None
    ) -> Callable:
        """Wraps the data collator in a callable removing unused columns."""
        if not self.args.remove_unused_columns:
            return data_collator
        self._set_signature_columns_if_needed()
        signature_columns = self._signature_columns

        remove_columns_collator = RemoveColumnsCollator(
            data_collator=data_collator,
            signature_columns=signature_columns,
            logger=logger,
            description=description,
            model_name=self.model.__class__.__name__,
        )
        return remove_columns_collator

    def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
        if not isinstance(self.train_dataset, collections.abc.Sized):
            return None
        generator = None
        if _is_torch_generator_available:
            generator = torch.Generator()
            # for backwards compatibility, we generate a seed here (which is sampled from a generator seeded with
            # `args.seed`) if data_seed isn't provided.
            # Further on in this method, we default to `args.seed` instead.
            if self.args.data_seed is None:
                seed = int(torch.empty((), dtype=torch.int64).random_().item())
            else:
                seed = self.args.data_seed
            generator.manual_seed(seed)

        seed = self.args.data_seed if self.args.data_seed is not None else self.args.seed
        combined_batch_size = self.args.per_device_train_batch_size * self.ipu_config.batch_size_factor()

        # Build the sampler.
        if self.args.group_by_length:
            if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset):
                lengths = (
                    self.train_dataset[self.args.length_column_name]
                    if self.args.length_column_name in self.train_dataset.column_names
                    else None
                )
            else:
                lengths = None
            model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None

            return LengthGroupedSampler(
                combined_batch_size,
                dataset=self.train_dataset,
                lengths=lengths,
                model_input_name=model_input_name,
                generator=generator,
            )

        else:
            return RandomSampler(self.train_dataset)

    def _check_dataset_can_fill_batch(self, dataset: torch.utils.data.Dataset, for_inference: bool = False) -> None:
        replication_factor = (
            self.ipu_config.inference_replication_factor if for_inference else self.ipu_config.replication_factor
        )
        gradient_accumulation_steps = 1 if for_inference else self.ipu_config.gradient_accumulation_steps
        device_iterations = (
            self.ipu_config.inference_device_iterations if for_inference else self.ipu_config.device_iterations
        )
        micro_batch_size = (
            self.args.per_device_eval_batch_size if for_inference else self.args.per_device_train_batch_size
        )
        global_batch_size = micro_batch_size * replication_factor * gradient_accumulation_steps * device_iterations

        try:
            len(dataset)
        except Exception:
            # If the length of the dataset cannot be determined skip the checks
            return
        if len(dataset) < global_batch_size:
            mode_str = "inference_" if for_inference else ""
            logger.warning(
                f"The provided dataset is of length {len(dataset)}, but the total dataset batch size is {global_batch_size}. "
                f"This batch size is calculated as:\n"
                f"  per_device_{'eval' if for_inference else 'train'}_batch_size={micro_batch_size}\n"
                f"* {mode_str}{replication_factor=}\n"
                f"* {mode_str}{gradient_accumulation_steps=}\n"
                f"* {mode_str}{device_iterations=}\n"
                "Please disregard this warning if you believe the dataset is reporting an incorrect length, such as 1."
            )

    def get_train_dataloader(self) -> poptorch.DataLoader:
        """
        Returns the training `poptorch.DataLoader`.

        Will not use a sampler if `train_dataset` does not implement `__len__` and will use a random sampler (adapted to distributed
        training if necessary) otherwise.

        Subclass and override this method if you want to inject some custom behavior.
        """
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")

        train_dataset = self.train_dataset
        data_collator = self.data_collator
        if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
            train_dataset = self._remove_unused_columns(train_dataset, description="training")
        else:
            data_collator = self._get_collator_with_removed_columns(data_collator, description="training")

        poptorch_specific_kwargs = {
            "auto_distributed_partitioning": not isinstance(train_dataset, torch.utils.data.IterableDataset),
            "mode": self.args.dataloader_mode,
            "worker_init_fn": _WorkerInit(123),
        }

        if isinstance(train_dataset, torch.utils.data.IterableDataset):
            return poptorch.DataLoader(
                self.opts,
                train_dataset,
                batch_size=self.args.train_batch_size,
                collate_fn=self.data_collator,
                num_workers=self.args.dataloader_num_workers,
                drop_last=self.args.dataloader_drop_last,
                pin_memory=self.args.dataloader_pin_memory,
                **poptorch_specific_kwargs,
            )

        train_sampler = self._get_train_sampler()
        combined_batch_size = self.args.per_device_train_batch_size * self.ipu_config.batch_size_factor()
        rebatched_worker_size = (
            2 * (combined_batch_size // self.args.dataloader_num_workers)
            if self.args.dataloader_num_workers
            else combined_batch_size
        )

        self._check_dataset_can_fill_batch(train_dataset, for_inference=False)

        return poptorch.DataLoader(
            self.opts,
            train_dataset,
            batch_size=self.args.per_device_train_batch_size,
            sampler=train_sampler,
            collate_fn=self.data_collator,
            num_workers=self.args.dataloader_num_workers,
            pin_memory=self.args.dataloader_pin_memory,
            rebatched_worker_size=rebatched_worker_size,
            **poptorch_specific_kwargs,
        )

    def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:
        return SequentialSampler(eval_dataset)

    def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> poptorch.DataLoader:
        """
        Returns the evaluation `poptorch.DataLoader`.

        Subclass and override this method if you want to inject some custom behavior.

        Args:
            eval_dataset (`torch.utils.data.Dataset`, *optional*):
                If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`] dataset, the columns not accepted
                by the `model.forward()` method are automatically removed. It must implement `__len__`.
        """
        poptorch_specific_kwargs = {
            "auto_distributed_partitioning": not isinstance(eval_dataset, torch.utils.data.IterableDataset),
            "mode": DataLoaderMode.Sync,
            "worker_init_fn": _WorkerInit(123),
        }

        if eval_dataset is None and self.eval_dataset is None:
            raise ValueError("Trainer: evaluation requires an eval_dataset.")
        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
        data_collator = self.eval_data_collator

        if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
            eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation")
        else:
            data_collator = self._get_collator_with_removed_columns(data_collator, description="evaluation")

        if isinstance(eval_dataset, torch.utils.data.IterableDataset):
            return poptorch.DataLoader(
                self.eval_opts,
                eval_dataset,
                batch_size=self.args.per_device_eval_batch_size,
                collate_fn=data_collator,
                num_workers=self.args.dataloader_num_workers,
                pin_memory=self.args.dataloader_pin_memory,
                **poptorch_specific_kwargs,
            )

        eval_sampler = self._get_eval_sampler(eval_dataset)

        self._check_dataset_can_fill_batch(eval_dataset, for_inference=True)

        return poptorch.DataLoader(
            self.eval_opts,
            eval_dataset,
            sampler=eval_sampler,
            batch_size=self.args.per_device_eval_batch_size,
            collate_fn=data_collator,
            drop_last=self.args.dataloader_drop_last,
            num_workers=self.args.dataloader_num_workers,
            pin_memory=self.args.dataloader_pin_memory,
            **poptorch_specific_kwargs,
        )

    def get_test_dataloader(self, test_dataset: Dataset) -> poptorch.DataLoader:
        """
        Returns the test `poptorch.DataLoader`.

        Subclass and override this method if you want to inject some custom behavior.

        Args:
            test_dataset (`torch.utils.data.Dataset`, *optional*):
                The test dataset to use. If it is a [`~datasets.Dataset`] dataset, the columns not accepted by the
                `model.forward()` method are automatically removed. It must implement `__len__`.
        """
        poptorch_specific_kwargs = {
            "auto_distributed_partitioning": not isinstance(test_dataset, torch.utils.data.IterableDataset),
            "mode": DataLoaderMode.Sync,
            "worker_init_fn": _WorkerInit(123),
        }

        data_collator = self.eval_data_collator
        if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
            test_dataset = self._remove_unused_columns(test_dataset, description="test")
        else:
            data_collator = self._get_collator_with_removed_columns(data_collator, description="test")

        if isinstance(test_dataset, torch.utils.data.IterableDataset):
            return poptorch.DataLoader(
                self.eval_opts,
                test_dataset,
                batch_size=self.args.per_device_eval_batch_size,
                collate_fn=data_collator,
                num_workers=self.args.dataloader_num_workers,
                pin_memory=self.args.dataloader_pin_memory,
                **poptorch_specific_kwargs,
            )

        test_sampler = self._get_eval_sampler(test_dataset)

        self._check_dataset_can_fill_batch(test_dataset, for_inference=True)

        # We use the same batch_size as for eval.
        return poptorch.DataLoader(
            self.eval_opts,
            test_dataset,
            sampler=test_sampler,
            batch_size=self.args.per_device_eval_batch_size,
            collate_fn=data_collator,
            drop_last=self.args.dataloader_drop_last,
            pin_memory=self.args.dataloader_pin_memory,
            **poptorch_specific_kwargs,
        )

    def create_optimizer_and_scheduler(self, num_training_steps: int):
        """
        Sets up the optimizer and the learning rate scheduler.

        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
        trainer's init through `optimizers`, or subclass and override this method (or `create_optimizer` and/or
        `create_scheduler`) in a subclass.
        """
        self.create_optimizer()
        self.create_scheduler(num_training_steps=num_training_steps, optimizer=self.optimizer)

    def create_optimizer(self):
        """
        Sets up the optimizer.

        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
        trainer's init through `optimizers`, or subclass and override this method in a subclass.
        """
        if self.optimizer is None:
            decay_parameters = get_parameter_names(self.model, [nn.LayerNorm])
            decay_parameters = {name for name in decay_parameters if "bias" not in name}
            if self.args.lamb or self.args.lamb_no_bias_correction:
                bias_parameters = {n for n, _ in self.model.named_parameters() if "bias" in n}
                optimizer_grouped_parameters = [
                    {
                        "params": [
                            p for n, p in self.model.named_parameters() if (n in decay_parameters and p.requires_grad)
                        ],
                        "weight_decay": self.args.weight_decay,
                    },
                    {
                        # Disable LAMB updates for bias parameters
                        "params": [
                            p for n, p in self.model.named_parameters() if (n in bias_parameters and p.requires_grad)
                        ],
                        "weight_decay": 0.0,
                        "max_weight_norm": 0.0,
                    },
                    {
                        "params": [
                            p
                            for n, p in self.model.named_parameters()
                            if n not in decay_parameters and n not in bias_parameters and p.requires_grad
                        ],
                        "weight_decay": 0.0,
                    },
                ]
                optimizer_cls = LAMB
                optimizer_kwargs = {
                    "max_weight_norm": None,
                    "bias_correction": not self.args.lamb_no_bias_correction,
                    "eps": self.args.adam_epsilon,
                }
            else:
                optimizer_grouped_parameters = [
                    {
                        "params": [
                            p for n, p in self.model.named_parameters() if (n in decay_parameters and p.requires_grad)
                        ],
                        "weight_decay": self.args.weight_decay,
                    },
                    {
                        "params": [
                            p
                            for n, p in self.model.named_parameters()
                            if (n not in decay_parameters and p.requires_grad)
                        ],
                        "weight_decay": 0.0,
                    },
                ]
                optimizer_cls = AdamW
                optimizer_kwargs = {
                    # TODO: disabled max_grad_norm because it make things fail, fix it.
                    #  "max_grad_norm": self.args.max_grad_norm,
                    "betas": (self.args.adam_beta1, self.args.adam_beta2),
                    "eps": self.args.adam_epsilon,
                    "bias_correction": False,
                }

            first_order_type = torch.float32 if self.args.fp32 else torch.float16
            optimizer_kwargs["lr"] = self.args.learning_rate
            optimizer_kwargs["loss_scaling"] = self.args.loss_scaling
            optimizer_kwargs["accum_type"] = first_order_type
            optimizer_kwargs["first_order_momentum_accum_type"] = first_order_type
            optimizer_kwargs["second_order_momentum_accum_type"] = torch.float32

            self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)

            if self.args.lamb or self.args.lamb_no_bias_correction:
                self.optimizer.variable_attrs.markAsConstant("max_weight_norm")

            self.optimizer.variable_attrs.markAsConstant("weight_decay")

        return self.optimizer

    def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
        """
        Sets up the scheduler. The optimizer of the trainer must have been set up either before this method is called or is
        passed as an argument.

        Args:
            num_training_steps (int): The number of training steps to execute.
        """
        optimizer = self.optimizer if optimizer is None else optimizer
        if self.lr_scheduler is None:
            self.lr_scheduler = get_scheduler(
                self.args.lr_scheduler_type,
                optimizer=optimizer,
                num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
                num_training_steps=num_training_steps,
            )
            optimizer._step_count = 1
        elif isinstance(self.lr_scheduler, functools.partial):
            self.lr_scheduler = self.lr_scheduler(optimizer)
        return self.lr_scheduler

    def num_examples(self, dataloader: poptorch.DataLoader) -> int:
        """
        Returns the number of samples in a `poptorch.DataLoader` object by accessing its dataset. When
        `poptorch.DataLoader.dataset` does not exist or has no length, returns the best estimate best it can.
        """
        return len(dataloader.dataset)

    def wrap_model(self, model: Union[PreTrainedModel, PoplarExecutor], training=True) -> PoplarExecutor:
        """
        Wraps a model for PopTorch, either for training or for inference.

        Args:
            model ([`transformers.PreTrainedModel`] or `poptorch.PoplarExecutor`):
                The model to wrap.
            training (`bool`, *optional*, defaults to `True`):
                If `True`, wraps the model for training. If `False`, does not
                wrap the model for training.

        Returns:
            `poptorch.PoplarExecutor`: The wrapped model.

        """
        wrapped = None
        if isinstance(model, PoplarExecutor):
            wrapped = model
        elif training:
            if self.training_model is None:
                model.deparallelize()
                model.ipu_config.train()
                model.parallelize(**model.ipu_config.parallelize_kwargs)
                self.create_optimizer()
                self.training_model = poptorch.trainingModel(
                    model.train(), options=self.opts, optimizer=self.optimizer
                )
            wrapped = self.training_model
        else:
            if self.inference_model is None:
                model.deparallelize()
                model.ipu_config.eval()
                model.parallelize(**model.ipu_config.inference_parallelize_kwargs)
                self.inference_model = poptorch.inferenceModel(model.eval(), options=self.eval_opts)
            wrapped = self.inference_model

        # Attaching to device when the model that is being access was already compiled but detached from previous loop.
        if wrapped.isCompiled() and not wrapped.isAttachedToDevice():
            wrapped.attachToDevice()
        return wrapped

    def _detach_training_model(self):
        """
        Detach the training model from IPUs.
        """
        self.training_model.detachFromDevice()

    def _detach_inference_model(self):
        """
        Detach the inference model from IPUs.
        """
        self.inference_model.detachFromDevice()

    def _reattach_training_model(self):
        """
        Reattach the training model to IPUs.
        """
        self.training_model.attachToDevice()

    def train(
        self,
        resume_from_checkpoint: Optional[Union[str, bool]] = None,
        trial: Union["optuna.Trial", Dict[str, Any]] = None,
        ignore_keys_for_eval: Optional[List[str]] = None,
        **kwargs,
    ):
        """
        Main training entry point.

        Args:
            resume_from_checkpoint (`str` or `bool`, *optional*):
                Indicates that training will resume from the model, optimizer or
                scheduler states loaded here. If `str`, local path to a saved
                checkpoint as saved by a previous instance of [`IPUTrainer`]. If
                `bool` and `True`, load the last checkpoint in *args.output_dir*
                as saved by a previous instance of [`IPUTrainer`].
            trial (`optuna.Trial` or `Dict[str, Any]`, *optional*):
                The trial run or the hyperparameter dictionary for a
                hyperparameter search. **Note**: Feature not supported.
            ignore_keys_for_eval (`List[str]`, *optional*)
                A list of keys in the output of your model (if it is a
                dictionary) that should be ignored when gathering predictions
                for evaluation during the training.
            kwargs:
                Additional keyword arguments used to hide deprecated arguments.
        """
        if resume_from_checkpoint is False:
            resume_from_checkpoint = None

        # memory metrics - must set up as early as possible
        self._memory_tracker.start()

        args = self.args

        self.is_in_train = True

        if "model_path" in kwargs:
            resume_from_checkpoint = kwargs.pop("model_path")
            warnings.warn(
                "`model_path` is deprecated and will be removed in a future version. Use `resume_from_checkpoint` "
                "instead.",
                FutureWarning,
            )
        if len(kwargs) > 0:
            raise TypeError(f"train() received got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.")

        # Load potential model checkpoint
        if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint:
            resume_from_checkpoint = get_last_checkpoint(args.output_dir)
            if resume_from_checkpoint is None:
                raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})")

        if resume_from_checkpoint is not None:
            self._load_from_checkpoint(resume_from_checkpoint)

        return self._inner_training_loop(
            args=args, resume_from_checkpoint=resume_from_checkpoint, ignore_keys_for_eval=ignore_keys_for_eval
        )

    def _inner_training_loop(
        self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None
    ):
        # For now, it will always be None.
        if batch_size is None:
            batch_size = args.per_device_train_batch_size

        # Data loader and number of training steps
        train_dataloader = self.get_train_dataloader()

        # Setting up training control variables:
        # number of training epochs: num_train_epochs
        # number of training steps per epoch: num_update_steps_per_epoch
        # total number of training steps to execute: max_steps
        total_train_batch_size = batch_size * self.ipu_config.batch_size_factor()

        len_dataloader = None
        if has_length(train_dataloader):
            # No need to divide by the number of gradient accumulation steps as poptorch already accounts for that.
            len_dataloader = len(train_dataloader)
            num_update_steps_per_epoch = len_dataloader
            num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
            if args.max_steps > 0:
                max_steps = args.max_steps
                num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
                    args.max_steps % num_update_steps_per_epoch > 0
                )
            else:
                max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
                num_train_epochs = math.ceil(args.num_train_epochs)
            # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
            # the best we can do.
            num_train_samples = max_steps * total_train_batch_size
        elif args.max_steps > 0:  # Rely on max_steps when dataloader does not have a working size
            max_steps = args.max_steps
            # Setting a very large number of epochs so we go as many times as necessary over the iterator.
            num_train_epochs = sys.maxsize
            num_update_steps_per_epoch = max_steps
            num_train_samples = args.max_steps * total_train_batch_size
        else:
            raise ValueError(
                "args.max_steps must be set to a positive value if dataloader does not have a length, was"
                f" {args.max_steps}"
            )

        if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
            debug_overflow = DebugUnderflowOverflow(self.model)  # noqa

        self.state = IPUTrainerState()
        if trial is not None:
            raise ValueError("Hyperparameter tuning is not supported by the IPUTrainer.")
            trial = None
        self.state.is_hyper_param_search = trial is not None

        self.training_model = self.wrap_model(self.model)

        self.create_scheduler(num_training_steps=max_steps)

        # TODO: handle optimizer and scheduler creation
        # if delay_optimizer_creation:
        #     self.create_optimizer_and_scheduler(num_training_steps=max_steps)

        # Check if saved optimizer or scheduler states exist
        self._load_optimizer_and_scheduler(resume_from_checkpoint)

        try:
            model_inputs = next(iter(train_dataloader))
        except StopIteration:
            raise ValueError(
                "Couldn't get first sample from dataloader, please check for warnings "
                "during dataloader construction."
            )
        self.compile_model(self.training_model, model_inputs, log=True)

        # Train!
        num_examples = (
            self.num_examples(train_dataloader)
            if has_length(train_dataloader)
            else total_train_batch_size * args.max_steps
        )
        logger.info("***** Running training *****")
        logger.info(f"  Num examples = {num_examples}")
        logger.info(f"  Num epochs = {num_train_epochs}")
        logger.info(f"  Instantaneous batch size per device = {batch_size}")
        logger.info(
            f"  Total training batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}"
        )
        logger.info(f"  Gradient accumulation steps = {self.ipu_config.gradient_accumulation_steps}")
        logger.info(f"  Total optimization steps = {max_steps}")

        self.state.epoch = 0
        start_time = time.time()
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
        steps_trained_progress_bar = None

        # Check if continuing training from a checkpoint
        if resume_from_checkpoint is not None and os.path.isfile(
            os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
        ):
            self.state = IPUTrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
            if self.state.start_time < 0:
                self.state.start_time = start_time
            start_time = self.state.start_time
            epochs_trained = self.state.global_step // num_update_steps_per_epoch
            if not args.ignore_data_skip:
                steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
                # No need to multiply by the number of gradient accumulation steps as poptorch already accounts for that.
                # steps_trained_in_current_epoch *= self.ipu_config.gradient_accumulation_steps
            else:
                steps_trained_in_current_epoch = 0

            logger.info("  Continuing training from checkpoint, will skip to saved global_step")
            logger.info(f"  Continuing training from epoch {epochs_trained}")
            logger.info(f"  Continuing training from global step {self.state.global_step}")
            if not args.ignore_data_skip:
                logger.info(
                    f"  Will skip the first {epochs_trained} epochs then the first {steps_trained_in_current_epoch} "
                    "batches in the first epoch. If this takes a lot of time, you can add the `--ignore_data_skip` "
                    "flag to your launch command, but you will resume the training on data already seen by your model."
                )
                if args.disable_tqdm:
                    steps_trained_progress_bar = tqdm(total=steps_trained_in_current_epoch)
                    steps_trained_progress_bar.set_description("Skipping the first batches")

        # Update the references
        self.callback_handler.model = self.model
        self.callback_handler.optimizer = self.optimizer
        self.callback_handler.lr_scheduler = self.lr_scheduler
        self.callback_handler.train_dataloader = train_dataloader
        self.state.trial_name = None
        self.state.trial_params = None
        # This should be the same if the state has been saved but in case the training arguments changed, it's safer
        # to set this after the load.
        self.state.max_steps = max_steps
        self.state.num_train_epochs = num_train_epochs
        self.state.start_time = start_time

        # tr_loss is a tensor to avoid synchronization of TPUs through .item()
        tr_loss = torch.tensor(0.0).to(args.device)
        # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
        self._total_loss_scalar = 0.0
        self._globalstep_last_logged = self.state.global_step

        self.control = self.callback_handler.on_train_begin(args, self.state, self.control)

        # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
        if not args.ignore_data_skip:
            for epoch in range(epochs_trained):
                is_random_sampler = hasattr(train_dataloader, "sampler") and isinstance(
                    train_dataloader.sampler, RandomSampler
                )
                if is_torch_less_than_1_11 or not is_random_sampler:
                    # We just need to begin an iteration to create the randomization of the sampler.
                    # That was before PyTorch 1.11 however...
                    for _ in train_dataloader:
                        break
                else:
                    # Otherwise we need to call the whooooole sampler cause there is some random operation added
                    # AT THE VERY END!
                    _ = list(train_dataloader.sampler)

        for epoch in range(epochs_trained, num_train_epochs):
            if isinstance(train_dataloader, poptorch.DataLoader) and isinstance(
                train_dataloader.sampler, DistributedSampler
            ):
                train_dataloader.sampler.set_epoch(epoch)
            elif isinstance(train_dataloader.dataset, IterableDatasetShard):
                train_dataloader.dataset.set_epoch(epoch)

            epoch_iterator = train_dataloader

            # Reset the past mems state at the beginning of each epoch if necessary.
            if args.past_index >= 0:
                self._past = None

            steps_in_epoch = (
                len(epoch_iterator)
                if has_length(train_dataloader)
                else args.max_steps * self.ipu_config.gradient_accumulation_steps
            )

            self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)

            if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
                self._load_rng_state(resume_from_checkpoint)

            step = -1
            for step, inputs in enumerate(epoch_iterator):
                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    if steps_trained_progress_bar is not None:
                        steps_trained_progress_bar.update(1)
                    if steps_trained_in_current_epoch == 0:
                        self._load_rng_state(resume_from_checkpoint)
                    continue
                elif steps_trained_progress_bar is not None:
                    steps_trained_progress_bar.close()
                    steps_trained_progress_bar = None

                self.control = self.callback_handler.on_step_begin(args, self.state, self.control)

                tr_loss_step = self.training_step(self.training_model, inputs)

                if args.logging_nan_inf_filter and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)):
                    # if loss is nan or inf simply add the average of previous logged losses
                    tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
                else:
                    tr_loss += tr_loss_step

                # TODO: see how to enable this (if necessary), slows down training a lot.
                self.current_flos += float(self.floating_point_ops(inputs))

                # Optimizer step
                optimizer_was_run = True

                if optimizer_was_run:
                    self.lr_scheduler.step()
                    self.training_model.setOptimizer(self.optimizer)

                self.state.global_step += 1
                self.state.epoch = epoch + (step + 1) / steps_in_epoch
                self.control = self.callback_handler.on_step_end(args, self.state, self.control)

                self._maybe_log_save_evaluate(tr_loss, self.training_model, epoch, ignore_keys_for_eval)

                if self.control.should_epoch_stop or self.control.should_training_stop:
                    break

            if step < 0:
                logger.warning(
                    "There seems to be not a single sample in your epoch_iterator, stopping training at step"
                    f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
                    f" num_steps ({max_steps}) higher than the number of available samples."
                )
                self.control.should_training_stop = True

            self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
            self._maybe_log_save_evaluate(tr_loss, self.training_model, epoch, ignore_keys_for_eval)

            if self.control.should_training_stop:
                break

        if args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of training
            delattr(self, "_past")

        logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
        if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
            self._load_best_model()

        # add remaining tr_loss
        self._total_loss_scalar += tr_loss.item()
        train_loss = self._total_loss_scalar / self.state.global_step

        metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps)
        self.store_flos()
        metrics["total_flos"] = self.state.total_flos
        metrics["train_loss"] = train_loss

        self.is_in_train = False

        self._memory_tracker.stop_and_update_metrics(metrics)

        self.log(metrics)

        self.control = self.callback_handler.on_train_end(args, self.state, self.control)

        # Detaching model from device to let the inference model attach itself
        self._detach_training_model()

        return TrainOutput(self.state.global_step, train_loss, metrics)

    def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
        if model is None:
            model = self.model

        config_file = os.path.join(resume_from_checkpoint, CONFIG_NAME)
        weights_file = os.path.join(resume_from_checkpoint, WEIGHTS_NAME)
        weights_index_file = os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
        adapter_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_WEIGHTS_NAME)

        if not any(
            os.path.isfile(f)
            for f in [
                weights_file,
                weights_index_file,
                adapter_weights_file,
            ]
        ):
            raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")

        logger.info(f"Loading model from {resume_from_checkpoint}.")

        if os.path.isfile(config_file):
            config = PretrainedConfig.from_json_file(config_file)
            checkpoint_version = config.transformers_version
            if checkpoint_version is not None and checkpoint_version != __version__:
                logger.warning(
                    f"You are resuming training from a checkpoint trained with {checkpoint_version} of "
                    f"Transformers but your current version is {__version__}. This is not recommended and could "
                    "yield to errors or unwanted behavior."
                )

        if os.path.isfile(weights_file):
            # We load the model state dict on the CPU to avoid an OOM error.
            state_dict = torch.load(weights_file, map_location="cpu")
            # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
            # which takes *args instead of **kwargs
            load_result = model.load_state_dict(state_dict, False)
            # release memory
            del state_dict
            self._issue_warnings_after_load(load_result)

        # Load adapters following PR # 24096 (> 4.29.2)
        elif isinstance(model, PeftModel):
            # If training a model using PEFT & LoRA, assume that adapter has been saved properly.
            if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"):
                if os.path.exists(resume_from_checkpoint):
                    model.load_adapter(resume_from_checkpoint, model.active_adapter)
                else:
                    logger.warning(
                        "The intermediate checkpoints of PEFT may not be saved correctly, "
                        f"consider using a custom callback to save {ADAPTER_WEIGHTS_NAME} in corresponding saving folders. "
                        "Check some examples here: https://github.com/huggingface/peft/issues/96"
                    )
            else:
                logger.warning("Could not load adapter model, make sure to have `peft>=0.3.0` installed")

    def _load_best_model(self):
        logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
        model = self.model
        best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
        best_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_WEIGHTS_NAME)
        if os.path.exists(best_model_path) or os.path.exists(best_adapter_model_path):
            if isinstance(model, PeftModel):
                # If training a model using PEFT & LoRA, assume that adapter has been saved properly.
                if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"):
                    if os.path.exists(best_adapter_model_path):
                        model.load_adapter(self.state.best_model_checkpoint, model.active_adapter)
            else:
                # We load the model state dict on the CPU to avoid an OOM error.
                state_dict = torch.load(best_model_path, map_location="cpu")
                self._load_state_dict_in_model(state_dict)
        else:
            logger.warning(
                f"Could not locate the best model at {best_model_path}. If you are running a distributed training "
                "on multiple nodes, you should activate `--save_on_each_node`."
            )

    def _load_state_dict_in_model(self, state_dict):
        self.model.deparallelize()
        load_result = self.model.load_state_dict(state_dict, strict=False)
        self.model.parallelize(**self.model.ipu_config._parallelize_kwargs)
        if not self.args.fp32:
            self.model.half()

        if len(load_result.missing_keys) != 0:
            if self.model._keys_to_ignore_on_save is not None and set(load_result.missing_keys) != set(
                self.model._keys_to_ignore_on_save
            ):
                logger.warn(f"There were missing keys in the checkpoint model loaded: {load_result.missing_keys}.")
        if len(load_result.unexpected_keys) != 0:
            logger.warn(f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}.")

    def _issue_warnings_after_load(self, load_result):
        if len(load_result.missing_keys) != 0:
            if self.model._keys_to_ignore_on_save is not None and set(load_result.missing_keys) == set(
                self.model._keys_to_ignore_on_save
            ):
                self.model.tie_weights()
            else:
                logger.warning(f"There were missing keys in the checkpoint model loaded: {load_result.missing_keys}.")
        if len(load_result.unexpected_keys) != 0:
            logger.warning(
                f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}."
            )

    def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval):
        if self.control.should_log:
            logs: Dict[str, float] = {}

            tr_loss_scalar = tr_loss.mean().item()

            # reset tr_loss to zero
            tr_loss -= tr_loss

            logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
            logs["learning_rate"] = self._get_learning_rate()

            self._total_loss_scalar += tr_loss_scalar
            self._globalstep_last_logged = self.state.global_step
            self.store_flos()

            self.log(logs)

        metrics = None
        if self.control.should_evaluate:
            self._detach_training_model()
            if isinstance(self.eval_dataset, dict):
                for eval_dataset_name, eval_dataset in self.eval_dataset.items():
                    metrics = self.evaluate(
                        eval_dataset=eval_dataset,
                        ignore_keys=ignore_keys_for_eval,
                        metric_key_prefix=f"eval_{eval_dataset_name}",
                    )
            else:
                metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
            self._reattach_training_model()

        if self.control.should_save:
            self._save_checkpoint(model, metrics=metrics)
            self.control = self.callback_handler.on_save(self.args, self.state, self.control)

    def _load_rng_state(self, checkpoint):
        # Load RNG states from `checkpoint`
        if checkpoint is None:
            return

        # TODO: validate that.
        local_rank = -1
        if local_rank != -1:
            rng_file = os.path.join(checkpoint, f"rng_state_{local_rank}.pth")
            if not os.path.isfile(os.path.join(checkpoint, rng_file)):
                logger.info(
                    f"Didn't find an RNG file for process {local_rank}. If you are resuming a training that "
                    "wasn't launched in a distributed fashion, reproducibility is not guaranteed."
                )
                return
        else:
            rng_file = os.path.join(checkpoint, "rng_state.pth")
            if not os.path.isfile(rng_file):
                logger.info(
                    "Didn't find an RNG file. If you are resuming a training that was launched in a distributed "
                    "fashion, reproducibility is not guaranteed."
                )
                return

        checkpoint_rng_state = torch.load(rng_file)
        random.setstate(checkpoint_rng_state["python"])
        np.random.set_state(checkpoint_rng_state["numpy"])
        torch.random.set_rng_state(checkpoint_rng_state["cpu"])
        # TODO: enable this when SDK 2.5 is out.
        # self.training_model.rng_state = checkpoint_rng_state["ipu"]

    def _save_checkpoint(self, model, metrics=None):
        # Save model checkpoint
        checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"

        run_dir = self.args.output_dir
        self.store_flos()

        output_dir = os.path.join(run_dir, checkpoint_folder)
        self.save_model(output_dir, _internal_call=True)
        if self.args.should_save:
            torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
            with warnings.catch_warnings(record=True) as caught_warnings:
                torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
            reissue_pt_warnings(caught_warnings)

        # Determine the new best metric / best model checkpoint
        if metrics is not None and self.args.metric_for_best_model is not None:
            metric_to_check = self.args.metric_for_best_model
            if not metric_to_check.startswith("eval_"):
                metric_to_check = f"eval_{metric_to_check}"
            metric_value = metrics[metric_to_check]

            operator = np.greater if self.args.greater_is_better else np.less
            if (
                self.state.best_metric is None
                or self.state.best_model_checkpoint is None
                or operator(metric_value, self.state.best_metric)
            ):
                self.state.best_metric = metric_value
                self.state.best_model_checkpoint = output_dir

        # Save the Trainer state
        if self.args.should_save:
            self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))

        # Save RNG state in non-distributed training
        rng_states = {
            "python": random.getstate(),
            "numpy": np.random.get_state(),
            "cpu": torch.random.get_rng_state(),
            # TODO: enable this when SDK 2.5 is out.
            # "ipu": self.training_model.rng_state,
        }

        # A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may
        # not yet exist.
        os.makedirs(output_dir, exist_ok=True)
        torch.save(rng_states, os.path.join(output_dir, "rng_state.pth"))

        if self.args.push_to_hub:
            self._push_from_checkpoint(output_dir)

        # Maybe delete some older checkpoints.
        if self.args.should_save:
            self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)

    def _load_optimizer_and_scheduler(self, checkpoint):
        """If optimizer and scheduler states exist, load them."""
        if checkpoint is None:
            return

        if os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME)) and os.path.isfile(
            os.path.join(checkpoint, SCHEDULER_NAME)
        ):
            self.optimizer.load_state_dict(torch.load(os.path.join(checkpoint, OPTIMIZER_NAME)))
            with warnings.catch_warnings(record=True) as caught_warnings:
                self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
            reissue_pt_warnings(caught_warnings)

            self.training_model.setOptimizer(self.optimizer)

    def log(self, logs: Dict[str, float]) -> None:
        """
        Log `logs` on the various objects watching the training.

        Subclass and override this method to inject custom behavior.

        Args:
            logs (`Dict[str, float]`):
                The values to log.
        """
        if self.state.epoch is not None:
            logs["epoch"] = round(self.state.epoch, 2)

        output = {**logs, **{"step": self.state.global_step}}
        self.state.log_history.append(output)
        self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)

    def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]:
        """
        Prepares a single data sample before feeding it to the model.

        The data sample can be it a tensor or a nested list or dictionary of tensors.
        """
        if isinstance(data, dict):
            return type(data)(**{k: self._prepare_input(v) for k, v in data.items()})
        elif isinstance(data, (tuple, list)):
            return type(data)(self._prepare_input(v) for v in data)
        elif isinstance(data, torch.Tensor):
            return data
        return data

    def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
        """
        Prepares inputs before feeding them to the model.

        This method converts the inputs to tensors if they are not already tensors and handles the potential state.
        """
        inputs = self._prepare_input(inputs)
        if len(inputs) == 0:
            raise ValueError(
                "The batch received was empty. Your model won't be able to train on it. Double-check that your "
                f"training dataset contains the keys expected by the model: {','.join(self._signature_columns)}."
            )
        if self.args.past_index >= 0 and self._past is not None:
            inputs["mems"] = self._past

        return inputs

    def training_step(
        self, model: poptorch.PoplarExecutor, inputs: Dict[str, Union[torch.Tensor, Any]]
    ) -> torch.Tensor:
        """
        Performs a training step on a batch of inputs.

        Subclass and override to inject custom behavior.

        Args:
            model (`poptorch.PoplarExecutor`):
                The model to train.
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument `labels`. Check your model's documentation for all accepted arguments.

        Return:
            `torch.Tensor`: The tensor with the training loss on this batch.
        """
        inputs = self._prepare_inputs(inputs)
        loss = self.compute_loss(model, inputs)
        loss = loss.mean()
        return loss

    def compute_loss(self, model, inputs, return_outputs=False):
        """
        Computes the loss on a batch of training inputs.

        Args:
            model:
                The model to train.
            inputs:
                The inputs and targets of the model.
            return_outputs (defaults to `False`):
                If `True`, returns the outputs with the loss. If `False`, only returns the loss.

        By default, all models return the loss in the first element.

        Subclass and override for custom behavior.
        """
        if self.label_smoother is not None and "labels" in inputs:
            labels = inputs.pop("labels")
        else:
            labels = None
        outputs = model(**inputs)
        # Save past state if it exists
        # TODO: this needs to be fixed and made cleaner later.
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]

        if labels is not None:
            loss = self.label_smoother(outputs, labels)
        else:
            if isinstance(outputs, dict) and "loss" not in outputs:
                raise ValueError(
                    "The model did not return a loss from the inputs, only the following keys: "
                    f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
                )
            # We don't use .loss here since the model may return tuples instead of ModelOutput.
            loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]

        return (loss, outputs) if return_outputs else loss

    def is_world_process_zero(self) -> bool:
        # Needed only because log_metrics use it.
        return True

    def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
        """
        Saves the model, so you can reload it using `from_pretrained()`.

        Will only save the model from the main process.
        """
        if output_dir is None:
            output_dir = self.args.output_dir

        if self.args.should_save:
            self._save(output_dir)

        # Push to the Hub when `save_model` is called by the user.
        if self.args.push_to_hub and not _internal_call:
            self.push_to_hub(commit_message="Model save")

    def _save(self, output_dir: Optional[str] = None, state_dict=None):
        # If we are executing this function, we are the process zero, so we don't check for that.
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
        logger.info(f"Saving model checkpoint to {output_dir}")

        # Updating self.model weights with the weights stored on device.
        # TODO: can this be deleted? I would makle things faster.
        if self.training_model is not None and self.training_model.isAttachedToDevice():
            self.training_model.copyWeightsToHost()

        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        if not isinstance(self.model, (PreTrainedModel, PeftModel)):
            logger.info(
                "Trainer.model is not a `transformers.PreTrainedModel` or `peft.PeftModel`, only saving its state dict."
            )
            if state_dict is None:
                state_dict = self.model.state_dict()
            torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
        else:
            rng_state = torch.random.get_rng_state()
            self.model.deparallelize()
            self.model.save_pretrained(output_dir, state_dict=state_dict)
            self.model.parallelize(**self.model.ipu_config.parallelize_kwargs)
            torch.random.set_rng_state(rng_state)

        if self.tokenizer is not None:
            self.tokenizer.save_pretrained(output_dir)

        self.ipu_config.save_pretrained(output_dir)

        # Good practice: save your training arguments together with the trained model
        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

    def store_flos(self):
        # Storing the number of floating-point operations that went into the model
        # TODO: Validate that this is right. (It's most likely wrong)
        self.state.total_flos += self.current_flos * self.ipu_config.batch_size_factor()
        self.current_flos = 0

    def _sorted_checkpoints(
        self, output_dir=None, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False
    ) -> List[str]:
        ordering_and_checkpoint_path = []

        glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*")]

        for path in glob_checkpoints:
            if use_mtime:
                ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
            else:
                regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
                if regex_match is not None and regex_match.groups() is not None:
                    ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))

        checkpoints_sorted = sorted(ordering_and_checkpoint_path)
        checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
        # Make sure we don't delete the best model.
        if self.state.best_model_checkpoint is not None:
            best_model_index = checkpoints_sorted.index(str(Path(self.state.best_model_checkpoint)))
            for i in range(best_model_index, len(checkpoints_sorted) - 2):
                checkpoints_sorted[i], checkpoints_sorted[i + 1] = checkpoints_sorted[i + 1], checkpoints_sorted[i]
        return checkpoints_sorted

    def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None:
        if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
            return

        # Check if we should delete older checkpoint(s)
        checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime, output_dir=output_dir)
        if len(checkpoints_sorted) <= self.args.save_total_limit:
            return

        # If save_total_limit=1 with load_best_model_at_end=True, we could end up deleting the last checkpoint, which
        # we don't do to allow resuming.
        save_total_limit = self.args.save_total_limit
        if (
            self.state.best_model_checkpoint is not None
            and self.args.save_total_limit == 1
            and checkpoints_sorted[-1] != self.state.best_model_checkpoint
        ):
            save_total_limit = 2

        number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit)
        checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
        for checkpoint in checkpoints_to_be_deleted:
            logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
            shutil.rmtree(checkpoint)

    def evaluate(
        self,
        eval_dataset: Optional[Dataset] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
    ) -> Dict[str, float]:
        """
        Runs an evaluation and returns metrics.

        The calling script will be responsible for providing a method to compute the metrics, as they are task-dependent
        (pass it to the init `compute_metrics` argument).

        You can also subclass and override this method to inject custom behavior.

        Args:
            eval_dataset (`Dataset`, *optional*):
                Pass a dataset if you wish to override `self.eval_dataset`. If it is a [`~datasets.Dataset`] dataset, the columns
                not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__`
                method.
            ignore_keys (`Lst[str]`, *optional*):
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.
            metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
                An optional prefix to be used as the metrics key prefix. For example the metric "bleu" will be named
                "eval_bleu" if the prefix is "eval" (default)

        Returns:
            A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
            dictionary also contains the epoch number which comes from the training state.
        """
        # memory metrics - must set up as early as possible
        self._memory_tracker.start()

        eval_dataloader = self.get_eval_dataloader(eval_dataset)

        prediction_loss_only = True if self.compute_metrics is None else None
        if prediction_loss_only is None:
            prediction_loss_only = self.args.prediction_loss_only

        # Running this here (even though it is being recalled in self.evaluation_loop to make compilation happen here.
        # That way, compilation will not mess inference speed metrics.
        _ = self._wrap_and_compile_model_for_evaluation(eval_dataloader, prediction_loss_only)

        start_time = time.time()

        output = self.evaluation_loop(
            eval_dataloader,
            description="Evaluation",
            # No point gathering the predictions if there are no metrics, otherwise we defer to
            # self.args.prediction_loss_only
            prediction_loss_only=prediction_loss_only,
            ignore_keys=ignore_keys,
            metric_key_prefix=metric_key_prefix,
        )

        # If we are using padded data collator, dropped the padded part of the output
        if self.args.pad_on_batch_axis:
            eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
            dataset_len = len(eval_dataset)
            output = output._replace(predictions=tuple(pred[:dataset_len] for pred in output.predictions))
            output = output._replace(num_samples=dataset_len)

        total_batch_size = self.args.per_device_eval_batch_size * self.ipu_config.batch_size_factor(for_inference=True)
        output.metrics.update(
            speed_metrics(
                metric_key_prefix,
                start_time,
                num_samples=output.num_samples,
                num_steps=math.ceil(output.num_samples / total_batch_size),
            )
        )

        self.log(output.metrics)

        self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)

        self._memory_tracker.stop_and_update_metrics(output.metrics)

        return output.metrics

    def predict(
        self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "test"
    ) -> PredictionOutput:
        """
        Returns predictions and potential metrics.

        Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
        will also return metrics, like in `evaluate()`.

        Args:
            test_dataset (`Dataset`):
                Dataset to run the predictions on. If it is an `datasets.Dataset` dataset, the columns not accepted by the
                `model.forward()` method are automatically removed. Has to implement the method `__len__`
            ignore_keys (`Lst[str]`, *optional*):
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.
            metric_key_prefix (`str`, *optional*, defaults to `"test"`):
                An optional prefix to be used as the metrics key prefix. For example the metric "bleu" will be named
                "test_bleu" if the prefix is "test" (default)

        <Tip>

        If your predictions or labels have different sequence lengths (for instance because you're doing dynamic padding
        in a token classification task) the predictions will be padded (on the right) to allow for concatenation into
        one array. The padding index is -100.

        </Tip>

        Returns: *NamedTuple* A namedtuple with the following keys:

            - predictions (`np.ndarray`): The predictions on `test_dataset`.
            - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some).
            - metrics (`Dict[str, float]`, *optional*): The dictionary of potential metrics (if the dataset contained
              labels).
        """
        # memory metrics - must set up as early as possible
        self._memory_tracker.start()

        test_dataloader = self.get_test_dataloader(test_dataset)
        start_time = time.time()

        output = self.evaluation_loop(
            test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
        )

        # If we are using padded data collator, dropped the padded part of the output
        if self.args.pad_on_batch_axis:
            dataset_len = len(test_dataset)
            output = output._replace(predictions=tuple([pred[:dataset_len] for pred in output.predictions]))
            output = output._replace(num_samples=dataset_len)

        total_batch_size = self.args.eval_batch_size * self.ipu_config.batch_size_factor(for_inference=True)
        output.metrics.update(
            speed_metrics(
                metric_key_prefix,
                start_time,
                num_samples=output.num_samples,
                num_steps=math.ceil(output.num_samples / total_batch_size),
            )
        )

        self._memory_tracker.stop_and_update_metrics(output.metrics)

        return PredictionOutput(predictions=output.predictions, label_ids=output.label_ids, metrics=output.metrics)

    def _wrap_and_compile_model_for_evaluation(self, dataloader, prediction_loss_only):
        model = self.wrap_model(self.model, training=False)
        try:
            model_inputs = next(iter(dataloader))
        except StopIteration:
            raise ValueError(
                "Couldn't get first sample from dataloader, please check for warnings "
                "during dataloader construction."
            )
        self.compile_model(model, model_inputs, log=True)
        return model

    def evaluation_loop(
        self,
        dataloader: poptorch.DataLoader,
        description: str,
        prediction_loss_only: Optional[bool] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
    ) -> EvalLoopOutput:
        """
        Prediction/evaluation loop, shared by [`IPUTrainer.evaluate`] and [`IPUTrainer.predict`].

        Works both with or without labels.

        Args:
            dataloader (`poptorch.DataLoader`):
                The dataset to be used.
            description (`str`):
                The description of what is being run.
            prediction_loss_only (`bool`):
                If `True`, only returns the loss. If `False`, returns loss,
                logits and labels (if present).
            ignore_keys (`Lst[str]`, *optional*):
                A list of keys in the output of your model (if it is a
                dictionary) that should be ignored when gathering predictions.
            metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
                An optional prefix to be used as the metrics key prefix. For
                example the metric "bleu" will be named "eval_bleu" if the
                prefix is "eval" (default).
        """
        prediction_loss_only = (
            prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only
        )

        self.inference_model = self._wrap_and_compile_model_for_evaluation(dataloader, prediction_loss_only)

        batch_size = dataloader.batch_size

        logger.info(f"***** Running {description} *****")
        if has_length(dataloader):
            logger.info(f"  Num examples = {self.num_examples(dataloader)}")
        else:
            logger.info("  Num examples: Unknown")
        logger.info(f"  Batch size = {batch_size}")

        self.callback_handler.eval_dataloader = dataloader
        # eval_dataset = getattr(dataloader, "dataset", None)

        if self.args.past_index >= 0:
            self._past = None

        # Initialize containers
        # losses/preds/labels on IPU (accumulated for eval_accumulation_steps, legacy code for IPUs)
        losses_host = None
        preds_host = None
        labels_host = None
        # losses/preds/labels on CPU (final containers)
        all_losses = None
        all_preds = None
        all_labels = None
        # Will be useful when we have an iterable dataset so don't know its length.

        observed_num_examples = 0
        # Main evaluation loop
        for step, inputs in enumerate(dataloader):
            # Update the observed num examples
            observed_batch_size = find_batch_size(inputs)
            if observed_batch_size is not None:
                observed_num_examples += observed_batch_size
                # For batch samplers, batch_size is not known by the dataloader in advance.
                if batch_size is None:
                    batch_size = observed_batch_size

            # Prediction step
            # If dataset is not sized, is_last_batch is False because we cannot know.
            is_last_batch = (
                step == len(dataloader) - 1 if isinstance(dataloader.dataset, collections.abc.Sized) else False
            )
            loss, logits, labels = self.prediction_step(
                self.inference_model,
                inputs,
                prediction_loss_only,
                ignore_keys=ignore_keys,
                is_last_batch=is_last_batch,
            )

            # Update containers on host
            if loss is not None:
                loss = loss.mean(dim=0, keepdim=True)
                # If only one IPU is used, loss is a zero dimensional tensor, we unsqueeze to be able to concatenate.
                if loss.dim() == 0:
                    loss = loss.unsqueeze(0)
                losses_host = loss if losses_host is None else torch.cat((losses_host, loss), dim=0)
            if logits is not None:
                if self.preprocess_logits_for_metrics is not None:
                    logits = self.preprocess_logits_for_metrics(logits, labels)
                preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
            if labels is not None:
                labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)

            self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control)

        if self.args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of the evaluation loop
            delattr(self, "_past")

        if losses_host is not None:
            losses = nested_numpify(losses_host)
            all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
        if preds_host is not None:
            logits = nested_numpify(preds_host)
            all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
        if labels_host is not None:
            labels = nested_numpify(labels_host)
            all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)

        # Number of samples
        # In the original Trainer, TODO: should we use this instead?
        # if has_length(eval_dataset):
        #     num_samples = len(eval_dataset)
        # # The instance check is weird and does not actually check for the type, but whether the dataset has the right
        # # methods. Therefore we need to make sure it also has the attribute.
        # elif isinstance(eval_dataset, IterableDatasetShard) and getattr(eval_dataset, "num_examples", 0) > 0:
        #     num_samples = eval_dataset.num_examples
        # else:
        #     if has_length(dataloader):
        #         num_samples = self.num_examples(dataloader)
        #     else:  # both len(dataloader.dataset) and len(dataloader) fail
        #         num_samples = observed_num_examples
        # if num_samples == 0 and observed_num_examples > 0:
        #     num_samples = observed_num_examples
        num_samples = observed_num_examples

        # Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of
        # samplers has been rounded to a multiple of batch_size, so we truncate.
        if all_losses is not None:
            all_losses = all_losses[:num_samples]
        if all_preds is not None:
            all_preds = nested_truncate(all_preds, num_samples)
        if all_labels is not None:
            all_labels = nested_truncate(all_labels, num_samples)

        # Metrics!
        if self.compute_metrics is not None and all_preds is not None and all_labels is not None:
            metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))
        else:
            metrics = {}

        # To be JSON-serializable, we need to remove numpy types or zero-d tensors
        metrics = denumpify_detensorize(metrics)

        if all_losses is not None:
            metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()

        # Prefix all keys with metric_key_prefix + '_'
        for key in list(metrics.keys()):
            if not key.startswith(f"{metric_key_prefix}_"):
                metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)

        # Detaching model from device to let the training model attach itself
        self._detach_inference_model()

        return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples)

    def prediction_step(
        self,
        model: poptorch.PoplarExecutor,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None,
        is_last_batch: bool = False,
    ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
        """
        Performs an evaluation step.

        Subclass and override to inject custom behavior.

        Args:
            model (`poptorch.PoplarExecutor`):
                The model to evaluate.
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model.
                Most models expect the targets under the argument `labels`.
                Check your model's documentation for all accepted arguments.
            prediction_loss_only (`bool`):
                If `True`, only returns the loss. If `False`, returns loss,
                logits and labels (if present).
            ignore_keys (`Lst[str]`, *optional*):
                A list of keys in the output of your model (if it is a
                dictionary) that should be ignored when gathering predictions.

        Return:
            Tuple[Optional[torch.Tensor], Optional[torch.Tensor],
            Optional[torch.Tensor]]:
                A tuple with the loss, logits and labels (each being optional).
        """
        has_labels = all(inputs.get(k) is not None for k in self.label_names)
        inputs = self._prepare_inputs(inputs)
        if ignore_keys is None:
            if hasattr(self.model, "config"):
                ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
            else:
                ignore_keys = []

        # labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
        if has_labels:
            labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
            if len(labels) == 1:
                labels = labels[0]
        else:
            labels = None

        with torch.no_grad():
            if has_labels:
                loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
                # If last batch is incomplete, some losses might be NaN because nothing was computed on the
                # corresponding Pod, ignoring them is necessary to not mess up evaluation loss computation
                if is_last_batch:
                    loss = loss[~loss.isnan()]
                loss = loss.detach()
                if isinstance(outputs, dict):
                    logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
                else:
                    logits = outputs[1:]
            else:
                loss = None
                outputs = model(**inputs)
                if isinstance(outputs, dict):
                    logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
                else:
                    logits = outputs
                # TODO: this needs to be fixed and made cleaner later.
                if self.args.past_index >= 0:
                    self._past = outputs[self.args.past_index - 1]

        if prediction_loss_only:
            return (loss, None, None)

        logits = nested_detach(logits)
        if len(logits) == 1:
            logits = logits[0]

        return (loss, logits, labels)

    def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]):
        """
        For models that inherit from [`transformers.PreTrainedModel`], uses that class's `floating_point_ops` method to compute the number of
        floating point operations for every backward and every forward pass.

        If using another model, either implement a `floating_point_ops`
        method in the model or subclass and override this method.

        Args:
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

        Returns:
            `int`: The number of floating-point operations.
        """
        # Using self.original_model because self.model is the underlying model used by the IPUs
        # and calling floating_point_ops on it slows things down a lot.
        if hasattr(self.original_model, "floating_point_ops"):
            return self.original_model.floating_point_ops(inputs)
        else:
            return 0

    def init_git_repo(self, at_init: bool = False):
        """
        Initializes a Git repo in `self.args.hub_model_id`.

        Args:
            at_init (`bool`, *optional*, defaults to `False`):
                If `True`, this function is called before any training. If
                `self.args.overwrite_output_dir` is `True` and `at_init` is
                `True`, the path to the repo (which is `self.args.output_dir`)
                might be wiped out.
        """
        if not self.is_world_process_zero():
            return
        use_auth_token = True if self.args.hub_token is None else self.args.hub_token
        if self.args.hub_model_id is None:
            repo_name = Path(self.args.output_dir).absolute().name
        else:
            repo_name = self.args.hub_model_id
        if "/" not in repo_name:
            repo_name = get_full_repo_name(repo_name, token=self.args.hub_token)

        try:
            self.repo = Repository(
                self.args.output_dir,
                clone_from=repo_name,
                use_auth_token=use_auth_token,
                private=self.args.hub_private_repo,
            )
        except EnvironmentError:
            if self.args.overwrite_output_dir and at_init:
                # Try again after wiping output_dir
                shutil.rmtree(self.args.output_dir)
                self.repo = Repository(
                    self.args.output_dir,
                    clone_from=repo_name,
                    use_auth_token=use_auth_token,
                )
            else:
                raise

        self.repo.git_pull()

        # By default, ignore the checkpoint folders
        if (
            not os.path.exists(os.path.join(self.args.output_dir, ".gitignore"))
            and self.args.hub_strategy != HubStrategy.ALL_CHECKPOINTS
        ):
            with open(os.path.join(self.args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer:
                writer.writelines(["checkpoint-*/"])

        self.push_in_progress = None

    def create_model_card(
        self,
        language: Optional[str] = None,
        license: Optional[str] = None,
        tags: Union[str, List[str], None] = None,
        model_name: Optional[str] = None,
        finetuned_from: Optional[str] = None,
        tasks: Union[str, List[str], None] = None,
        dataset_tags: Union[str, List[str], None] = None,
        dataset: Union[str, List[str], None] = None,
        dataset_args: Union[str, List[str], None] = None,
    ):
        """
        Creates a draft of a model card using the information available to
        [`IPUTrainer`].

        Args:
            language (`str`, *optional*):
                The language of the model (if applicable)
            license (`str`, *optional*):
                The license of the model. Will default to the license of the pretrained model used, if the original
                model given to [`IPUTrainer`] comes from a repo on the Hub.
            tags (`str` or `List[str]`, *optional*):
                Some tags to be included in the metadata of the model card.
            model_name (`str`, *optional*):
                The name of the model.
            finetuned_from (`str`, *optional*):
                The name of the model used to fine-tune this one (if applicable). Will default to the name of the repo
                of the original model given to [`IPUTrainer`] (if it comes from the Hub).
            tasks (`str` or `List[str]`, *optional*):
                One or several task identifiers, to be included in the metadata of the model card.
            dataset_tags (`str` or `List[str]`, *optional*):
                One or several dataset tags, to be included in the metadata of the model card.
            dataset (`str` or `List[str]`, *optional*):
                One or several dataset identifiers, to be included in the metadata of the model card.
            dataset_args (`str` or `List[str]`, *optional*):
               One or several dataset arguments, to be included in the metadata of the model card.
        """
        if not self.is_world_process_zero():
            return

        training_summary = IPUTrainingSummary.from_trainer(
            self,
            language=language,
            license=license,
            tags=tags,
            model_name=model_name,
            finetuned_from=finetuned_from,
            tasks=tasks,
            dataset_tags=dataset_tags,
            dataset=dataset,
            dataset_args=dataset_args,
        )
        model_card = training_summary.to_model_card()
        with open(os.path.join(self.args.output_dir, "README.md"), "w") as f:
            f.write(model_card)

    def _push_from_checkpoint(self, checkpoint_folder):
        # Only push from one node.
        if self.args.hub_strategy == HubStrategy.END:
            return
        # If we haven't finished the last push, we don't do this one.
        if self.push_in_progress is not None and not self.push_in_progress.is_done:
            return

        output_dir = self.args.output_dir
        # To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder
        modeling_files = [CONFIG_NAME, WEIGHTS_NAME, IPU_CONFIG_NAME]
        for modeling_file in modeling_files:
            if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)):
                shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file))
        # Saving the tokenizer is fast and we don't know how many files it may have spawned, so we resave it to be sure.
        if self.tokenizer is not None:
            self.tokenizer.save_pretrained(output_dir)
        # Same for the training arguments
        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

        try:
            if self.args.hub_strategy == HubStrategy.CHECKPOINT:
                # Temporarily move the checkpoint just saved for the push
                tmp_checkpoint = os.path.join(output_dir, "last-checkpoint")
                # We have to remove the "last-checkpoint" dir if it exists, otherwise the checkpoint is moved as a
                # subfolder.
                if os.path.isdir(tmp_checkpoint):
                    shutil.rmtree(tmp_checkpoint)
                shutil.move(checkpoint_folder, tmp_checkpoint)

            if self.args.save_strategy == IntervalStrategy.STEPS:
                commit_message = f"Training in progress, step {self.state.global_step}"
            else:
                commit_message = f"Training in progress, epoch {int(self.state.epoch)}"
            _, self.push_in_progress = self.repo.push_to_hub(
                commit_message=commit_message, blocking=False, auto_lfs_prune=True
            )
        finally:
            if self.args.hub_strategy == HubStrategy.CHECKPOINT:
                # Move back the checkpoint to its place
                shutil.move(tmp_checkpoint, checkpoint_folder)

    def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
        """
        Uploads *self.model* and *self.tokenizer* to the 🤗 Models Hub on the repo *self.args.hub_model_id*.

        Parameters:
            commit_message (`str`, *optional*, defaults to `"End of training"`):
                Message for the commit.
            blocking (`bool`, *optional*, defaults to `True`):
                If `True` (default), the function only returns when the `git push` command has completed. If `False`, returns immediately.
            kwargs:
                Additional keyword arguments passed along to [`~Trainer.create_model_card`].

        Returns:
            If `blocking=False`, returns the URL of the commit of your model in the given repository. If `blocking=True`, returns a tuple with the URL of the commit and an object to track the progress of the commit.
        """
        # If a user calls manually `push_to_hub` with `self.args.push_to_hub = False`, we try to create the repo but
        # it might fail.
        if not hasattr(self, "repo"):
            self.init_git_repo()

        if self.args.should_save:
            if self.args.hub_model_id is None:
                model_name = Path(self.args.output_dir).name
            else:
                model_name = self.args.hub_model_id.split("/")[-1]

        # Needs to be executed on all processes for TPU training, but will only save on the processed determined by
        # self.args.should_save.
        self.save_model(_internal_call=True)

        # Cancel any async push in progress if blocking=True. The commits will all be pushed together.
        if blocking and self.push_in_progress is not None and not self.push_in_progress.is_done:
            self.push_in_progress._process.kill()
            self.push_in_progress = None

        git_head_commit_url = self.repo.push_to_hub(
            commit_message=commit_message, blocking=blocking, auto_lfs_prune=True
        )
        # push separately the model card to be independant from the rest of the model
        if self.args.should_save:
            self.create_model_card(model_name=model_name, **kwargs)
            try:
                self.repo.push_to_hub(
                    commit_message="update model card README.md", blocking=blocking, auto_lfs_prune=True
                )
            except EnvironmentError as exc:
                logger.error(f"Error pushing update to the model card. Please read logs and retry.\n${exc}")

        return git_head_commit_url
