megatron_patch/data/utils.py (318 lines of code) (raw):

# Copyright (c) 2023 Alibaba PAI and Nvidia Megatron-LM Team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from megatron.core import mpu try: from megatron import get_args except: from megatron.training import get_args from megatron_patch.tokenizer import get_tokenizer def get_ltor_masks_and_position_ids(data, eod_token, reset_position_ids, reset_attention_mask, eod_mask_loss, create_attention_mask: bool=True): """Build masks and position id for left to right model.""" # Extract batch size and sequence length. micro_batch_size, seq_length = data.size() # Attention mask (lower triangular). if reset_attention_mask: att_mask_batch = micro_batch_size else: att_mask_batch = 1 if create_attention_mask: attention_mask = torch.tril(torch.ones( (att_mask_batch, seq_length, seq_length), device=data.device)).view( att_mask_batch, 1, seq_length, seq_length) else: attention_mask = None # Loss mask. loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device) if eod_mask_loss: loss_mask[data == eod_token] = 0.0 # Position ids. position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device) position_ids = position_ids.unsqueeze(0).expand_as(data) # We need to clone as the ids will be modifed based on batch index. if reset_position_ids: position_ids = position_ids.clone() if reset_position_ids or reset_attention_mask: # Loop through the batches: for b in range(micro_batch_size): # Find indecies where EOD token is. eod_index = position_ids[b, data[b] == eod_token] # Detach indecies from positions if going to modify positions. if reset_position_ids: eod_index = eod_index.clone() # Loop through EOD indecies: prev_index = 0 for j in range(eod_index.size()[0]): i = eod_index[j] # Mask attention loss. if reset_attention_mask and attention_mask is not None: attention_mask[b, 0, (i + 1):, :(i + 1)] = 0 # Reset positions. if reset_position_ids: position_ids[b, (i + 1):] -= (i + 1 - prev_index) prev_index = i + 1 if attention_mask is not None: # Convert attention mask to binary: attention_mask = (attention_mask < 0.5) return attention_mask, loss_mask, position_ids def get_ltor_position_ids_packed_seq(data): """ Given a input_seqs from custom mmap dataset, generate a position_ids by searching negative tokens. """ # Extract batch size and sequence length. micro_batch_size, seq_length = data.size() # Position ids. position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device) position_ids = position_ids.unsqueeze(0).expand_as(data) # We need to clone as the ids will be modifed based on batch index. position_ids = position_ids.clone() # Loop through the batches: for b in range(micro_batch_size): # Find indecies where EOD token is. eod_index = position_ids[b, data[b] < 0] # Detach indecies from positions if going to modify positions. eod_index = eod_index.clone() # Loop through EOD indecies: prev_index = 0 for j in range(eod_index.size()[0]): i = eod_index[j] position_ids[b, (i + 1):] -= (i + 1 - prev_index) prev_index = i + 1 return position_ids def get_batch_on_this_tp_rank_original(data_iterator, per_seq_average=False): args = get_args() tokenizer = get_tokenizer() def _broadcast(item): if item is None: return torch.distributed.broadcast(item, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group()) if mpu.get_tensor_model_parallel_rank() == 0: if isinstance(data_iterator, dict): data = data_iterator else: data = next(data_iterator) tokens_ = data['input_ids'].long() labels_ = data['labels'].long() tokens = tokens_[:, :-1].contiguous() labels = labels_[:, 1:].contiguous() # core/tensor_parallel/cross_entropy.py, target_mask = (target < vocab_start_index) | (target >= vocab_end_index) # labels[labels == tokenizer.eos_token_id] = -100 # NOTE: if eos == pad, we map <eos> to - 1 - eos_id, map these tokens back tokens[tokens < 0] = - 1 - tokens[tokens < 0] eos_indices = (labels < 0).nonzero() labels[labels == tokenizer.pad_token_id] = -100 labels[eos_indices[:, 0], eos_indices[:, 1]] = - 1 - labels[eos_indices[:, 0], eos_indices[:, 1]] attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( labels, -100, args.reset_position_ids, args.reset_attention_mask, args.eod_mask_loss) num_seqs = None if per_seq_average: # NOTE: raw dataset does not support sequence packing num_seqs = torch.ones(position_ids.shape[0], device=torch.cuda.current_device(), dtype=torch.int64) loss_mask = loss_mask / loss_mask.sum(dim=-1, keepdims=True) # [mbs] batch = { 'tokens': tokens.cuda(non_blocking=True), 'labels': labels.cuda(non_blocking=True), 'loss_mask': loss_mask.cuda(non_blocking=True), 'attention_mask': attention_mask.cuda(non_blocking=True), 'position_ids': position_ids.cuda(non_blocking=True), 'num_seqs': num_seqs.cuda(non_blocking=True) if num_seqs is not None else None } if args.pipeline_model_parallel_size == 1: _broadcast(batch['tokens']) _broadcast(batch['labels']) _broadcast(batch['loss_mask']) _broadcast(batch['attention_mask']) _broadcast(batch['position_ids']) _broadcast(batch['num_seqs']) elif mpu.is_pipeline_first_stage(): _broadcast(batch['tokens']) _broadcast(batch['attention_mask']) _broadcast(batch['position_ids']) elif mpu.is_pipeline_last_stage(): # Multi-Token Prediction (MTP) layers need tokens and position_ids to calculate embedding. # Currently the Multi-Token Prediction (MTP) layers is fixed on the last stage, so we need # to broadcast tokens and position_ids to all of the tensor parallel ranks on the last stage. if getattr(args, 'mtp_num_layers', None) is not None: _broadcast(batch['tokens']) _broadcast(batch['position_ids']) _broadcast(batch['labels']) _broadcast(batch['loss_mask']) _broadcast(batch['attention_mask']) _broadcast(batch['num_seqs']) else: tokens = torch.empty((args.micro_batch_size, args.seq_length), dtype=torch.int64, device=torch.cuda.current_device()) labels = torch.empty((args.micro_batch_size, args.seq_length), dtype=torch.int64, device=torch.cuda.current_device()) loss_mask = torch.empty((args.micro_batch_size, args.seq_length), dtype=torch.float32, device=torch.cuda.current_device()) mbs = args.micro_batch_size if args.reset_attention_mask else 1 attention_mask = torch.empty((mbs, 1, args.seq_length, args.seq_length), dtype=torch.bool, device=torch.cuda.current_device()) position_ids = torch.empty((args.micro_batch_size, args.seq_length), dtype=torch.int64, device=torch.cuda.current_device()) num_seqs = None if per_seq_average: num_seqs = torch.empty((args.micro_batch_size,), dtype=torch.int64, device=torch.cuda.current_device()) if args.pipeline_model_parallel_size == 1: _broadcast(tokens) _broadcast(labels) _broadcast(loss_mask) _broadcast(attention_mask) _broadcast(position_ids) _broadcast(num_seqs) elif mpu.is_pipeline_first_stage(): labels = None loss_mask = None num_seqs = None _broadcast(tokens) _broadcast(attention_mask) _broadcast(position_ids) elif mpu.is_pipeline_last_stage(): # Multi-Token Prediction (MTP) layers need tokens and position_ids to calculate embedding. # Currently the Multi-Token Prediction (MTP) layers is fixed on the last stage, so we need # to broadcast tokens and position_ids to all of the tensor parallel ranks on the last stage. if getattr(args, 'mtp_num_layers', None) is not None: _broadcast(tokens) _broadcast(position_ids) else: tokens = None position_ids = None _broadcast(labels) _broadcast(loss_mask) _broadcast(attention_mask) _broadcast(num_seqs) batch = { 'tokens': tokens, 'labels': labels, 'loss_mask': loss_mask, 'attention_mask': attention_mask, 'position_ids': position_ids, 'num_seqs': num_seqs } return batch def get_position_id_on_this_tp_rank_idxmap_sft_packing(data_iterator): args = get_args() tokenizer = get_tokenizer() def _broadcast(item): if item is None: return torch.distributed.broadcast(item, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group()) if mpu.get_tensor_model_parallel_rank() == 0: if isinstance(data_iterator, dict): data = data_iterator else: data = next(data_iterator) actual_seqlen = args.seq_length data['tokens'] = data['tokens'].long() tokens = data['tokens'][..., :actual_seqlen] position_ids = get_ltor_position_ids_packed_seq(tokens).cuda(non_blocking=True) else: # dtype: long position_ids = torch.empty((args.micro_batch_size, args.seq_length), dtype=torch.int64, device=torch.cuda.current_device()) _broadcast(position_ids) return position_ids def get_batch_on_this_tp_rank_idxmap_sft(data_iterator, per_seq_average=False): args = get_args() tokenizer = get_tokenizer() def _broadcast(item): if item is None: return torch.distributed.broadcast(item, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group()) if mpu.get_tensor_model_parallel_rank() == 0: if isinstance(data_iterator, dict): data = data_iterator else: data = next(data_iterator) # sanity check assert data['tokens'].shape[-1] == 2 * args.seq_length actual_seqlen = args.seq_length data['tokens'] = data['tokens'].long() tokens = data['tokens'][..., :actual_seqlen] labels = data['tokens'][..., actual_seqlen:] loss_mask = (labels != -100).float() if args.reset_position_ids: attention_mask = None position_ids = get_ltor_position_ids_packed_seq(tokens) has_pad = tokens[:, -1] >= 0 tokens[tokens < 0] = - tokens[tokens < 0] - 1 else: tokens[tokens < 0] = - tokens[tokens < 0] - 1 attention_mask, _, position_ids = get_ltor_masks_and_position_ids( tokens, tokenizer.eod, args.reset_position_ids, args.reset_attention_mask, False, args.create_attention_mask_in_dataloader ) num_seqs = None if per_seq_average: num_seqs = torch.ones(position_ids.shape[0], device=torch.cuda.current_device(), dtype=torch.int64) if args.reset_position_ids: for b in range(position_ids.shape[0]): p = position_ids[b] start_indices = (p == 0).nonzero(as_tuple=True)[0] seqlens = start_indices[1:] - start_indices[:-1] seqlens = seqlens.cpu().numpy().tolist() + [p.shape[0] - start_indices[-1].item()] subseqs = torch.split(loss_mask[b], seqlens) num_seqs[b] = len(seqlens) - int(has_pad[b]) for subseq_idx, (start_idx, seqlen, subseq) in enumerate(zip(start_indices, seqlens, subseqs)): if subseq_idx == num_seqs[b]: # NOTE: do not process pad sequence continue assert subseq.sum() > 0 loss_mask[b, start_idx: start_idx + seqlen] /= subseq.sum() else: loss_mask = loss_mask / loss_mask.sum(dim=-1, keepdims=True) # [mbs] # dtype: long, long, float, bool, long batch = { 'tokens': tokens.cuda(non_blocking=True), 'labels': labels.cuda(non_blocking=True), 'loss_mask': loss_mask.cuda(non_blocking=True), 'attention_mask': attention_mask.cuda(non_blocking=True) if attention_mask is not None else None, 'position_ids': position_ids.cuda(non_blocking=True), 'num_seqs': num_seqs.cuda(non_blocking=True) if num_seqs is not None else None } if args.pipeline_model_parallel_size == 1: _broadcast(batch['tokens']) _broadcast(batch['labels']) _broadcast(batch['loss_mask']) _broadcast(batch['attention_mask']) _broadcast(batch['num_seqs']) elif mpu.is_pipeline_first_stage(): _broadcast(batch['tokens']) _broadcast(batch['attention_mask']) elif mpu.is_pipeline_last_stage(): # Multi-Token Prediction (MTP) layers need tokens and position_ids to calculate embedding. # Currently the Multi-Token Prediction (MTP) layers is fixed on the last stage, so we need # to broadcast tokens and position_ids to all of the tensor parallel ranks on the last stage. if getattr(args, 'mtp_num_layers', None) is not None: _broadcast(batch['tokens']) _broadcast(batch['labels']) _broadcast(batch['loss_mask']) _broadcast(batch['attention_mask']) _broadcast(batch['num_seqs']) _broadcast(batch['position_ids']) else: # dtype: long, long, float, bool, long tokens = torch.empty((args.micro_batch_size, args.seq_length), dtype=torch.int64, device=torch.cuda.current_device()) labels = torch.empty((args.micro_batch_size, args.seq_length), dtype=torch.int64, device=torch.cuda.current_device()) loss_mask = torch.empty((args.micro_batch_size, args.seq_length), dtype=torch.float32, device=torch.cuda.current_device()) attention_mask = None if args.create_attention_mask_in_dataloader: mbs = args.micro_batch_size if args.reset_attention_mask else 1 attention_mask = torch.empty((mbs, 1, args.seq_length, args.seq_length), dtype=torch.bool, device=torch.cuda.current_device()) position_ids = torch.empty((args.micro_batch_size, args.seq_length), dtype=torch.int64, device=torch.cuda.current_device()) num_seqs = None if per_seq_average: num_seqs = torch.empty((args.micro_batch_size,), dtype=torch.int64, device=torch.cuda.current_device()) if args.pipeline_model_parallel_size == 1: _broadcast(tokens) _broadcast(labels) _broadcast(loss_mask) _broadcast(attention_mask) _broadcast(num_seqs) elif mpu.is_pipeline_first_stage(): labels = None loss_mask = None num_seqs = None _broadcast(tokens) _broadcast(attention_mask) elif mpu.is_pipeline_last_stage(): # Multi-Token Prediction (MTP) layers need tokens and position_ids to calculate embedding. # Currently the Multi-Token Prediction (MTP) layers is fixed on the last stage, so we need # to broadcast tokens and position_ids to all of the tensor parallel ranks on the last stage. if getattr(args, 'mtp_num_layers', None) is not None: _broadcast(tokens) else: tokens = None _broadcast(labels) _broadcast(loss_mask) _broadcast(attention_mask) _broadcast(num_seqs) _broadcast(position_ids) batch = { 'tokens': tokens, 'labels': labels, 'loss_mask': loss_mask, 'attention_mask': attention_mask, 'position_ids': position_ids, 'num_seqs': num_seqs } return batch