optimum/habana/trl/trainer/dpo_trainer.py (562 lines of code) (raw):
# DPO Authors: Rafael Rafailov, Archit Sharma, Eric Mitchell, Stefano Ermon, Christopher D. Manning, and Chelsea Finn 2023
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import warnings
from collections import defaultdict
from contextlib import nullcontext
from functools import partial
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
import torch
import torch.nn as nn
from accelerate import PartialState
from accelerate.utils import is_deepspeed_available
from datasets import Dataset
from transformers import (
AutoModelForCausalLM,
DataCollator,
PreTrainedModel,
PreTrainedTokenizerBase,
)
from transformers.models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalLoopOutput
from trl import DPOTrainer, create_reference_model
from trl.import_utils import is_peft_available, is_wandb_available
from trl.trainer.dpo_config import FDivergenceConstants
from trl.trainer.utils import (
DPODataCollatorWithPadding,
RunningMoments,
SyncRefModelCallback,
disable_dropout_in_model,
pad_to_length,
)
from ... import GaudiConfig, GaudiTrainer
from .dpo_config import GaudiDPOConfig
if is_peft_available():
from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training
if is_wandb_available():
pass
if is_deepspeed_available():
pass
class GaudiDPOTrainer(DPOTrainer, GaudiTrainer):
def __init__(
self,
model: Union[PreTrainedModel, nn.Module, str] = None,
ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
beta: float = 0.1,
label_smoothing: float = 0,
loss_type: Literal["sigmoid", "hinge", "ipo", "bco_pair", "robust", "aot", "aot_pair"] = "sigmoid",
args: Optional[GaudiDPOConfig] = None,
gaudi_config: GaudiConfig = None,
data_collator: Optional[DataCollator] = None,
label_pad_token_id: int = -100,
padding_value: int = None,
truncation_mode: str = "keep_end",
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
model_init: Optional[Callable[[], PreTrainedModel]] = None,
callbacks: Optional[List[TrainerCallback]] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
max_length: Optional[int] = None,
max_prompt_length: Optional[int] = None,
max_target_length: Optional[int] = None,
peft_config: Optional[Dict] = None,
is_encoder_decoder: Optional[bool] = None,
disable_dropout: bool = True,
generate_during_eval: bool = False,
compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None,
precompute_ref_log_probs: bool = False,
dataset_num_proc: Optional[int] = None,
model_init_kwargs: Optional[Dict] = None,
ref_model_init_kwargs: Optional[Dict] = None,
model_adapter_name: Optional[str] = None,
ref_adapter_name: Optional[str] = None,
reference_free: bool = False,
force_use_ref_model: bool = False,
):
"""
Copied from DPOTrainer.__init__: https://github.com/huggingface/trl/blob/v0.9.6/trl/trainer/dpo_trainer.py#L134
The only differences are:
- add new args gaudi_config
- use graph for ref_model
- use GaudiTrainer instead of Trainer
- cast peft model to bf16.
"""
if model_init_kwargs is not None:
warnings.warn(
"You passed `model_init_kwargs` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
)
args.model_init_kwargs = model_init_kwargs
if args.model_init_kwargs is None:
model_init_kwargs = {}
elif not isinstance(model, str):
raise ValueError(
"You passed model_init_kwargs to the DPOTrainer/DPOConfig, but your model is already instantiated."
)
else:
model_init_kwargs = args.model_init_kwargs
torch_dtype = model_init_kwargs["torch_dtype"]
if torch_dtype is not None:
# Convert to `torch.dtype` if an str is passed
if isinstance(torch_dtype, str) and torch_dtype != "auto":
torch_dtype = getattr(torch, torch_dtype)
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
raise ValueError(
f"Invalid `torch_dtype` passed to the DPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
)
model_init_kwargs["torch_dtype"] = torch_dtype
if ref_model_init_kwargs is not None:
warnings.warn(
"You passed `ref_model_init_kwargs` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
)
args.ref_model_init_kwargs = ref_model_init_kwargs
if args.ref_model_init_kwargs is None:
ref_model_init_kwargs = {}
elif not isinstance(ref_model, str):
raise ValueError(
"You passed ref_model_init_kwargs to the DPOTrainer/DPOConfig, but your ref_model is already instantiated."
)
else:
ref_model_init_kwargs = args.ref_model_init_kwargs
torch_dtype = ref_model_init_kwargs["torch_dtype"]
if torch_dtype is not None:
# Convert to `torch.dtype` if an str is passed
if isinstance(torch_dtype, str) and torch_dtype != "auto":
torch_dtype = getattr(torch, torch_dtype)
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
raise ValueError(
f"Invalid `torch_dtype` passed to the DPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
)
ref_model_init_kwargs["torch_dtype"] = torch_dtype
if isinstance(model, str):
warnings.warn(
"You passed a model_id to the DPOTrainer. This will automatically create an "
"`AutoModelForCausalLM` or a `PeftModel` (if you passed a `peft_config`) for you."
)
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
if isinstance(ref_model, str):
warnings.warn(
"You passed a ref model_id to the DPOTrainer. This will automatically create an `AutoModelForCausalLM`"
)
ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs)
# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
# has been called in order to properly call autocast if needed.
self._peft_has_been_casted_to_bf16 = False
if force_use_ref_model:
warnings.warn(
"You passed `force_use_ref_model` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
)
args.force_use_ref_model = force_use_ref_model
if not is_peft_available() and peft_config is not None:
raise ValueError(
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
)
elif is_peft_available() and peft_config is not None:
# if model is a peft model and we have a peft_config, we merge and unload it first
if isinstance(model, PeftModel):
model = model.merge_and_unload()
if ref_model is not None and not args.force_use_ref_model:
raise ValueError(
"You passed both a ref_model and a peft_config. For training PEFT adapters with DPO there is no need to pass a reference"
" model. Please pass `ref_model=None` in case you want to train PEFT adapters, or pass a ref_model with `force_use_ref_model=True` in DPOTrainer's init."
" if you want to use a different ref_model."
)
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
_support_gc_kwargs = hasattr(
args, "gradient_checkpointing_kwargs"
) and "gradient_checkpointing_kwargs" in list(
inspect.signature(prepare_model_for_kbit_training).parameters
)
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
if _support_gc_kwargs:
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
elif getattr(args, "gradient_checkpointing", False):
# For backward compatibility with older versions of transformers
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
# get peft model with the given config
model = get_peft_model(model, peft_config)
if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
model = model.to(torch.bfloat16)
# If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
self._peft_has_been_casted_to_bf16 = True
# For models that use gradient_checkpointing, we need to attach a hook that enables input
# to explicitly have `requires_grad=True`, otherwise training will either silently
# fail or completely fail.
elif getattr(args, "gradient_checkpointing", False):
# For backward compatibility with older versions of transformers
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
if generate_during_eval:
warnings.warn(
"You passed `generate_during_eval` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
)
args.generate_during_eval = generate_during_eval
if args.generate_during_eval and not is_wandb_available():
raise ValueError(
"`generate_during_eval=True` requires Weights and Biases to be installed."
" Please install `wandb` to resolve."
)
if is_encoder_decoder is not None:
warnings.warn(
"You passed `is_encoder_decoder` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
)
args.is_encoder_decoder = is_encoder_decoder
if model is not None:
self.is_encoder_decoder = model.config.is_encoder_decoder
elif args.is_encoder_decoder is None:
raise ValueError(
"When no model is provided, you need to pass the parameter is_encoder_decoder to the DPOTrainer/DPOConfig."
)
else:
self.is_encoder_decoder = args.is_encoder_decoder
if model is not None:
self.is_vision_model = model.config.model_type in MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES.keys()
else:
warnings.warn(
"No model provided, cannot determine if it is a vision model. Setting is_vision_model to False."
)
self.is_vision_model = False
if self.is_vision_model:
self.processor = tokenizer
self.tokenizer = tokenizer.tokenizer # tokenizer is actually a processor at this point
else:
self.tokenizer = tokenizer
self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
if model_adapter_name is not None:
warnings.warn(
"You passed `model_adapter_name` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
)
args.model_adapter_name = model_adapter_name
self.model_adapter_name = args.model_adapter_name
if ref_adapter_name is not None:
warnings.warn(
"You passed `ref_adapter_name` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
)
args.ref_adapter_name = ref_adapter_name
self.ref_adapter_name = args.ref_adapter_name
if reference_free:
warnings.warn(
"You passed `reference_free` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
)
args.reference_free = reference_free
self.reference_free = args.reference_free
if precompute_ref_log_probs:
warnings.warn(
"You passed `precompute_ref_log_probs` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
)
args.precompute_ref_log_probs = precompute_ref_log_probs
if ref_model:
self.ref_model = ref_model
elif self.is_peft_model or args.precompute_ref_log_probs:
# The `model` with adapters turned off will be used as the reference model
self.ref_model = None
else:
self.ref_model = create_reference_model(model)
if tokenizer is None:
raise ValueError("tokenizer must be specified to tokenize a DPO dataset.")
if max_length is not None:
warnings.warn(
"You passed `max_length` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
)
args.max_length = max_length
if args.max_length is None:
warnings.warn(
"`max_length` is not set in the DPOConfig's init"
" it will default to `512` by default, but you should do it yourself in the future.",
UserWarning,
)
args.max_length = 512
if max_prompt_length is not None:
warnings.warn(
"You passed `max_prompt_length` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
)
args.max_prompt_length = max_prompt_length
if args.max_prompt_length is None:
warnings.warn(
"`max_prompt_length` is not set in the DPOConfig's init"
" it will default to `128` by default, but you should do it yourself in the future.",
UserWarning,
)
args.max_prompt_length = 128
if max_target_length is not None:
warnings.warn(
"You passed `max_target_length` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
)
args.max_target_length = max_target_length
if args.max_target_length is None and self.is_encoder_decoder:
warnings.warn(
"When using an encoder decoder architecture, you should set `max_target_length` in the DPOConfig's init"
" it will default to `128` by default, but you should do it yourself in the future.",
UserWarning,
)
args.max_target_length = 128
if label_pad_token_id != -100:
warnings.warn(
"You passed `label_pad_token_id` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
)
args.label_pad_token_id = label_pad_token_id
if data_collator is None:
data_collator = DPODataCollatorWithPadding(
pad_token_id=self.tokenizer.pad_token_id,
label_pad_token_id=args.label_pad_token_id,
is_encoder_decoder=self.is_encoder_decoder,
)
if args.remove_unused_columns:
args.remove_unused_columns = False
# warn users
warnings.warn(
"When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments"
" we have set it for you, but you should do it yourself in the future.",
UserWarning,
)
self.use_dpo_data_collator = True
else:
self.use_dpo_data_collator = False
if not disable_dropout:
warnings.warn(
"You passed `disable_dropout` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
)
args.disable_dropout = disable_dropout
if args.disable_dropout:
disable_dropout_in_model(model)
if self.ref_model is not None:
disable_dropout_in_model(self.ref_model)
self.max_length = args.max_length
self.generate_during_eval = args.generate_during_eval
self.label_pad_token_id = args.label_pad_token_id
if padding_value is not None:
warnings.warn(
"You passed `padding_value` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
)
args.padding_value = padding_value
self.padding_value = args.padding_value if padding_value is not None else self.tokenizer.pad_token_id
self.max_prompt_length = args.max_prompt_length
if truncation_mode != "keep_end":
warnings.warn(
"You passed `truncation_mode` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
)
args.truncation_mode = truncation_mode
self.truncation_mode = args.truncation_mode
self.max_target_length = args.max_target_length
self.precompute_ref_log_probs = args.precompute_ref_log_probs
# Since ref_logs are precomputed on the first call to get_train/eval_dataloader
# keep track of first called to avoid computation of future calls
self._precomputed_train_ref_log_probs = False
self._precomputed_eval_ref_log_probs = False
if loss_type != "sigmoid":
warnings.warn(
"You passed `loss_type` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
)
args.loss_type = loss_type
if label_smoothing != 0:
warnings.warn(
"You passed `label_smoothing` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
)
args.label_smoothing = label_smoothing
if args.loss_type in ["hinge", "ipo", "bco_pair"] and args.label_smoothing > 0:
warnings.warn(
"You are using a loss type that does not support label smoothing. Ignoring label_smoothing parameter."
)
if args.loss_type == "kto_pair":
raise ValueError("Support for kto_pair has been removed in DPOTrainer. Please use KTOTrainer.")
if beta != 0.1:
warnings.warn(
"You passed `beta` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
)
args.beta = beta
self.beta = args.beta
self.label_smoothing = args.label_smoothing
self.loss_type = args.loss_type
self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
self._stored_metrics = defaultdict(lambda: defaultdict(list))
self.f_divergence_type = args.f_divergence_type
self.f_divergence_params = {FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY: args.f_alpha_divergence_coef}
if dataset_num_proc is not None:
warnings.warn(
"You passed `dataset_num_proc` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
)
args.dataset_num_proc = dataset_num_proc
self.dataset_num_proc = args.dataset_num_proc
# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
# tokenize the dataset, lower writer batch size to avoid OOM (frequent in vision models)
train_dataset = train_dataset.map(self.tokenize_row, num_proc=self.dataset_num_proc, writer_batch_size=10)
if eval_dataset is not None:
eval_dataset = eval_dataset.map(
self.tokenize_row, num_proc=self.dataset_num_proc, writer_batch_size=10
)
GaudiTrainer.__init__(
self,
model=model,
args=args,
gaudi_config=gaudi_config,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
model_init=model_init,
compute_metrics=compute_metrics,
callbacks=callbacks,
optimizers=optimizers,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)
# Add tags for models that have been loaded with the correct transformers version
if hasattr(self.model, "add_model_tags"):
self.model.add_model_tags(self._tag_names)
if not hasattr(self, "accelerator"):
raise AttributeError(
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
)
# Deepspeed Zero-3 does not support precompute_ref_log_probs
if self.is_deepspeed_enabled:
if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs:
raise ValueError(
"You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`."
)
if self.ref_model is None:
if not (self.is_peft_model or self.precompute_ref_log_probs):
raise ValueError(
"No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`"
)
if args.sync_ref_model:
raise ValueError(
"You currently cannot use `ref_model=None` with TR-DPO method. Please provide `ref_model`."
)
else:
if self.is_deepspeed_enabled:
self.ref_model = self._prepare_deepspeed(self.ref_model)
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
from habana_frameworks.torch.hpu import wrap_in_hpu_graph # use graph for ref_model
ref_model = self.accelerator.unwrap_model(self.ref_model)
ref_model = wrap_in_hpu_graph(ref_model)
if args.sync_ref_model:
if precompute_ref_log_probs:
raise ValueError(
"You cannot use `precompute_ref_log_probs=True` with TR-DPO method. Please set `precompute_ref_log_probs=False`."
)
self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator))
if self.loss_type == "bco_pair":
self.running = RunningMoments(self.accelerator)
@staticmethod
def concatenated_inputs(
batch: Dict[str, Union[List, torch.LongTensor]],
is_encoder_decoder: bool = False,
is_vision_model: bool = False,
label_pad_token_id: int = -100,
padding_value: int = 0,
device: Optional[torch.device] = None,
padded_max_length: int = 0,
) -> Dict[str, torch.LongTensor]:
"""
Copied from DPOTrainer.concatenated_inputs: https://github.com/huggingface/trl/blob/v0.9.6/trl/trainer/dpo_trainer.py#L979
- pad to self.max_length in Gaudi2
"""
concatenated_batch = {}
if is_encoder_decoder:
max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1])
else:
max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
if padded_max_length != 0: # pad to max_length in Gaudi
max_length = padded_max_length
for k in batch:
if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
if "labels" in k or is_encoder_decoder:
pad_value = label_pad_token_id
elif k.endswith("_input_ids"):
pad_value = padding_value
elif k.endswith("_attention_mask"):
pad_value = 0
concatenated_key = k.replace("chosen", "concatenated")
concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
for k in batch:
if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
if "labels" in k or is_encoder_decoder:
pad_value = label_pad_token_id
elif k.endswith("_input_ids"):
pad_value = padding_value
elif k.endswith("_attention_mask"):
pad_value = 0
concatenated_key = k.replace("rejected", "concatenated")
concatenated_batch[concatenated_key] = torch.cat(
(
concatenated_batch[concatenated_key],
pad_to_length(batch[k], max_length, pad_value=pad_value),
),
dim=0,
).to(device=device)
if is_encoder_decoder:
concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device)
concatenated_batch["concatenated_attention_mask"] = (
batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
)
if is_vision_model:
concatenated_batch["pixel_values"] = batch["prompt_pixel_values"].repeat(2, 1, 1, 1, 1).to(device=device)
concatenated_batch["pixel_attention_mask"] = (
batch["prompt_pixel_attention_mask"].repeat(2, 1, 1, 1).to(device=device)
)
return concatenated_batch
def concatenated_forward(
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""
Copied from DPOTrainer.concatenated_forward: https://github.com/huggingface/trl/blob/v0.9.6/trl/trainer/dpo_trainer.py#L1240
- pad to self.max_length in Gaudi2
"""
concatenated_batch = self.concatenated_inputs(
batch,
is_encoder_decoder=self.is_encoder_decoder,
is_vision_model=self.is_vision_model,
label_pad_token_id=self.label_pad_token_id,
padding_value=self.padding_value,
device=self.accelerator.device,
padded_max_length=self.max_length,
)
len_chosen = batch["chosen_labels"].shape[0]
model_kwargs = {}
if self.is_encoder_decoder:
model_kwargs["labels"] = concatenated_batch["concatenated_labels"]
model_kwargs["decoder_input_ids"] = concatenated_batch.pop("concatenated_decoder_input_ids", None)
if self.is_vision_model:
model_kwargs["pixel_values"] = concatenated_batch["pixel_values"]
model_kwargs["pixel_attention_mask"] = concatenated_batch["pixel_attention_mask"]
if self.aux_loss_enabled:
model_kwargs["output_router_logits"] = True
outputs = model(
concatenated_batch["concatenated_input_ids"],
attention_mask=concatenated_batch["concatenated_attention_mask"],
use_cache=False,
**model_kwargs,
)
all_logits = outputs.logits
all_logps, size_completion = self.get_batch_logps(
all_logits,
concatenated_batch["concatenated_labels"],
# average_log_prob=self.loss_type == "ipo",
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)
def cross_entropy_loss(logits, labels):
if not self.is_encoder_decoder:
# Shift so that tokens < n predict n
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
logits = logits.view(-1, logits.shape[-1])
labels = labels.view(-1)
# Enable model parallelism
labels = labels.to(logits.device)
loss = loss_fct(logits, labels)
return loss
labels = concatenated_batch["concatenated_labels"].clone()
nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
if self.loss_type == "ipo":
all_logps = all_logps / size_completion
chosen_logps = all_logps[:len_chosen]
rejected_logps = all_logps[len_chosen:]
chosen_logits = all_logits[:len_chosen]
rejected_logits = all_logits[len_chosen:]
if self.aux_loss_enabled:
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss, outputs.aux_loss)
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss)
def log(self, logs: Dict[str, float], **kwargs) -> None:
"""
Changes:
- add `**kwargs` to the method arguments to make sure it's compatible with Transformers
"""
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
return super().log(logs)
def compute_loss(
self,
model: Union[PreTrainedModel, nn.Module],
inputs: Dict[str, Union[torch.Tensor, Any]],
return_outputs=False,
num_items_in_batch=None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
"""
Copied from DPOTrainer.compute_loss: https://github.com/huggingface/trl/blob/v0.9.6/trl/trainer/dpo_trainer.py#L1393
- add num_items_in_batch to work with transformers 4.48
- use hpu autocast
"""
if not self.use_dpo_data_collator:
warnings.warn(
"compute_loss is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
"DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
)
compute_loss_context_manager = (
partial(torch.autocast, device_type="hpu", dtype=torch.bfloat16)
if self._peft_has_been_casted_to_bf16
else nullcontext
)
with compute_loss_context_manager():
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
# Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
loss = loss.to(self.args.device)
# force log the metrics
self.store_metrics(metrics, train_eval="train")
if return_outputs:
return (loss, metrics)
return loss