optimum/graphcore/trainer_seq2seq.py (171 lines of code) (raw):

# Copyright 2022 The HuggingFace Team. All rights reserved. # Copyright (c) 2023 Graphcore Ltd. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from copy import deepcopy from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union import torch from poptorch._impl import rewrapModelIfNecessary, unwrapModelIfNecessary from torch import nn from torch.utils.data import Dataset from transformers.generation.configuration_utils import GenerationConfig from optimum.utils import logging from .trainer import IPUConfig, IPUTrainer, IPUTrainingArguments if TYPE_CHECKING: from transformers.data.data_collator import DataCollator from transformers.modeling_utils import PreTrainedModel from transformers.tokenization_utils_base import PreTrainedTokenizerBase from transformers.trainer_callback import TrainerCallback from transformers.trainer_utils import EvalPrediction, PredictionOutput logger = logging.get_logger(__name__) class IPUSeq2SeqTrainer(IPUTrainer): """ The [`IPUSeq2SeqTrainer`] class is used to train seq2seq models. Its behaviour is exactly the same as [`IPUTrainer`] except that it expects [`IPUSeq2SeqTrainingArguments`] instead of [`IPUTrainingArguments`]. """ def __init__( self, model: Union["PreTrainedModel", nn.Module] = None, ipu_config: IPUConfig = None, args: "IPUTrainingArguments" = None, data_collator: Optional["DataCollator"] = None, eval_data_collator: Optional["DataCollator"] = None, train_dataset: Optional[Dataset] = None, eval_dataset: Optional[Dataset] = None, tokenizer: Optional["PreTrainedTokenizerBase"] = None, model_init: Callable[[], "PreTrainedModel"] = None, compute_metrics: Optional[Callable[["EvalPrediction"], Dict]] = None, callbacks: Optional[List["TrainerCallback"]] = None, optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, force_to_pipelined: bool = False, ): super().__init__( model=model, ipu_config=ipu_config, args=args, data_collator=data_collator, eval_data_collator=eval_data_collator, train_dataset=train_dataset, eval_dataset=eval_dataset, tokenizer=tokenizer, model_init=model_init, compute_metrics=compute_metrics, callbacks=callbacks, optimizers=optimizers, preprocess_logits_for_metrics=preprocess_logits_for_metrics, force_to_pipelined=force_to_pipelined, ) # Override self.model.generation_config if a GenerationConfig is specified in args. # Priority: args.generation_config > model.generation_config > default GenerationConfig. if self.args.generation_config is not None: gen_config = self.load_generation_config(self.args.generation_config) self.model.generation_config = gen_config @staticmethod def load_generation_config(gen_config_arg: Union[str, GenerationConfig]) -> GenerationConfig: """ Loads a `~generation.GenerationConfig` from the `Seq2SeqTrainingArguments.generation_config` arguments. Args: gen_config_arg (`str` or [`~generation.GenerationConfig`]): `Seq2SeqTrainingArguments.generation_config` argument. Returns: A `~generation.GenerationConfig`. """ # GenerationConfig provided, nothing to do if isinstance(gen_config_arg, GenerationConfig): return deepcopy(gen_config_arg) # str or Path pretrained_model_name = Path(gen_config_arg) if isinstance(gen_config_arg, str) else gen_config_arg config_file_name = None # Figuring if it is path pointing to a file, pointing to a directory or else a model id or URL # This step is required in order to determine config_file_name if pretrained_model_name.is_file(): config_file_name = pretrained_model_name.name pretrained_model_name = pretrained_model_name.parent # dir path elif pretrained_model_name.is_dir(): pass # model id or URL else: pretrained_model_name = gen_config_arg gen_config = GenerationConfig.from_pretrained(pretrained_model_name, config_file_name) return gen_config def _wrap_and_compile_model_for_evaluation(self, dataloader, prediction_loss_only): if prediction_loss_only: return super()._wrap_and_compile_model_for_evaluation(dataloader, prediction_loss_only) # Unwrap the model, including parameter and buffer annotations and the # model as a whole. This is needed to avoid issues with persistent buffers # when compiling an inference model for generation unwrapModelIfNecessary(self.model) # reparallelize for generation self.model.deparallelize().ipu_config.eval() self.model.parallelize(for_generation=True, **self.model.ipu_config.inference_parallelize_kwargs) # let IPUGenerationMixin::_call_generate handle compilation of the model # note though that self.model.poptorch_decoder and self.model.poptorch_encoder # (attribute added by IPUGenerationMixin::_call_generate) # are the actual models attached to the device return self.model def _rewrap_model_for_training(self): self.model.deparallelize().ipu_config.train() self.model.parallelize(**self.model.ipu_config.parallelize_kwargs) # Restores the PoptorchParameter and PoptorchBuffer annotations in the model rewrapModelIfNecessary(self.model) def evaluate( self, eval_dataset: Optional[Dataset] = None, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "eval", **gen_kwargs, ) -> Dict[str, float]: """ Run evaluation and returns metrics. The calling script will be responsible for providing a method to compute metrics, as they are task-dependent (pass it to the init `compute_metrics` argument). You can also subclass and override this method to inject custom behavior. Args: eval_dataset (`Dataset`, *optional*): Pass a dataset if you wish to override `self.eval_dataset`. If it is an `datasets.Dataset`, columns not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__` method. ignore_keys (`List[str]`, *optional*): A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions. metric_key_prefix (`str`, *optional*, defaults to `"eval"`): An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named "eval_bleu" if the prefix is `"eval"` (default) max_length (`int`, *optional*): The maximum target length to use when predicting with the generate method. num_beams (`int`, *optional*): Number of beams for the beam search that will be used when predicting with the generate method. 1 means no beam search. gen_kwargs: Additional `generate` specific kwargs. Returns: A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The dictionary also contains the epoch number which comes from the training state. """ gen_kwargs = gen_kwargs.copy() if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: gen_kwargs["max_length"] = self.args.generation_max_length gen_kwargs["num_beams"] = ( gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams ) self._gen_kwargs = gen_kwargs results = super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) self._rewrap_model_for_training() return results def predict( self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "test", **gen_kwargs, ) -> "PredictionOutput": """ Runs prediction and returns predictions and potential metrics. Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method will also return metrics, like `evaluate()`. Args: test_dataset (`Dataset`): Dataset to run the predictions on. If it is a `datasets.Dataset` dataset, the columns not accepted by the `model.forward()` method are automatically removed. Has to implement the method `__len__` ignore_keys (`List[str]`, *optional*): A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions. metric_key_prefix (`str`, *optional*, defaults to `"eval"`): An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named "eval_bleu" if the prefix is `"eval"` (default) max_length (`int`, *optional*): The maximum target length to use when predicting with the generate method. num_beams (`int`, *optional*): Number of beams for the beam search that will be used when predicting with the generate method. 1 means no beam search. gen_kwargs: Additional `generate` specific kwargs. <Tip> If your predictions or labels have different sequence lengths (for instance because you're doing dynamic padding in a token classification task) the predictions will be padded (on the right) to allow for concatenation into one array. The padding index is -100. </Tip> Returns: *NamedTuple* A namedtuple with the following keys: - predictions (`np.ndarray`): The predictions on `test_dataset`. - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some). - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained labels). """ gen_kwargs = gen_kwargs.copy() if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: gen_kwargs["max_length"] = self.args.generation_max_length gen_kwargs["num_beams"] = ( gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams ) self._gen_kwargs = gen_kwargs results = super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) self._rewrap_model_for_training() return results def prediction_step( self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool, ignore_keys: Optional[List[str]] = None, is_last_batch: bool = False, ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: """ Performs an evaluation step on a model using the inputs. Subclass and override to inject custom behavior. Args: model (`nn.Module`): The model to evaluate. inputs (`Dict[str, Union[torch.Tensor, Any]]`): The inputs and targets of the model. The dictionary will be unpacked before being fed to the model. Most models expect the targets under the argument `labels`. Check your model's documentation for all accepted arguments. prediction_loss_only (`bool`): If `True`, only returns the loss. Return: Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and labels (each being optional). """ if not self.args.predict_with_generate or prediction_loss_only: return super().prediction_step( model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys, is_last_batch=is_last_batch, ) inputs = self._prepare_inputs(inputs) # Priority (handled in generate): # gen_kwargs > model.generation_config > default GenerationConfig() gen_kwargs = self._gen_kwargs.copy() if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: gen_kwargs["max_length"] = self.model.config.max_length gen_kwargs["num_beams"] = ( gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.model.config.num_beams ) # prepare generation inputs # some encoder-decoder models can have varying encoder's and thus # varying model input names if hasattr(self.model, "encoder") and self.model.encoder.main_input_name != self.model.main_input_name: generation_inputs = inputs[self.model.encoder.main_input_name] else: generation_inputs = inputs[self.model.main_input_name] generated_tokens = model.generate( generation_inputs, attention_mask=inputs.get("attention_mask", None), **gen_kwargs, ) # Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop # TODO: remove this hack when the legacy code that initializes generation_config from a model config is # removed in https://github.com/huggingface/transformers/blob/98d88b23f54e5a23e741833f1e973fdf600cc2c5/src/transformers/generation/utils.py#L1183 if self.model.generation_config._from_model_config: self.model.generation_config._from_model_config = False # Retrieves GenerationConfig from model.generation_config gen_config = self.model.generation_config # in case the batch is shorter than max length, the output should be padded if generated_tokens.shape[-1] < gen_config.max_length: generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_config.max_length) elif gen_config.max_new_tokens is not None and generated_tokens.shape[-1] < gen_config.max_new_tokens + 1: generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_config.max_new_tokens + 1) # with torch.no_grad(): # # with self.autocast_smart_context_manager(): # outputs = model(**inputs) # if has_labels: # if self.label_smoother is not None: # loss = self.label_smoother(outputs, inputs["labels"]).mean().detach() # else: # loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach() # else: # loss = None # if self.args.prediction_loss_only: # return (loss, None, None) # if has_labels: # labels = inputs["labels"] # if labels.shape[-1] < gen_kwargs["max_length"]: # labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"]) # else: # labels = None loss = None labels = inputs["labels"] return (loss, generated_tokens, labels) def _pad_tensors_to_max_len(self, tensor, max_length): if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"): # If PAD token is not defined at least EOS token has to be defined pad_token_id = ( self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id ) else: if self.model.config.pad_token_id is not None: pad_token_id = self.model.config.pad_token_id else: raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors") padded_tensor = pad_token_id * torch.ones( (tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device ) padded_tensor[:, : tensor.shape[-1]] = tensor return padded_tensor