chatlearn/models/megatron/ops/policy_gradient.py (127 lines of code) (raw):
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Policy Gradient Loss"""
import torch
from chatlearn.utils.megatron_import_helper import average_losses_across_data_parallel_group
from chatlearn.utils.megatron_import_helper import get_args
from chatlearn.utils.megatron_import_helper import mpu
from chatlearn.utils.megatron_import_helper import VocabUtility
# pylint: disable=arguments-differ,abstract-method
class PolicyGradientLoss(torch.autograd.Function):
"""
Policy Gradient Loss
"""
@staticmethod
def forward(ctx, vocab_parallel_logits, cliprange, action_ids, old_logprobs, advantages, loss_mask, stats):
# Maximum value along vocab dimension across all GPUs.
vocab_parallel_logits = vocab_parallel_logits.clone() # for view error
action_ids = action_ids.clone()
old_logprobs = old_logprobs.clone()
advantages = advantages.clone()
loss_mask = loss_mask.clone()
if get_args().numerical_stable:
logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
torch.distributed.all_reduce(logits_max,
op=torch.distributed.ReduceOp.MAX,
group=mpu.get_tensor_model_parallel_group())
logits_min = torch.min(vocab_parallel_logits, dim=-1)[0]
torch.distributed.all_reduce(logits_min,
op=torch.distributed.ReduceOp.MIN,
group=mpu.get_tensor_model_parallel_group())
logits_max = (logits_min + logits_max) / 2
else:
logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
torch.distributed.all_reduce(logits_max,
op=torch.distributed.ReduceOp.MAX,
group=mpu.get_tensor_model_parallel_group())
# Subtract the maximum value.
vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
# Get the partition's vocab indecies
get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size
partition_vocab_size = vocab_parallel_logits.size()[-1]
rank = mpu.get_tensor_model_parallel_rank()
world_size = mpu.get_tensor_model_parallel_world_size()
vocab_start_index, vocab_end_index = get_vocab_range(
partition_vocab_size, rank, world_size)
# Create a mask of valid vocab ids (1 means it needs to be masked).
target_mask = (action_ids < vocab_start_index) | (
action_ids >= vocab_end_index) # [b,s] 1 for not in range action, 0 for in range
# print(f"target_mask: {target_mask}")
masked_actionids = action_ids.clone() - vocab_start_index # [b,s]
masked_actionids[target_mask] = 0 # [b,s]
# Get predicted-logits = logits[target].
# For Simplicity, we convert logits to a 2-D tensor with size
# [*, partition-vocab-size] and target to a 1-D tensor of size [*].
logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size) # [n vp]
masked_actionids_1d = masked_actionids.view(-1) # [n] 0 for not in vocab range, target id -start for in range
arange_1d = torch.arange(start=0, end=logits_2d.size()[0],
device=logits_2d.device)
predicted_logits_1d = logits_2d[
arange_1d, masked_actionids_1d] # [n] in range target logit, not in range logits[0]
predicted_logits_1d = predicted_logits_1d.clone().contiguous()
action_logits = predicted_logits_1d.view_as(action_ids)
action_logits[target_mask] = 0.0 # [b s] 0 for not in range, logit for in range
# All reduce is needed to get the chunks from other GPUs.
torch.distributed.all_reduce(action_logits,
op=torch.distributed.ReduceOp.SUM,
group=mpu.get_tensor_model_parallel_group())
# Sum of exponential of logits along vocab dimension across all GPUs.
exp_logits = vocab_parallel_logits # [ b, s, vp ]
torch.exp(vocab_parallel_logits, out=exp_logits)
sum_exp_logits = exp_logits.clone().sum(dim=-1) # [ b, s ]
torch.distributed.all_reduce(sum_exp_logits,
op=torch.distributed.ReduceOp.SUM,
group=mpu.get_tensor_model_parallel_group())
action_logprob = action_logits - torch.log(sum_exp_logits + 1e-10) # log ( exp(l) / sum(exp(li)
# Loss = log(sum(exp(logits))) - predicted-logit.
assert not torch.isnan(action_logprob).any(), f"action_logprob {action_logprob}"
# Store softmax, target-mask and masked-target for backward pass.
# exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
# clamp the diff to be exponentiated to be bounded
if get_args().numerical_stable:
logprob_diff = torch.clamp(action_logprob - old_logprobs, min=-1e5, max=1e5)
log_ratio = (logprob_diff) * loss_mask
# numerical approximate an exponential for numerical stability
ratio = 1 + log_ratio + torch.square(log_ratio) / 2
else:
logprob_diff = action_logprob - old_logprobs
log_ratio = (logprob_diff) * loss_mask
ratio = torch.exp(log_ratio)
# numerical approximate an exponential for numerical stability
# ratio = 1 + log_ratio + torch.square(log_ratio) / 2
# Unbiased KL-div estimates (`k3`). Ref: http://joschu.net/blog/kl-approx.html
assert not torch.isnan(ratio).any(), f"ratio {ratio} old_logprobs {old_logprobs}"
with torch.no_grad():
approx_kl = torch.mean((ratio - 1) - log_ratio)
pg_loss1 = -advantages * ratio
pg_loss2 = -advantages * torch.clamp(
ratio,
1.0 - cliprange,
1.0 + cliprange,
)
loss = torch.max(pg_loss1, pg_loss2) * loss_mask # [b, s]
# loss = pg_loss1 # [b, s]
pg_clipfrac = torch.sum((pg_loss2 > pg_loss1).float() * loss_mask).detach()
torch.distributed.all_reduce(approx_kl,
op=torch.distributed.ReduceOp.AVG,
group=mpu.get_tensor_model_parallel_group())
torch.distributed.all_reduce(pg_clipfrac,
op=torch.distributed.ReduceOp.AVG,
group=mpu.get_tensor_model_parallel_group())
all_average_approx_kl, all_average_pg_clipfrac = average_losses_across_data_parallel_group(
[approx_kl, pg_clipfrac])
stats["policy/approx_kl"] = all_average_approx_kl.item()
stats["policy/pg_clipfrac"] = all_average_pg_clipfrac.item()
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) # [b, s, v]
selected_action_softmax_1d = torch.exp(action_logprob).view(-1) # [n]
ctx.save_for_backward(exp_logits, masked_actionids_1d, selected_action_softmax_1d, old_logprobs, advantages,
ratio, loss_mask, target_mask)
ctx.cliprange = cliprange
return loss # [b, response size]
@staticmethod
def backward(ctx, grad_output): # [b, resposne size]
# masked_actionids_1d: [n]
S, masked_actionids_1d, selected_action_softmax_1d, old_logprobs, advantages, ratio, mask, invalid_mask = ctx.saved_tensors
cliprange = ctx.cliprange
ratio_1d = ratio.view(-1)
mask_1d = mask.view(-1)
old_logprobs_1d = old_logprobs.view(-1)
advantages_1d = advantages.view(-1)
vocab = S.size(-1)
s_2d = S.view(-1, vocab) # [b*s, v]
n = s_2d.size(0)
invalid_mask = invalid_mask.view(-1)
Sc = s_2d.clone()
m = s_2d.clone()
# Add the gradient from matching classes.
arange_1d = torch.arange(start=0, end=n) # [n]
m[arange_1d, masked_actionids_1d] = invalid_mask * m[arange_1d, masked_actionids_1d]
# [n] invalid actions [s1, s2, ... sn], valid actions [s1 ... 0, ... sn]
Sc = Sc - m # invalid action: [n, vp]
grad_input = Sc - selected_action_softmax_1d.unsqueeze(-1) * s_2d # [n, vp]
grad_input *= -advantages_1d.unsqueeze(-1) * torch.exp(-old_logprobs_1d).unsqueeze(-1) * mask_1d.unsqueeze(-1)
# Finally elementwise multiplication with the output gradients.
# select clamped loss
# grad = 0 : ratio < 1 - cliprange * advantages < 0 + ratio > 1 + cliprange * advantages > 0
grad_input[(ratio_1d < 1.0 - cliprange) * (advantages_1d < 0)] = 0
grad_input[(ratio_1d > 1.0 + cliprange) * (advantages_1d > 0)] = 0
grad_input = grad_input.view_as(S)
# add entropy gradient here:
grad_input.mul_(grad_output.unsqueeze(-1))
return grad_input.contiguous(), None, None, None, None, None, None
def tensor_decomp_pg_loss(config, action_token_logits, action_ids,
action_loss_mask, old_logprobs, advantages, stats):
"""Helper function for the cross entropy."""
assert action_token_logits.size(1) == action_ids.size(1) \
== action_loss_mask.size(1) == old_logprobs.size(1) == advantages.size(
1), f"{action_token_logits.size(1)}, {action_ids.size(1)}," \
f"{action_loss_mask.size(1)}, {old_logprobs.size(1)}," \
f"{advantages.size(1)}"
return PolicyGradientLoss.apply(action_token_logits, config.cliprange, action_ids,
old_logprobs, advantages, action_loss_mask, stats) # [b, response_size]
# pylint: enable=arguments-differ,abstract-method