optimum/habana/trl/trainer/ppo_trainer.py (736 lines of code) (raw):

# Copyright 2022 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import typing import warnings from contextlib import nullcontext from typing import Callable, List, Optional, Union import habana_frameworks.torch as ht import numpy as np import torch from accelerate.utils import ProjectConfiguration from datasets import Dataset from torch.optim import Adam from transformers import ( DataCollatorForLanguageModeling, PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast, ) from trl import PPOTrainer from trl.core import ( WANDB_PADDING, PPODecorators, convert_to_scalar, logprobs_from_logits, stack_dicts, stats_to_np, ) from trl.import_utils import is_torch_greater_2_0 from trl.models import ( SUPPORTED_ARCHITECTURES, PreTrainedModelWrapper, create_reference_model, unwrap_model_for_generation, ) from trl.trainer import ( AdaptiveKLController, BaseTrainer, FixedKLController, RunningMoments, ) from ...utils import HabanaGenerationTime, set_seed from . import GaudiPPOConfig _recorded_graph = None class GaudiPPOTrainer(PPOTrainer): def __init__( self, config: Optional[GaudiPPOConfig] = None, model: Optional[PreTrainedModelWrapper] = None, ref_model: Optional[PreTrainedModelWrapper] = None, tokenizer: Optional[PreTrainedTokenizerBase] = None, dataset: Optional[Union[torch.utils.data.Dataset, Dataset]] = None, optimizer: Optional[torch.optim.Optimizer] = None, data_collator: Optional[typing.Callable] = None, num_shared_layers: Optional[int] = None, lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, training_data_collator: Optional[typing.Callable] = None, ): """ Copied from PPOTrainer.__init__: https://github.com/huggingface/trl/blob/v0.9.6/trl/trainer/ppo_trainer.py#L148 The only differences are: - add new args for Gaudi in config - use GaudiAccelerator instead of Accelerator """ BaseTrainer.__init__(self, config) # initial seed for reproducible experiments set_seed(config.seed) # Step 0: check positional arguments validity if not isinstance(config, GaudiPPOConfig): raise ValueError(f"config must be a PPOConfig, got {type(config)}") if not isinstance(tokenizer, (PreTrainedTokenizerBase)): raise ValueError( f"tokenizer must be a PreTrainedTokenizerBase like a PreTrainedTokenizer or a PreTrainedTokenizerFast, got {type(tokenizer)}" ) if not isinstance(model, (SUPPORTED_ARCHITECTURES)): raise ValueError( f"model must be a PreTrainedModelWrapper, got {type(model)} - supported architectures are: {SUPPORTED_ARCHITECTURES}" ) # Step 1: Initialize Accelerator if config.use_habana: from ...accelerate import GaudiAccelerator as Accelerator else: from accelerate import Accelerator self.accelerator = Accelerator( log_with=config.log_with, gradient_accumulation_steps=config.gradient_accumulation_steps, project_config=ProjectConfiguration(**config.project_kwargs), **config.accelerator_kwargs, ) # Step 1.1 Runtime variables filled by the accelerator config.world_size = self.accelerator.num_processes config.global_backward_batch_size = config.backward_batch_size * config.world_size config.global_batch_size = config.batch_size * config.world_size self.model = model.to(self.accelerator.device.type) self.model_params = filter(lambda p: p.requires_grad, self.model.parameters()) self.is_encoder_decoder = hasattr(self.model, "is_encoder_decoder") self.is_peft_model = getattr(self.model, "is_peft_model", False) config.is_encoder_decoder = self.is_encoder_decoder config.is_peft_model = self.is_peft_model is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard" self.accelerator.init_trackers( config.tracker_project_name, config=({"trl_ppo_trainer_config": config.to_dict()} if not is_using_tensorboard else config.to_dict()), init_kwargs=config.tracker_kwargs, ) self.is_using_text_environment = getattr(config, "use_text_environment", False) if isinstance(ref_model, SUPPORTED_ARCHITECTURES): self.ref_model = ref_model.to(self.accelerator.device.type) if num_shared_layers is not None: warnings.warn( "num_shared_layers is ignored when ref_model is provided. Two different models are used for the " "model and the reference model and no layers are shared.", UserWarning, ) elif ref_model is None and not self.is_peft_model: self.ref_model = create_reference_model(self.model, num_shared_layers=num_shared_layers) elif self.is_peft_model: self.ref_model = None else: raise ValueError( f"ref_model must be a PreTrainedModelWrapper or `None`, got {type(ref_model)} - supported " f"architectures are: {SUPPORTED_ARCHITECTURES} " ) self.optional_peft_ctx = ( self.accelerator.unwrap_model(self.model).pretrained_model.disable_adapter if self.is_peft_model else nullcontext ) if not (isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast)): raise ValueError( "tokenizer must be a transformers.PreTrainedTokenizer or transformers.PreTrainedTokenizerFast" ) self.tokenizer = tokenizer if dataset is not None and not (isinstance(dataset, torch.utils.data.Dataset) or isinstance(dataset, Dataset)): raise ValueError("dataset must be a torch.utils.data.Dataset or datasets.Dataset") elif dataset is None: warnings.warn( "No dataset is provided. Make sure to set config.batch_size to the correct value before training.", UserWarning, ) self.dataset = dataset self._signature_columns = None if self.dataset is not None: self.dataloader = self.prepare_dataloader(self.dataset, data_collator) elif self.dataset is None and self.accelerator.num_processes > 1: warnings.warn( "No dataset is provided. In a multi-GPU setting, this will lead to an error. You should" " prepare your dataloader yourself with `dataloader = ppo_trainer.accelerator.prepare(dataloader)`" " and using `torch.utils.data.DataLoader`, or pass a dataset to the `PPOTrainer`. Please " " refer to the documentation for more details.", UserWarning, ) self.dataloader = None else: self.dataloader = None # Step 3: Initialize optimizer and data collator if training_data_collator is None: self.data_collator = DataCollatorForLanguageModeling(self.tokenizer, mlm=False) else: self.data_collator = training_data_collator if optimizer is None: self.optimizer = Adam( filter(lambda p: p.requires_grad, self.model.parameters()), lr=self.config.learning_rate, ) else: self.optimizer = optimizer self.lr_scheduler = lr_scheduler if self.lr_scheduler is not None: lr_scheduler_class = ( torch.optim.lr_scheduler._LRScheduler if not is_torch_greater_2_0() else torch.optim.lr_scheduler.LRScheduler ) if not isinstance(self.lr_scheduler, lr_scheduler_class): raise ValueError( "lr_scheduler must be a torch.optim.lr_scheduler._LRScheduler or torch.optim.lr_scheduler.LRScheduler (for torch >= 2.0)" ) if self.config.adap_kl_ctrl: self.kl_ctl = AdaptiveKLController(self.config.init_kl_coef, self.config.target, self.config.horizon) else: self.kl_ctl = FixedKLController(self.config.init_kl_coef) if self.accelerator.distributed_type == "MULTI_HPU": from accelerate.utils import DistributedDataParallelKwargs kwargs = {} kwargs["find_unused_parameters"] = True kwargs["gradient_as_bucket_view"] = True self.accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs) # Safety checkers for DS integration is_deepspeed_used = self.accelerator.distributed_type == "DEEPSPEED" and hasattr( self.accelerator.state, "deepspeed_plugin" ) if config.gradient_checkpointing: self.model.gradient_checkpointing_enable() if hasattr(self.model, "enable_input_require_grads"): self.model.enable_input_require_grads() else: # For backward compatibility with older versions of transformers def make_inputs_require_grad(module, input, output): output.requires_grad_(True) self.model.pretrained_model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) ( self.model, self.optimizer, self.data_collator, self.dataloader, self.lr_scheduler, ) = self.accelerator.prepare( self.model, self.optimizer, self.data_collator, self.dataloader, self.lr_scheduler, ) if is_deepspeed_used: # Quantized models are already set on the correct device if not self.is_peft_model and not ( getattr(self.ref_model.pretrained_model, "is_loaded_in_8bit", False) or getattr(self.ref_model.pretrained_model, "is_loaded_in_4bit", False) ): self.ref_model = self._prepare_deepspeed(self.ref_model) else: self.ref_model = self.accelerator.prepare(self.ref_model) # In a distributed setup, only logging needs to be performed on the main process # check: https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html # or: https://discuss.pytorch.org/t/use-distributed-data-parallel-correctly/82500/11 self.is_distributed = self.accelerator.num_processes > 1 # init the current step self.current_step = 0 # init variables for pushing model to hub if config.push_to_hub_if_best_kwargs: if "repo_id" not in config.push_to_hub_if_best_kwargs: raise ValueError("You have to specify repo_id in order to push the model to the hub!") self.push_to_hub_kwargs = config.push_to_hub_if_best_kwargs self.compare_step = 0 self.highest_reward = torch.tensor(-float("inf")) # post process for PP if not getattr(self.model, "is_sequential_parallel", False): self.current_device = self.accelerator.device else: if self.accelerator.device.type == "hpu": self.current_device = torch.device("hpu") else: self.current_device = torch.device("cpu") PPODecorators.optimize_device_cache = self.config.optimize_device_cache self.running = RunningMoments(self.accelerator) if config.use_habana: import habana_frameworks.torch.core as htcore self.htcore = htcore def generate( self, query_tensor: Union[torch.Tensor, List[torch.Tensor]], length_sampler: Callable = None, batch_size: int = 4, return_prompt: bool = True, generate_ref_response: bool = False, **generation_kwargs, ): """ Copied from PPOTrainer.generate: https://github.com/huggingface/trl/blob/v0.9.6/trl/trainer/ppo_trainer.py#L455 The only differences are: - add hpu graph for acceleration """ if generate_ref_response: ref_model = self.model if self.is_peft_model else self.ref_model if isinstance(query_tensor, List): if self.config.use_habana: self.wrap_generation_for_hpu_graph_mode(self.model) response = self._generate_batched( self.model, query_tensor, length_sampler=length_sampler, batch_size=batch_size, return_prompt=return_prompt, **generation_kwargs, ) if generate_ref_response: if self.config.use_habana: self.wrap_generation_for_hpu_graph_mode(ref_model) ref_response = self._generate_batched( ref_model, query_tensor, length_sampler=length_sampler, batch_size=batch_size, return_prompt=return_prompt, **generation_kwargs, ) else: if len(query_tensor.shape) == 2: raise ValueError( "query_tensor must be a tensor of shape (`seq_len`) or a list of tensors of shape (`seq_len`)" ) if length_sampler is not None: generation_kwargs["max_new_tokens"] = length_sampler() if self.config.use_habana: self.wrap_generation_for_hpu_graph_mode(self.model) with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model: response = unwrapped_model.generate(input_ids=query_tensor.unsqueeze(dim=0), **generation_kwargs) if generate_ref_response: if self.config.use_habana: self.wrap_generation_for_hpu_graph_mode(ref_model) with unwrap_model_for_generation( ref_model, self.accelerator, is_peft_model=self.is_peft_model ) as unwrapped_model: ref_response = unwrapped_model.generate( input_ids=query_tensor.unsqueeze(dim=0), **generation_kwargs ) if not return_prompt and not self.is_encoder_decoder: response = response[:, query_tensor.shape[0] :] if generate_ref_response: ref_response = ref_response[:, query_tensor.shape[0] :] if generate_ref_response: return response, ref_response return response def _generate_batched( self, model: PreTrainedModelWrapper, query_tensors: List[torch.Tensor], length_sampler: Optional[Callable] = None, batch_size: int = 4, return_prompt: bool = True, pad_to_multiple_of: Optional[int] = None, remove_padding: bool = True, **generation_kwargs, ): """ Copied from PPOTrainer._generate_batched: https://github.com/huggingface/trl/blob/v0.9.6/trl/trainer/ppo_trainer.py#L535 The only differences are: - pad to pad_max_input_len to get static shape for generation acceleration - use lazy mode and hpu_graphs for generation in hpu """ outputs = [] padding_side_default = self.tokenizer.padding_side if not self.is_encoder_decoder: self.tokenizer.padding_side = "left" # in case we have fewer examples than bs batch_size = min(len(query_tensors), batch_size) for i in range(0, len(query_tensors), batch_size): if length_sampler is not None: generation_kwargs["max_new_tokens"] = length_sampler() # prevent overflow if query tensors are not even multiple of bs end_index = min(len(query_tensors), i + batch_size) batch = query_tensors[i:end_index] batch_mask = [torch.ones_like(element) for element in batch] inputs = {"input_ids": batch, "attention_mask": batch_mask} if self.config.pad_for_acceleration and self.config.pad_max_input_len > 0: padded_inputs = self.tokenizer.pad( inputs, padding="max_length", max_length=self.config.pad_max_input_len, pad_to_multiple_of=pad_to_multiple_of, return_tensors="pt", ).to(self.current_device) else: padded_inputs = self.tokenizer.pad( inputs, padding=True, max_length=None, pad_to_multiple_of=pad_to_multiple_of, return_tensors="pt", ).to(self.current_device) if self.config.use_habana: generation_kwargs["ignore_eos"] = False generation_kwargs["lazy_mode"] = True generation_kwargs["hpu_graphs"] = True with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: generations = unwrapped_model.generate(**padded_inputs, **generation_kwargs) for generation, mask in zip(generations, padded_inputs["attention_mask"]): if not self.is_encoder_decoder: output = generation[(1 - mask).sum() :] # remove padding else: output = generation if not return_prompt and not self.is_encoder_decoder: output = output[(mask).sum() :] # remove prompt if remove_padding and self.tokenizer.eos_token_id in output: pad_mask = output == self.tokenizer.eos_token_id pad_start = torch.nonzero(pad_mask, as_tuple=False)[0, 0].item() output = output[: pad_start + 1] # keep the eos token at the end outputs.append(output) self.tokenizer.padding_side = padding_side_default return outputs @PPODecorators.empty_device_cache() def step( self, queries: List[torch.LongTensor], responses: List[torch.LongTensor], scores: List[torch.FloatTensor], response_masks: Optional[List[torch.LongTensor]] = None, ): """ Copied from PPOTrainer.step: https://github.com/huggingface/trl/blob/v0.9.6/trl/trainer/ppo_trainer.py#L647 The only differences are: - use hpu_graphs for sampling and training - remove duplicated padding if padding is done in prepare_model_inputs """ bs = self.config.batch_size queries, responses, scores, response_masks = self._step_safety_checker( bs, queries, responses, scores, response_masks ) scores = torch.tensor(scores, device=self.current_device) if self.config.use_score_scaling: # Score scaling scores_mean, scores_std = self.running.update(scores) tensor_to_kwargs = {"dtype": scores.dtype, "device": scores.device} score_scaling_factor = self.running.std.to(**tensor_to_kwargs) + torch.finfo(scores.dtype).eps if self.config.use_score_norm: scores = (scores - self.running.mean.to(**tensor_to_kwargs)) / score_scaling_factor else: scores /= score_scaling_factor if self.config.score_clip is not None: # Score clipping scores_dtype = scores.dtype scores = torch.clip(scores.float(), -self.config.score_clip, self.config.score_clip).to(dtype=scores_dtype) # if we want to push best model to the hub if hasattr(self, "highest_reward"): if self.compare_step % self.config.compare_steps == 0: curr_mean_reward = scores.mean() # if the best reward ever seen if curr_mean_reward > self.highest_reward: self.highest_reward = curr_mean_reward # push model to hub self.push_to_hub(**self.push_to_hub_kwargs) self.compare_step += 1 timing = {} timer = HabanaGenerationTime() timer.start() model_inputs = self.prepare_model_inputs(queries, responses) if self.is_distributed and not self.config.pad_for_acceleration: pad_first = self.tokenizer.padding_side == "left" model_inputs["input_ids"] = self.accelerator.pad_across_processes( model_inputs["input_ids"], dim=1, pad_index=self.tokenizer.pad_token_id, pad_first=pad_first, ) model_inputs["attention_mask"] = self.accelerator.pad_across_processes( model_inputs["attention_mask"], dim=1, pad_index=0, pad_first=pad_first ) if self.is_encoder_decoder: model_inputs["decoder_input_ids"] = self.accelerator.pad_across_processes( model_inputs["decoder_input_ids"], dim=1, pad_index=self.tokenizer.pad_token_id, pad_first=pad_first, ) model_inputs["decoder_attention_mask"] = self.accelerator.pad_across_processes( model_inputs["decoder_attention_mask"], dim=1, pad_index=0, pad_first=pad_first, ) model_inputs_names = list(model_inputs.keys()) full_kl_penalty = self.config.kl_penalty == "full" with torch.no_grad(): if self.config.use_habana: self.unwrap_generation_for_hpu_graph_mode(self.model) self.wrap_fw_for_hpu_graph_mode(self.model) if self.ref_model is not None: self.unwrap_generation_for_hpu_graph_mode(self.ref_model) self.wrap_fw_for_hpu_graph_mode(self.ref_model) all_logprobs, logits_or_none, values, masks = self.batched_forward_pass( self.model, queries, responses, model_inputs, response_masks=response_masks, return_logits=full_kl_penalty, ) with self.optional_peft_ctx(): ref_logprobs, ref_logits_or_none, _, _ = self.batched_forward_pass( self.model if self.is_peft_model else self.ref_model, queries, responses, model_inputs, return_logits=full_kl_penalty, ) timer.step() timing["time/ppo/forward_pass"] = timer.last_duration with torch.no_grad(): timer.step() if full_kl_penalty: active_full_logprobs = logprobs_from_logits(logits_or_none, None, gather=False) ref_full_logprobs = logprobs_from_logits(ref_logits_or_none, None, gather=False) rewards, non_score_reward, kls = self.compute_rewards( scores, active_full_logprobs, ref_full_logprobs, masks ) else: rewards, non_score_reward, kls = self.compute_rewards(scores, all_logprobs, ref_logprobs, masks) timer.step() timing["time/ppo/compute_rewards"] = timer.last_duration timer.step() values, advantages, returns = self.compute_advantages(values, rewards, masks) timer.step() timing["time/ppo/compute_advantages"] = timer.last_duration # upcast to float32 to avoid dataset issues batch_dict = { "queries": queries, "responses": responses, "logprobs": all_logprobs.to(torch.float32), "values": values.to(torch.float32), "masks": masks, "advantages": advantages, "returns": returns, } batch_dict.update(model_inputs) timer.step() all_stats = [] early_stop = False if self.config.use_habana: self.unwrap_fw_for_hpu_graph_mode(self.model) import habana_frameworks.torch as ht model = self.accelerator.unwrap_model(self.model) if not hasattr(model, "wrap_train_in_graph"): model = ht.hpu.wrap_in_hpu_graph(model) setattr(model, "wrap_train_in_graph", model.forward) else: model.forward = getattr(model, "wrap_train_in_graph") for _ in range(self.config.ppo_epochs): if early_stop: break b_inds = np.random.permutation(bs) for backward_batch_start in range(0, bs, self.config.backward_batch_size): backward_batch_end = backward_batch_start + self.config.backward_batch_size backward_batch_inds = b_inds[backward_batch_start:backward_batch_end] for mini_batch_start in range(0, self.config.backward_batch_size, self.config.mini_batch_size): mini_batch_end = mini_batch_start + self.config.mini_batch_size mini_batch_inds = backward_batch_inds[mini_batch_start:mini_batch_end] mini_batch_dict = { "logprobs": batch_dict["logprobs"][mini_batch_inds], "values": batch_dict["values"][mini_batch_inds], "masks": batch_dict["masks"][mini_batch_inds], # hacks: the queries and responses are ragged. "queries": [batch_dict["queries"][i] for i in mini_batch_inds], "responses": [batch_dict["responses"][i] for i in mini_batch_inds], "advantages": batch_dict["advantages"][mini_batch_inds], "returns": batch_dict["returns"][mini_batch_inds], } for k in model_inputs_names: mini_batch_dict[k] = batch_dict[k][mini_batch_inds] with self.accelerator.accumulate(self.model): model_inputs = {k: mini_batch_dict[k] for k in model_inputs_names} logprobs, logits, vpreds, _ = self.batched_forward_pass( self.model, mini_batch_dict["queries"], mini_batch_dict["responses"], model_inputs, return_logits=True, ) train_stats = self.train_minibatch( mini_batch_dict["logprobs"], mini_batch_dict["values"], logprobs, logits, vpreds, mini_batch_dict["masks"], mini_batch_dict["advantages"], mini_batch_dict["returns"], ) all_stats.append(train_stats) # typically, early stopping is done at the epoch level if self.config.early_stopping: policykl = train_stats["policy/policykl"] early_stop = self._early_stop(policykl) if early_stop: break timer.step() timing["time/ppo/optimize_step"] = timer.last_duration timer.step() train_stats = stack_dicts(all_stats) # reshape advantages/ratios such that they are not averaged. train_stats["policy/advantages"] = torch.flatten(train_stats["policy/advantages"]).unsqueeze(0) train_stats["policy/advantages"] = torch.nan_to_num(train_stats["policy/advantages"], WANDB_PADDING) train_stats["policy/ratio"] = torch.flatten(train_stats["policy/ratio"]).unsqueeze(0) stats = self.record_step_stats( scores=scores, logprobs=all_logprobs, ref_logprobs=ref_logprobs, non_score_reward=non_score_reward, train_stats=train_stats, kl_coef=self.kl_ctl.value, masks=masks, queries=queries, responses=responses, kls=kls, ) # Gather/Reduce stats from all processes if self.is_distributed: stats = self.gather_stats(stats) stats = stats_to_np(stats) timer.step() timing["time/ppo/calc_stats"] = timer.last_duration stats["ppo/learning_rate"] = self.optimizer.param_groups[0]["lr"] # Update the KL control - multiply the batch_size by the number of processes self.kl_ctl.update( stats["objective/kl"], self.config.batch_size * self.accelerator.num_processes, ) # Log the total ppo time timing["time/ppo/total"] = timer.total_time() stats.update(timing) # post-process stats for tensorboard and other loggers if self.config.log_with != "wandb": stats = convert_to_scalar(stats) if self.lr_scheduler is not None: self.lr_scheduler.step() return stats def prepare_model_inputs(self, queries: torch.Tensor, responses: torch.Tensor): """ Copied from PPOTrainer.prepare_model_inputs: https://github.com/huggingface/trl/blob/v0.9.6/trl/trainer/ppo_trainer.py#L949 The only differences are: - add padding to model inputs for static shape support in forward """ if self.is_encoder_decoder: input_data = self.data_collator( [{"input_ids": q, "attention_mask": torch.ones_like(q)} for q in queries] ).to(self.current_device) decoder_inputs = self.data_collator( [{"input_ids": r, "attention_mask": torch.ones_like(r)} for r in responses] ).to(self.current_device) input_data["decoder_input_ids"] = decoder_inputs["input_ids"] input_data["decoder_attention_mask"] = decoder_inputs["attention_mask"] else: input_ids = [torch.cat([q, r]) for q, r in zip(queries, responses)] input_data = self.data_collator( [{"input_ids": ids, "attention_mask": torch.ones_like(ids)} for ids in input_ids] ).to(self.current_device) if self.config.pad_for_acceleration: input_data["input_ids"] = torch.nn.functional.pad( input_data["input_ids"], (0, self.config.pad_max_len - input_data["input_ids"].shape[1]), value=self.tokenizer.pad_token_id, ) input_data["attention_mask"] = torch.nn.functional.pad( input_data["attention_mask"], ( 0, self.config.pad_max_len - input_data["attention_mask"].shape[1], ), value=0, ) if self.is_encoder_decoder: input_data["decoder_input_ids"] = torch.nn.functional.pad( input_data["decoder_input_ids"], ( 0, self.config.pad_max_len - input_data["decoder_input_ids"].shape[1], ), value=self.tokenizer.pad_token_id, ) input_data["decoder_attention_mask"] = torch.nn.functional.pad( input_data["decoder_attention_mask"], ( 0, self.config.pad_max_len - input_data["decoder_attention_mask"].shape[1], ), value=0, ) input_data.pop("labels", None) # we don't want to compute LM losses return input_data @PPODecorators.empty_device_cache() def batched_forward_pass( self, model: PreTrainedModelWrapper, queries: torch.Tensor, responses: torch.Tensor, model_inputs: dict, return_logits: bool = False, response_masks: Optional[torch.Tensor] = None, ): """ Copied from PPOTrainer.batched_forward_pass: https://github.com/huggingface/trl/blob/v0.9.6/trl/trainer/ppo_trainer.py#L971 The only differences are: - input_kwargs/output need to clone() to avoid overidden in hpu """ bs = len(queries) fbs = self.config.mini_batch_size all_logprobs = [] all_logits = [] all_masks = [] all_values = [] model.eval() for i in range(math.ceil(bs / fbs)): input_kwargs = {key: value[i * fbs : (i + 1) * fbs].clone() for key, value in model_inputs.items()} query_batch = queries[i * fbs : (i + 1) * fbs] response_batch = responses[i * fbs : (i + 1) * fbs] if response_masks is not None: response_masks_batch = response_masks[i * fbs : (i + 1) * fbs] logits, _, values = model(**input_kwargs) if self.is_encoder_decoder: input_ids = input_kwargs["decoder_input_ids"] attention_mask = input_kwargs["decoder_attention_mask"] else: input_ids = input_kwargs["input_ids"] attention_mask = input_kwargs["attention_mask"] logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:]) masks = torch.zeros_like(attention_mask) masks[:, :-1] = attention_mask[:, 1:] for j in range(len(query_batch)): if self.is_encoder_decoder: # Decoder sentence starts always in the index 1 after padding in the Enc-Dec Models start = 1 end = attention_mask[j, :].sum() - 1 else: start = len(query_batch[j]) - 1 # logprobs starts from the second query token if attention_mask[j, 0] == 0: # offset left padding start += attention_mask[j, :].nonzero()[0] end = start + len(response_batch[j]) if response_masks is not None: response_masks_batch[j] = torch.cat( (torch.zeros_like(query_batch[j]), response_masks_batch[j]) )[1:] masks[j, :start] = 0 masks[j, end:] = 0 if response_masks is not None: masks[j, start:end] = masks[j, start:end] * response_masks_batch[j][start:end] if return_logits: all_logits.append(logits.clone()) else: del logits all_values.append(values.clone()) all_logprobs.append(logprobs) all_masks.append(masks) return ( torch.cat(all_logprobs), torch.cat(all_logits)[:, :-1] if return_logits else None, torch.cat(all_values)[:, :-1], torch.cat(all_masks)[:, :-1], ) @PPODecorators.empty_device_cache() def train_minibatch( self, old_logprobs: torch.FloatTensor, values: torch.FloatTensor, logprobs: torch.FloatTensor, logits: torch.FloatTensor, vpreds: torch.FloatTensor, mask: torch.LongTensor, advantages: torch.FloatTensor, returns: torch.FloatTensor, ): """ Copied from PPOTrainer.batched_forward_pass: https://github.com/huggingface/trl/blob/v0.9.6/trl/trainer/ppo_trainer.py#L1058 The only differences are: - add htcore.mark_step """ self.model.train() loss_p, loss_v, train_stats = self.loss( old_logprobs, values, logits, vpreds, logprobs, mask, advantages, returns ) loss = loss_p + loss_v global _recorded_graph if _recorded_graph is None: _recorded_graph = ht.hpu.HPUGraph() s = ht.hpu.default_stream() with ht.hpu.stream(s): _recorded_graph.capture_begin() self.accelerator.backward(loss) _recorded_graph.capture_end() else: _recorded_graph.replay() if self.config.max_grad_norm is not None: if self.accelerator.sync_gradients: self.accelerator.clip_grad_norm_(self.model_params, self.config.max_grad_norm) self.optimizer.step() if self.config.use_habana: self.htcore.mark_step() # we call optimizer.zero_grad() every time and let `accelerator` handle accumulation # see https://huggingface.co/docs/accelerate/usage_guides/gradient_accumulation#the-finished-code self.optimizer.zero_grad() return train_stats def wrap_fw_for_hpu_graph_mode(self, model: PreTrainedModelWrapper): model = self.accelerator.unwrap_model(model) if hasattr(model, "hpu_graph_fw"): model.forward = model.hpu_graph_fw else: from habana_frameworks.torch.hpu import wrap_in_hpu_graph model.orig_fw = model.forward model = wrap_in_hpu_graph(model) model.hpu_graph_fw = model.forward def unwrap_fw_for_hpu_graph_mode(self, model: PreTrainedModelWrapper): model = self.accelerator.unwrap_model(model) if hasattr(model, "orig_fw"): model.forward = model.orig_fw def wrap_generation_for_hpu_graph_mode(self, model: PreTrainedModelWrapper): from habana_frameworks.torch.hpu import wrap_in_hpu_graph model = self.accelerator.unwrap_model(model) if getattr(model, "is_peft_model", False): if hasattr(model.pretrained_model.base_model.model, "hpu_graph_fw"): model.pretrained_model.base_model.model.forward = model.pretrained_model.base_model.model.hpu_graph_fw else: model.pretrained_model.base_model.model.orig_fw = model.pretrained_model.base_model.model.forward model.pretrained_model.base_model.model = wrap_in_hpu_graph(model.pretrained_model.base_model.model) model.pretrained_model.base_model.model.hpu_graph_fw = model.pretrained_model.base_model.model.forward else: if hasattr(model.pretrained_model, "hpu_graph_fw"): model.pretrained_model.forward = model.pretrained_model.hpu_graph_fw else: model.pretrained_model.orig_fw = model.pretrained_model.forward model.pretrained_model = wrap_in_hpu_graph(model.pretrained_model) model.pretrained_model.hpu_graph_fw = model.pretrained_model.forward def unwrap_generation_for_hpu_graph_mode(self, model: PreTrainedModelWrapper): model = self.accelerator.unwrap_model(model) if getattr(model, "is_peft_model", False): if hasattr(model.pretrained_model.base_model.model, "orig_fw"): model.pretrained_model.base_model.model.forward = model.pretrained_model.base_model.model.orig_fw else: if hasattr(model.pretrained_model, "orig_fw"): model.pretrained_model.forward = model.pretrained_model.orig_fw