optimum/habana/trl/trainer/ppo_trainer.py (736 lines of code) (raw):
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import typing
import warnings
from contextlib import nullcontext
from typing import Callable, List, Optional, Union
import habana_frameworks.torch as ht
import numpy as np
import torch
from accelerate.utils import ProjectConfiguration
from datasets import Dataset
from torch.optim import Adam
from transformers import (
DataCollatorForLanguageModeling,
PreTrainedTokenizer,
PreTrainedTokenizerBase,
PreTrainedTokenizerFast,
)
from trl import PPOTrainer
from trl.core import (
WANDB_PADDING,
PPODecorators,
convert_to_scalar,
logprobs_from_logits,
stack_dicts,
stats_to_np,
)
from trl.import_utils import is_torch_greater_2_0
from trl.models import (
SUPPORTED_ARCHITECTURES,
PreTrainedModelWrapper,
create_reference_model,
unwrap_model_for_generation,
)
from trl.trainer import (
AdaptiveKLController,
BaseTrainer,
FixedKLController,
RunningMoments,
)
from ...utils import HabanaGenerationTime, set_seed
from . import GaudiPPOConfig
_recorded_graph = None
class GaudiPPOTrainer(PPOTrainer):
def __init__(
self,
config: Optional[GaudiPPOConfig] = None,
model: Optional[PreTrainedModelWrapper] = None,
ref_model: Optional[PreTrainedModelWrapper] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
dataset: Optional[Union[torch.utils.data.Dataset, Dataset]] = None,
optimizer: Optional[torch.optim.Optimizer] = None,
data_collator: Optional[typing.Callable] = None,
num_shared_layers: Optional[int] = None,
lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
training_data_collator: Optional[typing.Callable] = None,
):
"""
Copied from PPOTrainer.__init__: https://github.com/huggingface/trl/blob/v0.9.6/trl/trainer/ppo_trainer.py#L148
The only differences are:
- add new args for Gaudi in config
- use GaudiAccelerator instead of Accelerator
"""
BaseTrainer.__init__(self, config)
# initial seed for reproducible experiments
set_seed(config.seed)
# Step 0: check positional arguments validity
if not isinstance(config, GaudiPPOConfig):
raise ValueError(f"config must be a PPOConfig, got {type(config)}")
if not isinstance(tokenizer, (PreTrainedTokenizerBase)):
raise ValueError(
f"tokenizer must be a PreTrainedTokenizerBase like a PreTrainedTokenizer or a PreTrainedTokenizerFast, got {type(tokenizer)}"
)
if not isinstance(model, (SUPPORTED_ARCHITECTURES)):
raise ValueError(
f"model must be a PreTrainedModelWrapper, got {type(model)} - supported architectures are: {SUPPORTED_ARCHITECTURES}"
)
# Step 1: Initialize Accelerator
if config.use_habana:
from ...accelerate import GaudiAccelerator as Accelerator
else:
from accelerate import Accelerator
self.accelerator = Accelerator(
log_with=config.log_with,
gradient_accumulation_steps=config.gradient_accumulation_steps,
project_config=ProjectConfiguration(**config.project_kwargs),
**config.accelerator_kwargs,
)
# Step 1.1 Runtime variables filled by the accelerator
config.world_size = self.accelerator.num_processes
config.global_backward_batch_size = config.backward_batch_size * config.world_size
config.global_batch_size = config.batch_size * config.world_size
self.model = model.to(self.accelerator.device.type)
self.model_params = filter(lambda p: p.requires_grad, self.model.parameters())
self.is_encoder_decoder = hasattr(self.model, "is_encoder_decoder")
self.is_peft_model = getattr(self.model, "is_peft_model", False)
config.is_encoder_decoder = self.is_encoder_decoder
config.is_peft_model = self.is_peft_model
is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard"
self.accelerator.init_trackers(
config.tracker_project_name,
config=({"trl_ppo_trainer_config": config.to_dict()} if not is_using_tensorboard else config.to_dict()),
init_kwargs=config.tracker_kwargs,
)
self.is_using_text_environment = getattr(config, "use_text_environment", False)
if isinstance(ref_model, SUPPORTED_ARCHITECTURES):
self.ref_model = ref_model.to(self.accelerator.device.type)
if num_shared_layers is not None:
warnings.warn(
"num_shared_layers is ignored when ref_model is provided. Two different models are used for the "
"model and the reference model and no layers are shared.",
UserWarning,
)
elif ref_model is None and not self.is_peft_model:
self.ref_model = create_reference_model(self.model, num_shared_layers=num_shared_layers)
elif self.is_peft_model:
self.ref_model = None
else:
raise ValueError(
f"ref_model must be a PreTrainedModelWrapper or `None`, got {type(ref_model)} - supported "
f"architectures are: {SUPPORTED_ARCHITECTURES} "
)
self.optional_peft_ctx = (
self.accelerator.unwrap_model(self.model).pretrained_model.disable_adapter
if self.is_peft_model
else nullcontext
)
if not (isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast)):
raise ValueError(
"tokenizer must be a transformers.PreTrainedTokenizer or transformers.PreTrainedTokenizerFast"
)
self.tokenizer = tokenizer
if dataset is not None and not (isinstance(dataset, torch.utils.data.Dataset) or isinstance(dataset, Dataset)):
raise ValueError("dataset must be a torch.utils.data.Dataset or datasets.Dataset")
elif dataset is None:
warnings.warn(
"No dataset is provided. Make sure to set config.batch_size to the correct value before training.",
UserWarning,
)
self.dataset = dataset
self._signature_columns = None
if self.dataset is not None:
self.dataloader = self.prepare_dataloader(self.dataset, data_collator)
elif self.dataset is None and self.accelerator.num_processes > 1:
warnings.warn(
"No dataset is provided. In a multi-GPU setting, this will lead to an error. You should"
" prepare your dataloader yourself with `dataloader = ppo_trainer.accelerator.prepare(dataloader)`"
" and using `torch.utils.data.DataLoader`, or pass a dataset to the `PPOTrainer`. Please "
" refer to the documentation for more details.",
UserWarning,
)
self.dataloader = None
else:
self.dataloader = None
# Step 3: Initialize optimizer and data collator
if training_data_collator is None:
self.data_collator = DataCollatorForLanguageModeling(self.tokenizer, mlm=False)
else:
self.data_collator = training_data_collator
if optimizer is None:
self.optimizer = Adam(
filter(lambda p: p.requires_grad, self.model.parameters()),
lr=self.config.learning_rate,
)
else:
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
if self.lr_scheduler is not None:
lr_scheduler_class = (
torch.optim.lr_scheduler._LRScheduler
if not is_torch_greater_2_0()
else torch.optim.lr_scheduler.LRScheduler
)
if not isinstance(self.lr_scheduler, lr_scheduler_class):
raise ValueError(
"lr_scheduler must be a torch.optim.lr_scheduler._LRScheduler or torch.optim.lr_scheduler.LRScheduler (for torch >= 2.0)"
)
if self.config.adap_kl_ctrl:
self.kl_ctl = AdaptiveKLController(self.config.init_kl_coef, self.config.target, self.config.horizon)
else:
self.kl_ctl = FixedKLController(self.config.init_kl_coef)
if self.accelerator.distributed_type == "MULTI_HPU":
from accelerate.utils import DistributedDataParallelKwargs
kwargs = {}
kwargs["find_unused_parameters"] = True
kwargs["gradient_as_bucket_view"] = True
self.accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs)
# Safety checkers for DS integration
is_deepspeed_used = self.accelerator.distributed_type == "DEEPSPEED" and hasattr(
self.accelerator.state, "deepspeed_plugin"
)
if config.gradient_checkpointing:
self.model.gradient_checkpointing_enable()
if hasattr(self.model, "enable_input_require_grads"):
self.model.enable_input_require_grads()
else:
# For backward compatibility with older versions of transformers
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
self.model.pretrained_model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
(
self.model,
self.optimizer,
self.data_collator,
self.dataloader,
self.lr_scheduler,
) = self.accelerator.prepare(
self.model,
self.optimizer,
self.data_collator,
self.dataloader,
self.lr_scheduler,
)
if is_deepspeed_used:
# Quantized models are already set on the correct device
if not self.is_peft_model and not (
getattr(self.ref_model.pretrained_model, "is_loaded_in_8bit", False)
or getattr(self.ref_model.pretrained_model, "is_loaded_in_4bit", False)
):
self.ref_model = self._prepare_deepspeed(self.ref_model)
else:
self.ref_model = self.accelerator.prepare(self.ref_model)
# In a distributed setup, only logging needs to be performed on the main process
# check: https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
# or: https://discuss.pytorch.org/t/use-distributed-data-parallel-correctly/82500/11
self.is_distributed = self.accelerator.num_processes > 1
# init the current step
self.current_step = 0
# init variables for pushing model to hub
if config.push_to_hub_if_best_kwargs:
if "repo_id" not in config.push_to_hub_if_best_kwargs:
raise ValueError("You have to specify repo_id in order to push the model to the hub!")
self.push_to_hub_kwargs = config.push_to_hub_if_best_kwargs
self.compare_step = 0
self.highest_reward = torch.tensor(-float("inf"))
# post process for PP
if not getattr(self.model, "is_sequential_parallel", False):
self.current_device = self.accelerator.device
else:
if self.accelerator.device.type == "hpu":
self.current_device = torch.device("hpu")
else:
self.current_device = torch.device("cpu")
PPODecorators.optimize_device_cache = self.config.optimize_device_cache
self.running = RunningMoments(self.accelerator)
if config.use_habana:
import habana_frameworks.torch.core as htcore
self.htcore = htcore
def generate(
self,
query_tensor: Union[torch.Tensor, List[torch.Tensor]],
length_sampler: Callable = None,
batch_size: int = 4,
return_prompt: bool = True,
generate_ref_response: bool = False,
**generation_kwargs,
):
"""
Copied from PPOTrainer.generate: https://github.com/huggingface/trl/blob/v0.9.6/trl/trainer/ppo_trainer.py#L455
The only differences are:
- add hpu graph for acceleration
"""
if generate_ref_response:
ref_model = self.model if self.is_peft_model else self.ref_model
if isinstance(query_tensor, List):
if self.config.use_habana:
self.wrap_generation_for_hpu_graph_mode(self.model)
response = self._generate_batched(
self.model,
query_tensor,
length_sampler=length_sampler,
batch_size=batch_size,
return_prompt=return_prompt,
**generation_kwargs,
)
if generate_ref_response:
if self.config.use_habana:
self.wrap_generation_for_hpu_graph_mode(ref_model)
ref_response = self._generate_batched(
ref_model,
query_tensor,
length_sampler=length_sampler,
batch_size=batch_size,
return_prompt=return_prompt,
**generation_kwargs,
)
else:
if len(query_tensor.shape) == 2:
raise ValueError(
"query_tensor must be a tensor of shape (`seq_len`) or a list of tensors of shape (`seq_len`)"
)
if length_sampler is not None:
generation_kwargs["max_new_tokens"] = length_sampler()
if self.config.use_habana:
self.wrap_generation_for_hpu_graph_mode(self.model)
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
response = unwrapped_model.generate(input_ids=query_tensor.unsqueeze(dim=0), **generation_kwargs)
if generate_ref_response:
if self.config.use_habana:
self.wrap_generation_for_hpu_graph_mode(ref_model)
with unwrap_model_for_generation(
ref_model, self.accelerator, is_peft_model=self.is_peft_model
) as unwrapped_model:
ref_response = unwrapped_model.generate(
input_ids=query_tensor.unsqueeze(dim=0), **generation_kwargs
)
if not return_prompt and not self.is_encoder_decoder:
response = response[:, query_tensor.shape[0] :]
if generate_ref_response:
ref_response = ref_response[:, query_tensor.shape[0] :]
if generate_ref_response:
return response, ref_response
return response
def _generate_batched(
self,
model: PreTrainedModelWrapper,
query_tensors: List[torch.Tensor],
length_sampler: Optional[Callable] = None,
batch_size: int = 4,
return_prompt: bool = True,
pad_to_multiple_of: Optional[int] = None,
remove_padding: bool = True,
**generation_kwargs,
):
"""
Copied from PPOTrainer._generate_batched: https://github.com/huggingface/trl/blob/v0.9.6/trl/trainer/ppo_trainer.py#L535
The only differences are:
- pad to pad_max_input_len to get static shape for generation acceleration
- use lazy mode and hpu_graphs for generation in hpu
"""
outputs = []
padding_side_default = self.tokenizer.padding_side
if not self.is_encoder_decoder:
self.tokenizer.padding_side = "left"
# in case we have fewer examples than bs
batch_size = min(len(query_tensors), batch_size)
for i in range(0, len(query_tensors), batch_size):
if length_sampler is not None:
generation_kwargs["max_new_tokens"] = length_sampler()
# prevent overflow if query tensors are not even multiple of bs
end_index = min(len(query_tensors), i + batch_size)
batch = query_tensors[i:end_index]
batch_mask = [torch.ones_like(element) for element in batch]
inputs = {"input_ids": batch, "attention_mask": batch_mask}
if self.config.pad_for_acceleration and self.config.pad_max_input_len > 0:
padded_inputs = self.tokenizer.pad(
inputs,
padding="max_length",
max_length=self.config.pad_max_input_len,
pad_to_multiple_of=pad_to_multiple_of,
return_tensors="pt",
).to(self.current_device)
else:
padded_inputs = self.tokenizer.pad(
inputs,
padding=True,
max_length=None,
pad_to_multiple_of=pad_to_multiple_of,
return_tensors="pt",
).to(self.current_device)
if self.config.use_habana:
generation_kwargs["ignore_eos"] = False
generation_kwargs["lazy_mode"] = True
generation_kwargs["hpu_graphs"] = True
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
generations = unwrapped_model.generate(**padded_inputs, **generation_kwargs)
for generation, mask in zip(generations, padded_inputs["attention_mask"]):
if not self.is_encoder_decoder:
output = generation[(1 - mask).sum() :] # remove padding
else:
output = generation
if not return_prompt and not self.is_encoder_decoder:
output = output[(mask).sum() :] # remove prompt
if remove_padding and self.tokenizer.eos_token_id in output:
pad_mask = output == self.tokenizer.eos_token_id
pad_start = torch.nonzero(pad_mask, as_tuple=False)[0, 0].item()
output = output[: pad_start + 1] # keep the eos token at the end
outputs.append(output)
self.tokenizer.padding_side = padding_side_default
return outputs
@PPODecorators.empty_device_cache()
def step(
self,
queries: List[torch.LongTensor],
responses: List[torch.LongTensor],
scores: List[torch.FloatTensor],
response_masks: Optional[List[torch.LongTensor]] = None,
):
"""
Copied from PPOTrainer.step: https://github.com/huggingface/trl/blob/v0.9.6/trl/trainer/ppo_trainer.py#L647
The only differences are:
- use hpu_graphs for sampling and training
- remove duplicated padding if padding is done in prepare_model_inputs
"""
bs = self.config.batch_size
queries, responses, scores, response_masks = self._step_safety_checker(
bs, queries, responses, scores, response_masks
)
scores = torch.tensor(scores, device=self.current_device)
if self.config.use_score_scaling:
# Score scaling
scores_mean, scores_std = self.running.update(scores)
tensor_to_kwargs = {"dtype": scores.dtype, "device": scores.device}
score_scaling_factor = self.running.std.to(**tensor_to_kwargs) + torch.finfo(scores.dtype).eps
if self.config.use_score_norm:
scores = (scores - self.running.mean.to(**tensor_to_kwargs)) / score_scaling_factor
else:
scores /= score_scaling_factor
if self.config.score_clip is not None:
# Score clipping
scores_dtype = scores.dtype
scores = torch.clip(scores.float(), -self.config.score_clip, self.config.score_clip).to(dtype=scores_dtype)
# if we want to push best model to the hub
if hasattr(self, "highest_reward"):
if self.compare_step % self.config.compare_steps == 0:
curr_mean_reward = scores.mean()
# if the best reward ever seen
if curr_mean_reward > self.highest_reward:
self.highest_reward = curr_mean_reward
# push model to hub
self.push_to_hub(**self.push_to_hub_kwargs)
self.compare_step += 1
timing = {}
timer = HabanaGenerationTime()
timer.start()
model_inputs = self.prepare_model_inputs(queries, responses)
if self.is_distributed and not self.config.pad_for_acceleration:
pad_first = self.tokenizer.padding_side == "left"
model_inputs["input_ids"] = self.accelerator.pad_across_processes(
model_inputs["input_ids"],
dim=1,
pad_index=self.tokenizer.pad_token_id,
pad_first=pad_first,
)
model_inputs["attention_mask"] = self.accelerator.pad_across_processes(
model_inputs["attention_mask"], dim=1, pad_index=0, pad_first=pad_first
)
if self.is_encoder_decoder:
model_inputs["decoder_input_ids"] = self.accelerator.pad_across_processes(
model_inputs["decoder_input_ids"],
dim=1,
pad_index=self.tokenizer.pad_token_id,
pad_first=pad_first,
)
model_inputs["decoder_attention_mask"] = self.accelerator.pad_across_processes(
model_inputs["decoder_attention_mask"],
dim=1,
pad_index=0,
pad_first=pad_first,
)
model_inputs_names = list(model_inputs.keys())
full_kl_penalty = self.config.kl_penalty == "full"
with torch.no_grad():
if self.config.use_habana:
self.unwrap_generation_for_hpu_graph_mode(self.model)
self.wrap_fw_for_hpu_graph_mode(self.model)
if self.ref_model is not None:
self.unwrap_generation_for_hpu_graph_mode(self.ref_model)
self.wrap_fw_for_hpu_graph_mode(self.ref_model)
all_logprobs, logits_or_none, values, masks = self.batched_forward_pass(
self.model,
queries,
responses,
model_inputs,
response_masks=response_masks,
return_logits=full_kl_penalty,
)
with self.optional_peft_ctx():
ref_logprobs, ref_logits_or_none, _, _ = self.batched_forward_pass(
self.model if self.is_peft_model else self.ref_model,
queries,
responses,
model_inputs,
return_logits=full_kl_penalty,
)
timer.step()
timing["time/ppo/forward_pass"] = timer.last_duration
with torch.no_grad():
timer.step()
if full_kl_penalty:
active_full_logprobs = logprobs_from_logits(logits_or_none, None, gather=False)
ref_full_logprobs = logprobs_from_logits(ref_logits_or_none, None, gather=False)
rewards, non_score_reward, kls = self.compute_rewards(
scores, active_full_logprobs, ref_full_logprobs, masks
)
else:
rewards, non_score_reward, kls = self.compute_rewards(scores, all_logprobs, ref_logprobs, masks)
timer.step()
timing["time/ppo/compute_rewards"] = timer.last_duration
timer.step()
values, advantages, returns = self.compute_advantages(values, rewards, masks)
timer.step()
timing["time/ppo/compute_advantages"] = timer.last_duration
# upcast to float32 to avoid dataset issues
batch_dict = {
"queries": queries,
"responses": responses,
"logprobs": all_logprobs.to(torch.float32),
"values": values.to(torch.float32),
"masks": masks,
"advantages": advantages,
"returns": returns,
}
batch_dict.update(model_inputs)
timer.step()
all_stats = []
early_stop = False
if self.config.use_habana:
self.unwrap_fw_for_hpu_graph_mode(self.model)
import habana_frameworks.torch as ht
model = self.accelerator.unwrap_model(self.model)
if not hasattr(model, "wrap_train_in_graph"):
model = ht.hpu.wrap_in_hpu_graph(model)
setattr(model, "wrap_train_in_graph", model.forward)
else:
model.forward = getattr(model, "wrap_train_in_graph")
for _ in range(self.config.ppo_epochs):
if early_stop:
break
b_inds = np.random.permutation(bs)
for backward_batch_start in range(0, bs, self.config.backward_batch_size):
backward_batch_end = backward_batch_start + self.config.backward_batch_size
backward_batch_inds = b_inds[backward_batch_start:backward_batch_end]
for mini_batch_start in range(0, self.config.backward_batch_size, self.config.mini_batch_size):
mini_batch_end = mini_batch_start + self.config.mini_batch_size
mini_batch_inds = backward_batch_inds[mini_batch_start:mini_batch_end]
mini_batch_dict = {
"logprobs": batch_dict["logprobs"][mini_batch_inds],
"values": batch_dict["values"][mini_batch_inds],
"masks": batch_dict["masks"][mini_batch_inds],
# hacks: the queries and responses are ragged.
"queries": [batch_dict["queries"][i] for i in mini_batch_inds],
"responses": [batch_dict["responses"][i] for i in mini_batch_inds],
"advantages": batch_dict["advantages"][mini_batch_inds],
"returns": batch_dict["returns"][mini_batch_inds],
}
for k in model_inputs_names:
mini_batch_dict[k] = batch_dict[k][mini_batch_inds]
with self.accelerator.accumulate(self.model):
model_inputs = {k: mini_batch_dict[k] for k in model_inputs_names}
logprobs, logits, vpreds, _ = self.batched_forward_pass(
self.model,
mini_batch_dict["queries"],
mini_batch_dict["responses"],
model_inputs,
return_logits=True,
)
train_stats = self.train_minibatch(
mini_batch_dict["logprobs"],
mini_batch_dict["values"],
logprobs,
logits,
vpreds,
mini_batch_dict["masks"],
mini_batch_dict["advantages"],
mini_batch_dict["returns"],
)
all_stats.append(train_stats)
# typically, early stopping is done at the epoch level
if self.config.early_stopping:
policykl = train_stats["policy/policykl"]
early_stop = self._early_stop(policykl)
if early_stop:
break
timer.step()
timing["time/ppo/optimize_step"] = timer.last_duration
timer.step()
train_stats = stack_dicts(all_stats)
# reshape advantages/ratios such that they are not averaged.
train_stats["policy/advantages"] = torch.flatten(train_stats["policy/advantages"]).unsqueeze(0)
train_stats["policy/advantages"] = torch.nan_to_num(train_stats["policy/advantages"], WANDB_PADDING)
train_stats["policy/ratio"] = torch.flatten(train_stats["policy/ratio"]).unsqueeze(0)
stats = self.record_step_stats(
scores=scores,
logprobs=all_logprobs,
ref_logprobs=ref_logprobs,
non_score_reward=non_score_reward,
train_stats=train_stats,
kl_coef=self.kl_ctl.value,
masks=masks,
queries=queries,
responses=responses,
kls=kls,
)
# Gather/Reduce stats from all processes
if self.is_distributed:
stats = self.gather_stats(stats)
stats = stats_to_np(stats)
timer.step()
timing["time/ppo/calc_stats"] = timer.last_duration
stats["ppo/learning_rate"] = self.optimizer.param_groups[0]["lr"]
# Update the KL control - multiply the batch_size by the number of processes
self.kl_ctl.update(
stats["objective/kl"],
self.config.batch_size * self.accelerator.num_processes,
)
# Log the total ppo time
timing["time/ppo/total"] = timer.total_time()
stats.update(timing)
# post-process stats for tensorboard and other loggers
if self.config.log_with != "wandb":
stats = convert_to_scalar(stats)
if self.lr_scheduler is not None:
self.lr_scheduler.step()
return stats
def prepare_model_inputs(self, queries: torch.Tensor, responses: torch.Tensor):
"""
Copied from PPOTrainer.prepare_model_inputs: https://github.com/huggingface/trl/blob/v0.9.6/trl/trainer/ppo_trainer.py#L949
The only differences are:
- add padding to model inputs for static shape support in forward
"""
if self.is_encoder_decoder:
input_data = self.data_collator(
[{"input_ids": q, "attention_mask": torch.ones_like(q)} for q in queries]
).to(self.current_device)
decoder_inputs = self.data_collator(
[{"input_ids": r, "attention_mask": torch.ones_like(r)} for r in responses]
).to(self.current_device)
input_data["decoder_input_ids"] = decoder_inputs["input_ids"]
input_data["decoder_attention_mask"] = decoder_inputs["attention_mask"]
else:
input_ids = [torch.cat([q, r]) for q, r in zip(queries, responses)]
input_data = self.data_collator(
[{"input_ids": ids, "attention_mask": torch.ones_like(ids)} for ids in input_ids]
).to(self.current_device)
if self.config.pad_for_acceleration:
input_data["input_ids"] = torch.nn.functional.pad(
input_data["input_ids"],
(0, self.config.pad_max_len - input_data["input_ids"].shape[1]),
value=self.tokenizer.pad_token_id,
)
input_data["attention_mask"] = torch.nn.functional.pad(
input_data["attention_mask"],
(
0,
self.config.pad_max_len - input_data["attention_mask"].shape[1],
),
value=0,
)
if self.is_encoder_decoder:
input_data["decoder_input_ids"] = torch.nn.functional.pad(
input_data["decoder_input_ids"],
(
0,
self.config.pad_max_len - input_data["decoder_input_ids"].shape[1],
),
value=self.tokenizer.pad_token_id,
)
input_data["decoder_attention_mask"] = torch.nn.functional.pad(
input_data["decoder_attention_mask"],
(
0,
self.config.pad_max_len - input_data["decoder_attention_mask"].shape[1],
),
value=0,
)
input_data.pop("labels", None) # we don't want to compute LM losses
return input_data
@PPODecorators.empty_device_cache()
def batched_forward_pass(
self,
model: PreTrainedModelWrapper,
queries: torch.Tensor,
responses: torch.Tensor,
model_inputs: dict,
return_logits: bool = False,
response_masks: Optional[torch.Tensor] = None,
):
"""
Copied from PPOTrainer.batched_forward_pass: https://github.com/huggingface/trl/blob/v0.9.6/trl/trainer/ppo_trainer.py#L971
The only differences are:
- input_kwargs/output need to clone() to avoid overidden in hpu
"""
bs = len(queries)
fbs = self.config.mini_batch_size
all_logprobs = []
all_logits = []
all_masks = []
all_values = []
model.eval()
for i in range(math.ceil(bs / fbs)):
input_kwargs = {key: value[i * fbs : (i + 1) * fbs].clone() for key, value in model_inputs.items()}
query_batch = queries[i * fbs : (i + 1) * fbs]
response_batch = responses[i * fbs : (i + 1) * fbs]
if response_masks is not None:
response_masks_batch = response_masks[i * fbs : (i + 1) * fbs]
logits, _, values = model(**input_kwargs)
if self.is_encoder_decoder:
input_ids = input_kwargs["decoder_input_ids"]
attention_mask = input_kwargs["decoder_attention_mask"]
else:
input_ids = input_kwargs["input_ids"]
attention_mask = input_kwargs["attention_mask"]
logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
masks = torch.zeros_like(attention_mask)
masks[:, :-1] = attention_mask[:, 1:]
for j in range(len(query_batch)):
if self.is_encoder_decoder:
# Decoder sentence starts always in the index 1 after padding in the Enc-Dec Models
start = 1
end = attention_mask[j, :].sum() - 1
else:
start = len(query_batch[j]) - 1 # logprobs starts from the second query token
if attention_mask[j, 0] == 0: # offset left padding
start += attention_mask[j, :].nonzero()[0]
end = start + len(response_batch[j])
if response_masks is not None:
response_masks_batch[j] = torch.cat(
(torch.zeros_like(query_batch[j]), response_masks_batch[j])
)[1:]
masks[j, :start] = 0
masks[j, end:] = 0
if response_masks is not None:
masks[j, start:end] = masks[j, start:end] * response_masks_batch[j][start:end]
if return_logits:
all_logits.append(logits.clone())
else:
del logits
all_values.append(values.clone())
all_logprobs.append(logprobs)
all_masks.append(masks)
return (
torch.cat(all_logprobs),
torch.cat(all_logits)[:, :-1] if return_logits else None,
torch.cat(all_values)[:, :-1],
torch.cat(all_masks)[:, :-1],
)
@PPODecorators.empty_device_cache()
def train_minibatch(
self,
old_logprobs: torch.FloatTensor,
values: torch.FloatTensor,
logprobs: torch.FloatTensor,
logits: torch.FloatTensor,
vpreds: torch.FloatTensor,
mask: torch.LongTensor,
advantages: torch.FloatTensor,
returns: torch.FloatTensor,
):
"""
Copied from PPOTrainer.batched_forward_pass: https://github.com/huggingface/trl/blob/v0.9.6/trl/trainer/ppo_trainer.py#L1058
The only differences are:
- add htcore.mark_step
"""
self.model.train()
loss_p, loss_v, train_stats = self.loss(
old_logprobs, values, logits, vpreds, logprobs, mask, advantages, returns
)
loss = loss_p + loss_v
global _recorded_graph
if _recorded_graph is None:
_recorded_graph = ht.hpu.HPUGraph()
s = ht.hpu.default_stream()
with ht.hpu.stream(s):
_recorded_graph.capture_begin()
self.accelerator.backward(loss)
_recorded_graph.capture_end()
else:
_recorded_graph.replay()
if self.config.max_grad_norm is not None:
if self.accelerator.sync_gradients:
self.accelerator.clip_grad_norm_(self.model_params, self.config.max_grad_norm)
self.optimizer.step()
if self.config.use_habana:
self.htcore.mark_step()
# we call optimizer.zero_grad() every time and let `accelerator` handle accumulation
# see https://huggingface.co/docs/accelerate/usage_guides/gradient_accumulation#the-finished-code
self.optimizer.zero_grad()
return train_stats
def wrap_fw_for_hpu_graph_mode(self, model: PreTrainedModelWrapper):
model = self.accelerator.unwrap_model(model)
if hasattr(model, "hpu_graph_fw"):
model.forward = model.hpu_graph_fw
else:
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
model.orig_fw = model.forward
model = wrap_in_hpu_graph(model)
model.hpu_graph_fw = model.forward
def unwrap_fw_for_hpu_graph_mode(self, model: PreTrainedModelWrapper):
model = self.accelerator.unwrap_model(model)
if hasattr(model, "orig_fw"):
model.forward = model.orig_fw
def wrap_generation_for_hpu_graph_mode(self, model: PreTrainedModelWrapper):
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
model = self.accelerator.unwrap_model(model)
if getattr(model, "is_peft_model", False):
if hasattr(model.pretrained_model.base_model.model, "hpu_graph_fw"):
model.pretrained_model.base_model.model.forward = model.pretrained_model.base_model.model.hpu_graph_fw
else:
model.pretrained_model.base_model.model.orig_fw = model.pretrained_model.base_model.model.forward
model.pretrained_model.base_model.model = wrap_in_hpu_graph(model.pretrained_model.base_model.model)
model.pretrained_model.base_model.model.hpu_graph_fw = model.pretrained_model.base_model.model.forward
else:
if hasattr(model.pretrained_model, "hpu_graph_fw"):
model.pretrained_model.forward = model.pretrained_model.hpu_graph_fw
else:
model.pretrained_model.orig_fw = model.pretrained_model.forward
model.pretrained_model = wrap_in_hpu_graph(model.pretrained_model)
model.pretrained_model.hpu_graph_fw = model.pretrained_model.forward
def unwrap_generation_for_hpu_graph_mode(self, model: PreTrainedModelWrapper):
model = self.accelerator.unwrap_model(model)
if getattr(model, "is_peft_model", False):
if hasattr(model.pretrained_model.base_model.model, "orig_fw"):
model.pretrained_model.base_model.model.forward = model.pretrained_model.base_model.model.orig_fw
else:
if hasattr(model.pretrained_model, "orig_fw"):
model.pretrained_model.forward = model.pretrained_model.orig_fw