in picotron/data.py [0:0]
def collate_batch(self, batch):
batch_input_ids = torch.stack([torch.tensor(item['input_ids']) for item in batch])
batch_size = batch_input_ids.size(0)
start_idx = pgm.process_group_manager.cp_rank * self.seq_length_per_gpu
end_idx = start_idx + self.seq_length_per_gpu
input_ids = batch_input_ids[:, start_idx:end_idx].contiguous()
target_ids = batch_input_ids[:, start_idx+1:end_idx+1].contiguous()
position_ids = torch.arange(start_idx, end_idx, dtype=torch.long).unsqueeze(0).expand(batch_size, -1).contiguous()
return {
"input_ids": input_ids,
"target_ids": target_ids,
"position_ids": position_ids,
"hidden_states": None
}