in src/hyperpod_nemo_adapter/utils/train_utils.py [0:0]
def get_batch_for_cp_rank(batch):
# Based on https://github.com/NVIDIA/NeMo/blob/58d6bcee313a44d926a54e51c69222ddae20f070/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L840
cp_size = tsm.state.cp_size
cp_rank = tsm.state.cp_rank
if cp_size > 1:
return_batch = []
for val in batch:
if val is not None:
seq_dim = 1
val = val.view(
*val.shape[0:seq_dim],
2 * cp_size,
val.shape[seq_dim] // (2 * cp_size),
*val.shape[(seq_dim + 1) :],
)
index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device=val.device)
val = val.index_select(seq_dim, index)
val = val.view(*val.shape[0:seq_dim], -1, *val.shape[(seq_dim + 2) :])
return_batch.append(val)
return_batch = tuple(return_batch)
else:
return_batch = batch
return return_batch