def get_batch_for_cp_rank()

in src/hyperpod_nemo_adapter/utils/train_utils.py [0:0]


def get_batch_for_cp_rank(batch):
    # Based on https://github.com/NVIDIA/NeMo/blob/58d6bcee313a44d926a54e51c69222ddae20f070/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L840
    cp_size = tsm.state.cp_size
    cp_rank = tsm.state.cp_rank
    if cp_size > 1:
        return_batch = []
        for val in batch:
            if val is not None:
                seq_dim = 1
                val = val.view(
                    *val.shape[0:seq_dim],
                    2 * cp_size,
                    val.shape[seq_dim] // (2 * cp_size),
                    *val.shape[(seq_dim + 1) :],
                )
                index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device=val.device)
                val = val.index_select(seq_dim, index)
                val = val.view(*val.shape[0:seq_dim], -1, *val.shape[(seq_dim + 2) :])
            return_batch.append(val)
        return_batch = tuple(return_batch)
    else:
        return_batch = batch
    return return_batch