def forward()

in chatlearn/models/megatron/ops/policy_gradient.py [0:0]


    def forward(ctx, vocab_parallel_logits, cliprange, action_ids, old_logprobs, advantages, loss_mask, stats):
        # Maximum value along vocab dimension across all GPUs.
        vocab_parallel_logits = vocab_parallel_logits.clone()  # for view error
        action_ids = action_ids.clone()
        old_logprobs = old_logprobs.clone()
        advantages = advantages.clone()
        loss_mask = loss_mask.clone()
        if get_args().numerical_stable:
            logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
            torch.distributed.all_reduce(logits_max,
                                         op=torch.distributed.ReduceOp.MAX,
                                         group=mpu.get_tensor_model_parallel_group())
            logits_min = torch.min(vocab_parallel_logits, dim=-1)[0]
            torch.distributed.all_reduce(logits_min,
                                         op=torch.distributed.ReduceOp.MIN,
                                         group=mpu.get_tensor_model_parallel_group())
            logits_max = (logits_min + logits_max) / 2

        else:
            logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
            torch.distributed.all_reduce(logits_max,
                                         op=torch.distributed.ReduceOp.MAX,
                                         group=mpu.get_tensor_model_parallel_group())
        # Subtract the maximum value.
        vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))

        # Get the partition's vocab indecies
        get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size
        partition_vocab_size = vocab_parallel_logits.size()[-1]
        rank = mpu.get_tensor_model_parallel_rank()
        world_size = mpu.get_tensor_model_parallel_world_size()
        vocab_start_index, vocab_end_index = get_vocab_range(
            partition_vocab_size, rank, world_size)

        # Create a mask of valid vocab ids (1 means it needs to be masked).
        target_mask = (action_ids < vocab_start_index) | (
            action_ids >= vocab_end_index)  # [b,s] 1 for not in range action, 0 for in range
        # print(f"target_mask: {target_mask}")

        masked_actionids = action_ids.clone() - vocab_start_index  # [b,s]
        masked_actionids[target_mask] = 0  # [b,s]


        # Get predicted-logits = logits[target].
        # For Simplicity, we convert logits to a 2-D tensor with size
        # [*, partition-vocab-size] and target to a 1-D tensor of size [*].
        logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size)  # [n vp]
        masked_actionids_1d = masked_actionids.view(-1)  # [n] 0 for not in vocab range, target id -start for in range
        arange_1d = torch.arange(start=0, end=logits_2d.size()[0],
                                 device=logits_2d.device)
        predicted_logits_1d = logits_2d[
            arange_1d, masked_actionids_1d]  # [n] in range target logit, not in range logits[0]
        predicted_logits_1d = predicted_logits_1d.clone().contiguous()
        action_logits = predicted_logits_1d.view_as(action_ids)
        action_logits[target_mask] = 0.0  # [b s] 0 for not in range, logit for in range
        # All reduce is needed to get the chunks from other GPUs.
        torch.distributed.all_reduce(action_logits,
                                     op=torch.distributed.ReduceOp.SUM,
                                     group=mpu.get_tensor_model_parallel_group())
        # Sum of exponential of logits along vocab dimension across all GPUs.
        exp_logits = vocab_parallel_logits  # [ b, s, vp ]
        torch.exp(vocab_parallel_logits, out=exp_logits)
        sum_exp_logits = exp_logits.clone().sum(dim=-1)  # [ b, s ]
        torch.distributed.all_reduce(sum_exp_logits,
                                     op=torch.distributed.ReduceOp.SUM,
                                     group=mpu.get_tensor_model_parallel_group())

        action_logprob = action_logits - torch.log(sum_exp_logits + 1e-10)  # log ( exp(l) / sum(exp(li)
        # Loss = log(sum(exp(logits))) - predicted-logit.
        assert not torch.isnan(action_logprob).any(), f"action_logprob {action_logprob}"
        # Store softmax, target-mask and masked-target for backward pass.
        # exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
        # clamp the diff to be exponentiated to be bounded

        if get_args().numerical_stable:
            logprob_diff = torch.clamp(action_logprob - old_logprobs, min=-1e5, max=1e5)
            log_ratio = (logprob_diff) * loss_mask
            # numerical approximate an exponential for numerical stability
            ratio = 1 + log_ratio + torch.square(log_ratio) / 2
        else:
            logprob_diff = action_logprob - old_logprobs
            log_ratio = (logprob_diff) * loss_mask
            ratio = torch.exp(log_ratio)
            # numerical approximate an exponential for numerical stability
            # ratio = 1 + log_ratio + torch.square(log_ratio) / 2
        # Unbiased KL-div estimates (`k3`). Ref: http://joschu.net/blog/kl-approx.html
        assert not torch.isnan(ratio).any(), f"ratio {ratio} old_logprobs {old_logprobs}"

        with torch.no_grad():
            approx_kl = torch.mean((ratio - 1) - log_ratio)

        pg_loss1 = -advantages * ratio
        pg_loss2 = -advantages * torch.clamp(
            ratio,
            1.0 - cliprange,
            1.0 + cliprange,
        )

        loss = torch.max(pg_loss1, pg_loss2) * loss_mask  # [b, s]
        # loss = pg_loss1  # [b, s]
        pg_clipfrac = torch.sum((pg_loss2 > pg_loss1).float() * loss_mask).detach()
        torch.distributed.all_reduce(approx_kl,
                                     op=torch.distributed.ReduceOp.AVG,
                                     group=mpu.get_tensor_model_parallel_group())
        torch.distributed.all_reduce(pg_clipfrac,
                                     op=torch.distributed.ReduceOp.AVG,
                                     group=mpu.get_tensor_model_parallel_group())
        all_average_approx_kl, all_average_pg_clipfrac = average_losses_across_data_parallel_group(
            [approx_kl, pg_clipfrac])
        stats["policy/approx_kl"] = all_average_approx_kl.item()
        stats["policy/pg_clipfrac"] = all_average_pg_clipfrac.item()

        exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))  # [b, s, v]
        selected_action_softmax_1d = torch.exp(action_logprob).view(-1)  # [n]

        ctx.save_for_backward(exp_logits, masked_actionids_1d, selected_action_softmax_1d, old_logprobs, advantages,
                              ratio, loss_mask, target_mask)
        ctx.cliprange = cliprange

        return loss  # [b, response size]