# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import typing
import warnings
from contextlib import nullcontext
from typing import Callable, List, Optional, Union

import habana_frameworks.torch as ht
import numpy as np
import torch
from accelerate.utils import ProjectConfiguration
from datasets import Dataset
from torch.optim import Adam
from transformers import (
    DataCollatorForLanguageModeling,
    PreTrainedTokenizer,
    PreTrainedTokenizerBase,
    PreTrainedTokenizerFast,
)
from trl import PPOTrainer
from trl.core import (
    WANDB_PADDING,
    PPODecorators,
    convert_to_scalar,
    logprobs_from_logits,
    stack_dicts,
    stats_to_np,
)
from trl.import_utils import is_torch_greater_2_0
from trl.models import (
    SUPPORTED_ARCHITECTURES,
    PreTrainedModelWrapper,
    create_reference_model,
    unwrap_model_for_generation,
)
from trl.trainer import (
    AdaptiveKLController,
    BaseTrainer,
    FixedKLController,
    RunningMoments,
)

from ...utils import HabanaGenerationTime, set_seed
from . import GaudiPPOConfig


_recorded_graph = None


class GaudiPPOTrainer(PPOTrainer):
    def __init__(
        self,
        config: Optional[GaudiPPOConfig] = None,
        model: Optional[PreTrainedModelWrapper] = None,
        ref_model: Optional[PreTrainedModelWrapper] = None,
        tokenizer: Optional[PreTrainedTokenizerBase] = None,
        dataset: Optional[Union[torch.utils.data.Dataset, Dataset]] = None,
        optimizer: Optional[torch.optim.Optimizer] = None,
        data_collator: Optional[typing.Callable] = None,
        num_shared_layers: Optional[int] = None,
        lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
        training_data_collator: Optional[typing.Callable] = None,
    ):
        """
        Copied from PPOTrainer.__init__: https://github.com/huggingface/trl/blob/v0.9.6/trl/trainer/ppo_trainer.py#L148
        The only differences are:
        - add new args for Gaudi in config
        - use GaudiAccelerator instead of Accelerator
        """
        BaseTrainer.__init__(self, config)

        # initial seed for reproducible experiments
        set_seed(config.seed)

        # Step 0: check positional arguments validity
        if not isinstance(config, GaudiPPOConfig):
            raise ValueError(f"config must be a PPOConfig, got {type(config)}")
        if not isinstance(tokenizer, (PreTrainedTokenizerBase)):
            raise ValueError(
                f"tokenizer must be a PreTrainedTokenizerBase like a PreTrainedTokenizer or a PreTrainedTokenizerFast, got {type(tokenizer)}"
            )
        if not isinstance(model, (SUPPORTED_ARCHITECTURES)):
            raise ValueError(
                f"model must be a PreTrainedModelWrapper, got {type(model)} - supported architectures are: {SUPPORTED_ARCHITECTURES}"
            )
        # Step 1: Initialize Accelerator
        if config.use_habana:
            from ...accelerate import GaudiAccelerator as Accelerator
        else:
            from accelerate import Accelerator
        self.accelerator = Accelerator(
            log_with=config.log_with,
            gradient_accumulation_steps=config.gradient_accumulation_steps,
            project_config=ProjectConfiguration(**config.project_kwargs),
            **config.accelerator_kwargs,
        )

        # Step 1.1 Runtime variables filled by the accelerator
        config.world_size = self.accelerator.num_processes
        config.global_backward_batch_size = config.backward_batch_size * config.world_size
        config.global_batch_size = config.batch_size * config.world_size

        self.model = model.to(self.accelerator.device.type)
        self.model_params = filter(lambda p: p.requires_grad, self.model.parameters())
        self.is_encoder_decoder = hasattr(self.model, "is_encoder_decoder")
        self.is_peft_model = getattr(self.model, "is_peft_model", False)
        config.is_encoder_decoder = self.is_encoder_decoder
        config.is_peft_model = self.is_peft_model

        is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard"
        self.accelerator.init_trackers(
            config.tracker_project_name,
            config=({"trl_ppo_trainer_config": config.to_dict()} if not is_using_tensorboard else config.to_dict()),
            init_kwargs=config.tracker_kwargs,
        )
        self.is_using_text_environment = getattr(config, "use_text_environment", False)

        if isinstance(ref_model, SUPPORTED_ARCHITECTURES):
            self.ref_model = ref_model.to(self.accelerator.device.type)
            if num_shared_layers is not None:
                warnings.warn(
                    "num_shared_layers is ignored when ref_model is provided. Two different models are used for the "
                    "model and the reference model and no layers are shared.",
                    UserWarning,
                )
        elif ref_model is None and not self.is_peft_model:
            self.ref_model = create_reference_model(self.model, num_shared_layers=num_shared_layers)
        elif self.is_peft_model:
            self.ref_model = None
        else:
            raise ValueError(
                f"ref_model must be a PreTrainedModelWrapper or `None`, got {type(ref_model)} - supported "
                f"architectures are: {SUPPORTED_ARCHITECTURES} "
            )
        self.optional_peft_ctx = (
            self.accelerator.unwrap_model(self.model).pretrained_model.disable_adapter
            if self.is_peft_model
            else nullcontext
        )

        if not (isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast)):
            raise ValueError(
                "tokenizer must be a transformers.PreTrainedTokenizer or transformers.PreTrainedTokenizerFast"
            )
        self.tokenizer = tokenizer

        if dataset is not None and not (isinstance(dataset, torch.utils.data.Dataset) or isinstance(dataset, Dataset)):
            raise ValueError("dataset must be a torch.utils.data.Dataset or datasets.Dataset")
        elif dataset is None:
            warnings.warn(
                "No dataset is provided. Make sure to set config.batch_size to the correct value before training.",
                UserWarning,
            )
        self.dataset = dataset
        self._signature_columns = None
        if self.dataset is not None:
            self.dataloader = self.prepare_dataloader(self.dataset, data_collator)
        elif self.dataset is None and self.accelerator.num_processes > 1:
            warnings.warn(
                "No dataset is provided. In a multi-GPU setting, this will lead to an error. You should"
                " prepare your dataloader yourself with `dataloader = ppo_trainer.accelerator.prepare(dataloader)`"
                " and using `torch.utils.data.DataLoader`, or pass a dataset to the `PPOTrainer`. Please "
                " refer to the documentation for more details.",
                UserWarning,
            )
            self.dataloader = None
        else:
            self.dataloader = None

        # Step 3: Initialize optimizer and data collator
        if training_data_collator is None:
            self.data_collator = DataCollatorForLanguageModeling(self.tokenizer, mlm=False)
        else:
            self.data_collator = training_data_collator
        if optimizer is None:
            self.optimizer = Adam(
                filter(lambda p: p.requires_grad, self.model.parameters()),
                lr=self.config.learning_rate,
            )
        else:
            self.optimizer = optimizer

        self.lr_scheduler = lr_scheduler
        if self.lr_scheduler is not None:
            lr_scheduler_class = (
                torch.optim.lr_scheduler._LRScheduler
                if not is_torch_greater_2_0()
                else torch.optim.lr_scheduler.LRScheduler
            )

            if not isinstance(self.lr_scheduler, lr_scheduler_class):
                raise ValueError(
                    "lr_scheduler must be a torch.optim.lr_scheduler._LRScheduler or torch.optim.lr_scheduler.LRScheduler (for torch >= 2.0)"
                )

        if self.config.adap_kl_ctrl:
            self.kl_ctl = AdaptiveKLController(self.config.init_kl_coef, self.config.target, self.config.horizon)
        else:
            self.kl_ctl = FixedKLController(self.config.init_kl_coef)

        if self.accelerator.distributed_type == "MULTI_HPU":
            from accelerate.utils import DistributedDataParallelKwargs

            kwargs = {}
            kwargs["find_unused_parameters"] = True
            kwargs["gradient_as_bucket_view"] = True
            self.accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs)

        # Safety checkers for DS integration
        is_deepspeed_used = self.accelerator.distributed_type == "DEEPSPEED" and hasattr(
            self.accelerator.state, "deepspeed_plugin"
        )

        if config.gradient_checkpointing:
            self.model.gradient_checkpointing_enable()

            if hasattr(self.model, "enable_input_require_grads"):
                self.model.enable_input_require_grads()
            else:
                # For backward compatibility with older versions of transformers
                def make_inputs_require_grad(module, input, output):
                    output.requires_grad_(True)

                self.model.pretrained_model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

        (
            self.model,
            self.optimizer,
            self.data_collator,
            self.dataloader,
            self.lr_scheduler,
        ) = self.accelerator.prepare(
            self.model,
            self.optimizer,
            self.data_collator,
            self.dataloader,
            self.lr_scheduler,
        )
        if is_deepspeed_used:
            # Quantized models are already set on the correct device
            if not self.is_peft_model and not (
                getattr(self.ref_model.pretrained_model, "is_loaded_in_8bit", False)
                or getattr(self.ref_model.pretrained_model, "is_loaded_in_4bit", False)
            ):
                self.ref_model = self._prepare_deepspeed(self.ref_model)
        else:
            self.ref_model = self.accelerator.prepare(self.ref_model)

        # In a distributed setup, only logging needs to be performed on the main process
        # check: https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
        # or: https://discuss.pytorch.org/t/use-distributed-data-parallel-correctly/82500/11
        self.is_distributed = self.accelerator.num_processes > 1

        # init the current step
        self.current_step = 0

        # init variables for pushing model to hub
        if config.push_to_hub_if_best_kwargs:
            if "repo_id" not in config.push_to_hub_if_best_kwargs:
                raise ValueError("You have to specify repo_id in order to push the model to the hub!")
            self.push_to_hub_kwargs = config.push_to_hub_if_best_kwargs
            self.compare_step = 0
            self.highest_reward = torch.tensor(-float("inf"))

        # post process for PP
        if not getattr(self.model, "is_sequential_parallel", False):
            self.current_device = self.accelerator.device
        else:
            if self.accelerator.device.type == "hpu":
                self.current_device = torch.device("hpu")
            else:
                self.current_device = torch.device("cpu")

        PPODecorators.optimize_device_cache = self.config.optimize_device_cache

        self.running = RunningMoments(self.accelerator)
        if config.use_habana:
            import habana_frameworks.torch.core as htcore

            self.htcore = htcore

    def generate(
        self,
        query_tensor: Union[torch.Tensor, List[torch.Tensor]],
        length_sampler: Callable = None,
        batch_size: int = 4,
        return_prompt: bool = True,
        generate_ref_response: bool = False,
        **generation_kwargs,
    ):
        """
        Copied from PPOTrainer.generate: https://github.com/huggingface/trl/blob/v0.9.6/trl/trainer/ppo_trainer.py#L455
        The only differences are:
        - add hpu graph for acceleration
        """
        if generate_ref_response:
            ref_model = self.model if self.is_peft_model else self.ref_model
        if isinstance(query_tensor, List):
            if self.config.use_habana:
                self.wrap_generation_for_hpu_graph_mode(self.model)
            response = self._generate_batched(
                self.model,
                query_tensor,
                length_sampler=length_sampler,
                batch_size=batch_size,
                return_prompt=return_prompt,
                **generation_kwargs,
            )
            if generate_ref_response:
                if self.config.use_habana:
                    self.wrap_generation_for_hpu_graph_mode(ref_model)
                ref_response = self._generate_batched(
                    ref_model,
                    query_tensor,
                    length_sampler=length_sampler,
                    batch_size=batch_size,
                    return_prompt=return_prompt,
                    **generation_kwargs,
                )

        else:
            if len(query_tensor.shape) == 2:
                raise ValueError(
                    "query_tensor must be a tensor of shape (`seq_len`) or a list of tensors of shape (`seq_len`)"
                )

            if length_sampler is not None:
                generation_kwargs["max_new_tokens"] = length_sampler()
            if self.config.use_habana:
                self.wrap_generation_for_hpu_graph_mode(self.model)
            with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
                response = unwrapped_model.generate(input_ids=query_tensor.unsqueeze(dim=0), **generation_kwargs)
            if generate_ref_response:
                if self.config.use_habana:
                    self.wrap_generation_for_hpu_graph_mode(ref_model)
                with unwrap_model_for_generation(
                    ref_model, self.accelerator, is_peft_model=self.is_peft_model
                ) as unwrapped_model:
                    ref_response = unwrapped_model.generate(
                        input_ids=query_tensor.unsqueeze(dim=0), **generation_kwargs
                    )

            if not return_prompt and not self.is_encoder_decoder:
                response = response[:, query_tensor.shape[0] :]
                if generate_ref_response:
                    ref_response = ref_response[:, query_tensor.shape[0] :]

        if generate_ref_response:
            return response, ref_response
        return response

    def _generate_batched(
        self,
        model: PreTrainedModelWrapper,
        query_tensors: List[torch.Tensor],
        length_sampler: Optional[Callable] = None,
        batch_size: int = 4,
        return_prompt: bool = True,
        pad_to_multiple_of: Optional[int] = None,
        remove_padding: bool = True,
        **generation_kwargs,
    ):
        """
        Copied from PPOTrainer._generate_batched: https://github.com/huggingface/trl/blob/v0.9.6/trl/trainer/ppo_trainer.py#L535
        The only differences are:
        - pad to pad_max_input_len to get static shape for generation acceleration
        - use lazy mode and hpu_graphs for generation in hpu
        """
        outputs = []

        padding_side_default = self.tokenizer.padding_side
        if not self.is_encoder_decoder:
            self.tokenizer.padding_side = "left"

        # in case we have fewer examples than bs
        batch_size = min(len(query_tensors), batch_size)

        for i in range(0, len(query_tensors), batch_size):
            if length_sampler is not None:
                generation_kwargs["max_new_tokens"] = length_sampler()

            # prevent overflow if query tensors are not even multiple of bs
            end_index = min(len(query_tensors), i + batch_size)

            batch = query_tensors[i:end_index]
            batch_mask = [torch.ones_like(element) for element in batch]
            inputs = {"input_ids": batch, "attention_mask": batch_mask}

            if self.config.pad_for_acceleration and self.config.pad_max_input_len > 0:
                padded_inputs = self.tokenizer.pad(
                    inputs,
                    padding="max_length",
                    max_length=self.config.pad_max_input_len,
                    pad_to_multiple_of=pad_to_multiple_of,
                    return_tensors="pt",
                ).to(self.current_device)
            else:
                padded_inputs = self.tokenizer.pad(
                    inputs,
                    padding=True,
                    max_length=None,
                    pad_to_multiple_of=pad_to_multiple_of,
                    return_tensors="pt",
                ).to(self.current_device)

            if self.config.use_habana:
                generation_kwargs["ignore_eos"] = False
                generation_kwargs["lazy_mode"] = True
                generation_kwargs["hpu_graphs"] = True

            with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
                generations = unwrapped_model.generate(**padded_inputs, **generation_kwargs)

            for generation, mask in zip(generations, padded_inputs["attention_mask"]):
                if not self.is_encoder_decoder:
                    output = generation[(1 - mask).sum() :]  # remove padding
                else:
                    output = generation

                if not return_prompt and not self.is_encoder_decoder:
                    output = output[(mask).sum() :]  # remove prompt

                if remove_padding and self.tokenizer.eos_token_id in output:
                    pad_mask = output == self.tokenizer.eos_token_id
                    pad_start = torch.nonzero(pad_mask, as_tuple=False)[0, 0].item()
                    output = output[: pad_start + 1]  # keep the eos token at the end

                outputs.append(output)

        self.tokenizer.padding_side = padding_side_default
        return outputs

    @PPODecorators.empty_device_cache()
    def step(
        self,
        queries: List[torch.LongTensor],
        responses: List[torch.LongTensor],
        scores: List[torch.FloatTensor],
        response_masks: Optional[List[torch.LongTensor]] = None,
    ):
        """
        Copied from PPOTrainer.step: https://github.com/huggingface/trl/blob/v0.9.6/trl/trainer/ppo_trainer.py#L647
        The only differences are:
        - use hpu_graphs for sampling and training
        - remove duplicated padding if padding is done in prepare_model_inputs
        """
        bs = self.config.batch_size

        queries, responses, scores, response_masks = self._step_safety_checker(
            bs, queries, responses, scores, response_masks
        )
        scores = torch.tensor(scores, device=self.current_device)
        if self.config.use_score_scaling:
            # Score scaling
            scores_mean, scores_std = self.running.update(scores)
            tensor_to_kwargs = {"dtype": scores.dtype, "device": scores.device}
            score_scaling_factor = self.running.std.to(**tensor_to_kwargs) + torch.finfo(scores.dtype).eps
            if self.config.use_score_norm:
                scores = (scores - self.running.mean.to(**tensor_to_kwargs)) / score_scaling_factor
            else:
                scores /= score_scaling_factor

        if self.config.score_clip is not None:
            # Score clipping
            scores_dtype = scores.dtype
            scores = torch.clip(scores.float(), -self.config.score_clip, self.config.score_clip).to(dtype=scores_dtype)

        # if we want to push best model to the hub
        if hasattr(self, "highest_reward"):
            if self.compare_step % self.config.compare_steps == 0:
                curr_mean_reward = scores.mean()
                # if the best reward ever seen
                if curr_mean_reward > self.highest_reward:
                    self.highest_reward = curr_mean_reward
                    # push model to hub
                    self.push_to_hub(**self.push_to_hub_kwargs)
            self.compare_step += 1

        timing = {}
        timer = HabanaGenerationTime()
        timer.start()

        model_inputs = self.prepare_model_inputs(queries, responses)

        if self.is_distributed and not self.config.pad_for_acceleration:
            pad_first = self.tokenizer.padding_side == "left"

            model_inputs["input_ids"] = self.accelerator.pad_across_processes(
                model_inputs["input_ids"],
                dim=1,
                pad_index=self.tokenizer.pad_token_id,
                pad_first=pad_first,
            )
            model_inputs["attention_mask"] = self.accelerator.pad_across_processes(
                model_inputs["attention_mask"], dim=1, pad_index=0, pad_first=pad_first
            )
            if self.is_encoder_decoder:
                model_inputs["decoder_input_ids"] = self.accelerator.pad_across_processes(
                    model_inputs["decoder_input_ids"],
                    dim=1,
                    pad_index=self.tokenizer.pad_token_id,
                    pad_first=pad_first,
                )
                model_inputs["decoder_attention_mask"] = self.accelerator.pad_across_processes(
                    model_inputs["decoder_attention_mask"],
                    dim=1,
                    pad_index=0,
                    pad_first=pad_first,
                )

        model_inputs_names = list(model_inputs.keys())

        full_kl_penalty = self.config.kl_penalty == "full"

        with torch.no_grad():
            if self.config.use_habana:
                self.unwrap_generation_for_hpu_graph_mode(self.model)
                self.wrap_fw_for_hpu_graph_mode(self.model)
                if self.ref_model is not None:
                    self.unwrap_generation_for_hpu_graph_mode(self.ref_model)
                    self.wrap_fw_for_hpu_graph_mode(self.ref_model)
            all_logprobs, logits_or_none, values, masks = self.batched_forward_pass(
                self.model,
                queries,
                responses,
                model_inputs,
                response_masks=response_masks,
                return_logits=full_kl_penalty,
            )
            with self.optional_peft_ctx():
                ref_logprobs, ref_logits_or_none, _, _ = self.batched_forward_pass(
                    self.model if self.is_peft_model else self.ref_model,
                    queries,
                    responses,
                    model_inputs,
                    return_logits=full_kl_penalty,
                )
        timer.step()
        timing["time/ppo/forward_pass"] = timer.last_duration

        with torch.no_grad():
            timer.step()
            if full_kl_penalty:
                active_full_logprobs = logprobs_from_logits(logits_or_none, None, gather=False)
                ref_full_logprobs = logprobs_from_logits(ref_logits_or_none, None, gather=False)

                rewards, non_score_reward, kls = self.compute_rewards(
                    scores, active_full_logprobs, ref_full_logprobs, masks
                )
            else:
                rewards, non_score_reward, kls = self.compute_rewards(scores, all_logprobs, ref_logprobs, masks)
            timer.step()
            timing["time/ppo/compute_rewards"] = timer.last_duration

            timer.step()
            values, advantages, returns = self.compute_advantages(values, rewards, masks)
            timer.step()
            timing["time/ppo/compute_advantages"] = timer.last_duration

        # upcast to float32 to avoid dataset issues
        batch_dict = {
            "queries": queries,
            "responses": responses,
            "logprobs": all_logprobs.to(torch.float32),
            "values": values.to(torch.float32),
            "masks": masks,
            "advantages": advantages,
            "returns": returns,
        }
        batch_dict.update(model_inputs)

        timer.step()
        all_stats = []
        early_stop = False
        if self.config.use_habana:
            self.unwrap_fw_for_hpu_graph_mode(self.model)
            import habana_frameworks.torch as ht

            model = self.accelerator.unwrap_model(self.model)
            if not hasattr(model, "wrap_train_in_graph"):
                model = ht.hpu.wrap_in_hpu_graph(model)
                setattr(model, "wrap_train_in_graph", model.forward)
            else:
                model.forward = getattr(model, "wrap_train_in_graph")

        for _ in range(self.config.ppo_epochs):
            if early_stop:
                break
            b_inds = np.random.permutation(bs)
            for backward_batch_start in range(0, bs, self.config.backward_batch_size):
                backward_batch_end = backward_batch_start + self.config.backward_batch_size
                backward_batch_inds = b_inds[backward_batch_start:backward_batch_end]

                for mini_batch_start in range(0, self.config.backward_batch_size, self.config.mini_batch_size):
                    mini_batch_end = mini_batch_start + self.config.mini_batch_size
                    mini_batch_inds = backward_batch_inds[mini_batch_start:mini_batch_end]
                    mini_batch_dict = {
                        "logprobs": batch_dict["logprobs"][mini_batch_inds],
                        "values": batch_dict["values"][mini_batch_inds],
                        "masks": batch_dict["masks"][mini_batch_inds],
                        # hacks: the queries and responses are ragged.
                        "queries": [batch_dict["queries"][i] for i in mini_batch_inds],
                        "responses": [batch_dict["responses"][i] for i in mini_batch_inds],
                        "advantages": batch_dict["advantages"][mini_batch_inds],
                        "returns": batch_dict["returns"][mini_batch_inds],
                    }
                    for k in model_inputs_names:
                        mini_batch_dict[k] = batch_dict[k][mini_batch_inds]
                    with self.accelerator.accumulate(self.model):
                        model_inputs = {k: mini_batch_dict[k] for k in model_inputs_names}

                        logprobs, logits, vpreds, _ = self.batched_forward_pass(
                            self.model,
                            mini_batch_dict["queries"],
                            mini_batch_dict["responses"],
                            model_inputs,
                            return_logits=True,
                        )
                        train_stats = self.train_minibatch(
                            mini_batch_dict["logprobs"],
                            mini_batch_dict["values"],
                            logprobs,
                            logits,
                            vpreds,
                            mini_batch_dict["masks"],
                            mini_batch_dict["advantages"],
                            mini_batch_dict["returns"],
                        )
                        all_stats.append(train_stats)

            # typically, early stopping is done at the epoch level
            if self.config.early_stopping:
                policykl = train_stats["policy/policykl"]
                early_stop = self._early_stop(policykl)
                if early_stop:
                    break

        timer.step()
        timing["time/ppo/optimize_step"] = timer.last_duration

        timer.step()
        train_stats = stack_dicts(all_stats)

        # reshape advantages/ratios such that they are not averaged.
        train_stats["policy/advantages"] = torch.flatten(train_stats["policy/advantages"]).unsqueeze(0)
        train_stats["policy/advantages"] = torch.nan_to_num(train_stats["policy/advantages"], WANDB_PADDING)
        train_stats["policy/ratio"] = torch.flatten(train_stats["policy/ratio"]).unsqueeze(0)

        stats = self.record_step_stats(
            scores=scores,
            logprobs=all_logprobs,
            ref_logprobs=ref_logprobs,
            non_score_reward=non_score_reward,
            train_stats=train_stats,
            kl_coef=self.kl_ctl.value,
            masks=masks,
            queries=queries,
            responses=responses,
            kls=kls,
        )
        # Gather/Reduce stats from all processes
        if self.is_distributed:
            stats = self.gather_stats(stats)
        stats = stats_to_np(stats)
        timer.step()
        timing["time/ppo/calc_stats"] = timer.last_duration
        stats["ppo/learning_rate"] = self.optimizer.param_groups[0]["lr"]

        # Update the KL control - multiply the batch_size by the number of processes
        self.kl_ctl.update(
            stats["objective/kl"],
            self.config.batch_size * self.accelerator.num_processes,
        )

        # Log the total ppo time
        timing["time/ppo/total"] = timer.total_time()
        stats.update(timing)

        # post-process stats for tensorboard and other loggers
        if self.config.log_with != "wandb":
            stats = convert_to_scalar(stats)

        if self.lr_scheduler is not None:
            self.lr_scheduler.step()

        return stats

    def prepare_model_inputs(self, queries: torch.Tensor, responses: torch.Tensor):
        """
        Copied from PPOTrainer.prepare_model_inputs: https://github.com/huggingface/trl/blob/v0.9.6/trl/trainer/ppo_trainer.py#L949
        The only differences are:
        - add padding to model inputs for static shape support in forward
        """
        if self.is_encoder_decoder:
            input_data = self.data_collator(
                [{"input_ids": q, "attention_mask": torch.ones_like(q)} for q in queries]
            ).to(self.current_device)

            decoder_inputs = self.data_collator(
                [{"input_ids": r, "attention_mask": torch.ones_like(r)} for r in responses]
            ).to(self.current_device)

            input_data["decoder_input_ids"] = decoder_inputs["input_ids"]
            input_data["decoder_attention_mask"] = decoder_inputs["attention_mask"]
        else:
            input_ids = [torch.cat([q, r]) for q, r in zip(queries, responses)]
            input_data = self.data_collator(
                [{"input_ids": ids, "attention_mask": torch.ones_like(ids)} for ids in input_ids]
            ).to(self.current_device)

        if self.config.pad_for_acceleration:
            input_data["input_ids"] = torch.nn.functional.pad(
                input_data["input_ids"],
                (0, self.config.pad_max_len - input_data["input_ids"].shape[1]),
                value=self.tokenizer.pad_token_id,
            )
            input_data["attention_mask"] = torch.nn.functional.pad(
                input_data["attention_mask"],
                (
                    0,
                    self.config.pad_max_len - input_data["attention_mask"].shape[1],
                ),
                value=0,
            )
            if self.is_encoder_decoder:
                input_data["decoder_input_ids"] = torch.nn.functional.pad(
                    input_data["decoder_input_ids"],
                    (
                        0,
                        self.config.pad_max_len - input_data["decoder_input_ids"].shape[1],
                    ),
                    value=self.tokenizer.pad_token_id,
                )
                input_data["decoder_attention_mask"] = torch.nn.functional.pad(
                    input_data["decoder_attention_mask"],
                    (
                        0,
                        self.config.pad_max_len - input_data["decoder_attention_mask"].shape[1],
                    ),
                    value=0,
                )

        input_data.pop("labels", None)  # we don't want to compute LM losses
        return input_data

    @PPODecorators.empty_device_cache()
    def batched_forward_pass(
        self,
        model: PreTrainedModelWrapper,
        queries: torch.Tensor,
        responses: torch.Tensor,
        model_inputs: dict,
        return_logits: bool = False,
        response_masks: Optional[torch.Tensor] = None,
    ):
        """
        Copied from PPOTrainer.batched_forward_pass: https://github.com/huggingface/trl/blob/v0.9.6/trl/trainer/ppo_trainer.py#L971
        The only differences are:
        - input_kwargs/output need to clone() to avoid overidden in hpu
        """
        bs = len(queries)
        fbs = self.config.mini_batch_size
        all_logprobs = []
        all_logits = []
        all_masks = []
        all_values = []

        model.eval()

        for i in range(math.ceil(bs / fbs)):
            input_kwargs = {key: value[i * fbs : (i + 1) * fbs].clone() for key, value in model_inputs.items()}
            query_batch = queries[i * fbs : (i + 1) * fbs]
            response_batch = responses[i * fbs : (i + 1) * fbs]
            if response_masks is not None:
                response_masks_batch = response_masks[i * fbs : (i + 1) * fbs]
            logits, _, values = model(**input_kwargs)

            if self.is_encoder_decoder:
                input_ids = input_kwargs["decoder_input_ids"]
                attention_mask = input_kwargs["decoder_attention_mask"]
            else:
                input_ids = input_kwargs["input_ids"]
                attention_mask = input_kwargs["attention_mask"]

            logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
            masks = torch.zeros_like(attention_mask)
            masks[:, :-1] = attention_mask[:, 1:]

            for j in range(len(query_batch)):
                if self.is_encoder_decoder:
                    # Decoder sentence starts always in the index 1 after padding in the Enc-Dec Models
                    start = 1
                    end = attention_mask[j, :].sum() - 1
                else:
                    start = len(query_batch[j]) - 1  # logprobs starts from the second query token
                    if attention_mask[j, 0] == 0:  # offset left padding
                        start += attention_mask[j, :].nonzero()[0]
                    end = start + len(response_batch[j])
                    if response_masks is not None:
                        response_masks_batch[j] = torch.cat(
                            (torch.zeros_like(query_batch[j]), response_masks_batch[j])
                        )[1:]

                masks[j, :start] = 0
                masks[j, end:] = 0
                if response_masks is not None:
                    masks[j, start:end] = masks[j, start:end] * response_masks_batch[j][start:end]

            if return_logits:
                all_logits.append(logits.clone())
            else:
                del logits
            all_values.append(values.clone())
            all_logprobs.append(logprobs)
            all_masks.append(masks)

        return (
            torch.cat(all_logprobs),
            torch.cat(all_logits)[:, :-1] if return_logits else None,
            torch.cat(all_values)[:, :-1],
            torch.cat(all_masks)[:, :-1],
        )

    @PPODecorators.empty_device_cache()
    def train_minibatch(
        self,
        old_logprobs: torch.FloatTensor,
        values: torch.FloatTensor,
        logprobs: torch.FloatTensor,
        logits: torch.FloatTensor,
        vpreds: torch.FloatTensor,
        mask: torch.LongTensor,
        advantages: torch.FloatTensor,
        returns: torch.FloatTensor,
    ):
        """
        Copied from PPOTrainer.batched_forward_pass: https://github.com/huggingface/trl/blob/v0.9.6/trl/trainer/ppo_trainer.py#L1058
        The only differences are:
        - add htcore.mark_step
        """
        self.model.train()
        loss_p, loss_v, train_stats = self.loss(
            old_logprobs, values, logits, vpreds, logprobs, mask, advantages, returns
        )
        loss = loss_p + loss_v
        global _recorded_graph

        if _recorded_graph is None:
            _recorded_graph = ht.hpu.HPUGraph()
            s = ht.hpu.default_stream()

            with ht.hpu.stream(s):
                _recorded_graph.capture_begin()
                self.accelerator.backward(loss)
                _recorded_graph.capture_end()
        else:
            _recorded_graph.replay()
        if self.config.max_grad_norm is not None:
            if self.accelerator.sync_gradients:
                self.accelerator.clip_grad_norm_(self.model_params, self.config.max_grad_norm)
        self.optimizer.step()
        if self.config.use_habana:
            self.htcore.mark_step()
        # we call optimizer.zero_grad() every time and let `accelerator` handle accumulation
        # see https://huggingface.co/docs/accelerate/usage_guides/gradient_accumulation#the-finished-code
        self.optimizer.zero_grad()
        return train_stats

    def wrap_fw_for_hpu_graph_mode(self, model: PreTrainedModelWrapper):
        model = self.accelerator.unwrap_model(model)
        if hasattr(model, "hpu_graph_fw"):
            model.forward = model.hpu_graph_fw
        else:
            from habana_frameworks.torch.hpu import wrap_in_hpu_graph

            model.orig_fw = model.forward
            model = wrap_in_hpu_graph(model)
            model.hpu_graph_fw = model.forward

    def unwrap_fw_for_hpu_graph_mode(self, model: PreTrainedModelWrapper):
        model = self.accelerator.unwrap_model(model)
        if hasattr(model, "orig_fw"):
            model.forward = model.orig_fw

    def wrap_generation_for_hpu_graph_mode(self, model: PreTrainedModelWrapper):
        from habana_frameworks.torch.hpu import wrap_in_hpu_graph

        model = self.accelerator.unwrap_model(model)
        if getattr(model, "is_peft_model", False):
            if hasattr(model.pretrained_model.base_model.model, "hpu_graph_fw"):
                model.pretrained_model.base_model.model.forward = model.pretrained_model.base_model.model.hpu_graph_fw
            else:
                model.pretrained_model.base_model.model.orig_fw = model.pretrained_model.base_model.model.forward
                model.pretrained_model.base_model.model = wrap_in_hpu_graph(model.pretrained_model.base_model.model)
                model.pretrained_model.base_model.model.hpu_graph_fw = model.pretrained_model.base_model.model.forward
        else:
            if hasattr(model.pretrained_model, "hpu_graph_fw"):
                model.pretrained_model.forward = model.pretrained_model.hpu_graph_fw
            else:
                model.pretrained_model.orig_fw = model.pretrained_model.forward
                model.pretrained_model = wrap_in_hpu_graph(model.pretrained_model)
                model.pretrained_model.hpu_graph_fw = model.pretrained_model.forward

    def unwrap_generation_for_hpu_graph_mode(self, model: PreTrainedModelWrapper):
        model = self.accelerator.unwrap_model(model)
        if getattr(model, "is_peft_model", False):
            if hasattr(model.pretrained_model.base_model.model, "orig_fw"):
                model.pretrained_model.base_model.model.forward = model.pretrained_model.base_model.model.orig_fw
        else:
            if hasattr(model.pretrained_model, "orig_fw"):
                model.pretrained_model.forward = model.pretrained_model.orig_fw
