megatron_patch/template/helper.py (115 lines of code) (raw):
# Copyright (c) 2025 Alibaba PAI and Nvidia Megatron-LM Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretrain GPT."""
import os
import torch
import inspect
from functools import partial
from megatron.core import mpu
from megatron.training import get_args, get_timers
from megatron.training.utils import (
average_losses_across_data_parallel_group,
get_batch_on_this_cp_rank,
get_batch_on_this_tp_rank,
)
from megatron.core.models.gpt import GPTModel
from megatron.core.packed_seq_params import PackedSeqParams
from megatron_patch.data.utils import (
get_batch_on_this_tp_rank_original,
get_batch_on_this_tp_rank_idxmap_sft,
get_position_id_on_this_tp_rank_idxmap_sft_packing
)
def get_batch(data_iterator):
"""Generate a batch."""
args = get_args()
# TODO: this is pretty hacky, find a better way
if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()):
packed_seq_params = None
if args.dataset == 'MMAP' and args.train_mode == "finetune" and args.reset_position_ids:
position_ids = get_position_id_on_this_tp_rank_idxmap_sft_packing(data_iterator)
position_ids = position_ids[0] # shape: [seq_length]
start_indices = (position_ids == 0).nonzero(as_tuple=True)[0]
seqlens = start_indices[1:] - start_indices[:-1]
# NOTE: cu_seqlens: [0, A1, A1+A2, A1+A2+A3, ..., seq_len]
cu_seqlens = torch.zeros(start_indices.shape[0] + 1, device=position_ids.device, dtype=torch.int)
cu_seqlens[1:-1] = torch.cumsum(seqlens, dim=0)
cu_seqlens[-1] = position_ids.shape[0]
packed_seq_params = PackedSeqParams(
cu_seqlens_q=cu_seqlens,
cu_seqlens_kv=cu_seqlens,
qkv_format='thd'
)
return None, None, None, None, None, None, packed_seq_params
if args.dataset == 'JSON-SFT':
if args.train_mode == "pretrain":
raise ValueError('The JSON-SFT dataset should only be used for finetuning!')
# get batches based on the TP rank you are on
batch = get_batch_on_this_tp_rank_original(data_iterator, per_seq_average=True)
# slice batch along sequence dimension for context parallelism
num_seqs = batch.pop('num_seqs')
batch = get_batch_on_this_cp_rank(batch)
return (
batch['tokens'],
batch['labels'],
batch['loss_mask'],
batch['attention_mask'],
batch['position_ids'],
num_seqs,
None
)
elif args.dataset == 'MMAP':
# get batches based on the TP rank you are on
if args.train_mode == "pretrain":
batch = get_batch_on_this_tp_rank(data_iterator)
else:
batch = get_batch_on_this_tp_rank_idxmap_sft(data_iterator, per_seq_average=True)
packed_seq_params = None
if args.reset_position_ids:
# sequence-packing, build cu_seqlens
position_ids = batch.get('position_ids', None)
if position_ids is not None:
# mbs = 1
position_ids = position_ids[0] # shape: [seq_length]
start_indices = (position_ids == 0).nonzero(as_tuple=True)[0]
seqlens = start_indices[1:] - start_indices[:-1]
# NOTE: cu_seqlens: [0, A1, A1+A2, A1+A2+A3, ..., seq_len]
cu_seqlens = torch.zeros(start_indices.shape[0] + 1, device=position_ids.device, dtype=torch.int)
cu_seqlens[1:-1] = torch.cumsum(seqlens, dim=0)
cu_seqlens[-1] = position_ids.shape[0]
packed_seq_params = PackedSeqParams(
cu_seqlens_q=cu_seqlens,
cu_seqlens_kv=cu_seqlens,
qkv_format='thd'
)
if packed_seq_params is not None and args.context_parallel_size > 1:
raise ValueError('Sequence Packing is not supported when CP>1 !')
# slice batch along sequence dimension for context parallelism
num_seqs = batch.pop('num_seqs', None)
batch = get_batch_on_this_cp_rank(batch)
return (
batch['tokens'],
batch['labels'],
batch['loss_mask'],
batch['attention_mask'],
batch['position_ids'],
num_seqs,
packed_seq_params
)
else:
raise ValueError("please set correct --dataset ")
def loss_func(loss_mask: torch.Tensor, num_seqs: torch.Tensor, output_tensor: torch.Tensor):
"""Loss function.
Args:
loss_mask (torch.Tensor): Used to mask out some portions of the loss
output_tensor (torch.Tensor): The tensor with the losses
"""
args = get_args()
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
# NOTE: for each seq, sum(loss_mask) == 1 if num_seqs is not None,
# otherwise sum(loss_mask) == n_tokens
loss = torch.stack([torch.sum(losses.view(-1) * loss_mask), loss_mask.sum()])
if args.context_parallel_size > 1:
torch.distributed.all_reduce(loss, group=mpu.get_context_parallel_group())
# Check individual rank losses are not NaN prior to DP all-reduce.
if args.check_for_nan_in_loss_and_grad:
global_rank = torch.distributed.get_rank()
assert not loss.isnan().any(), (
f"Rank {global_rank}: found NaN in local forward loss calculation. "
f"Device: {torch.cuda.current_device()}, node: {os.uname()[1]}"
)
averaged_loss = average_losses_across_data_parallel_group(loss)
averaged_loss = averaged_loss[0] / averaged_loss[1]
# NOTE: The grad will be scaled down by CP size later, should not remove this multilication factor
# LINK: https://github.com/NVIDIA/Megatron-LM/issues/906
# The issue is solved since 0926
if num_seqs is None:
# average on token-level
return loss[0] / loss[1] * args.context_parallel_size, {"lm loss": averaged_loss}
return loss[0] * args.context_parallel_size, num_seqs.sum(), {"lm loss": averaged_loss}
def forward_step(data_iterator, model):
"""Forward training step.
Args:
data_iterator : Input data iterator
model (GPTModel): The GPT Model
"""
timers = get_timers()
args = get_args()
# Get the batch.
timers("batch-generator", log_level=2).start()
tokens, labels, loss_mask, attention_mask, position_ids, num_seqs, packed_seq_params = get_batch(data_iterator)
timers("batch-generator").stop()
if 'loss_mask' in inspect.signature(GPTModel.forward).parameters:
# NOTE: MTP-head (since 0328) requires loss_mask to compute correct loss scale.
output_tensor = model(tokens, position_ids, attention_mask, labels=labels, packed_seq_params=packed_seq_params, loss_mask=loss_mask)
else:
output_tensor = model(tokens, position_ids, attention_mask, labels=labels, packed_seq_params=packed_seq_params)
return output_tensor, partial(loss_func, loss_mask, num_seqs)