in src/hyperpod_nemo_adapter/utils/dpo_utils.py [0:0]
def flush_left(mask: torch.Tensor, *tensors: torch.Tensor) -> tuple[torch.Tensor, ...]:
# Create copy of mask and tensors
mask = mask.clone()
tensors = [t.clone() for t in tensors]
# Shift non-zero values to the left
for i in range(mask.size(0)):
first_one_idx = torch.nonzero(mask[i])[0].item()
mask[i] = torch.roll(mask[i], shifts=-first_one_idx)
for tensor in tensors:
tensor[i] = torch.roll(tensor[i], shifts=-first_one_idx)
# Get the first column idx that is all zeros and remove every column after that
empty_cols = torch.sum(mask, dim=0) == 0
first_empty_col = torch.nonzero(empty_cols)[0].item() if empty_cols.any() else mask.size(1)
mask = mask[:, :first_empty_col]
for i, tensor in enumerate(tensors):
tensors[i] = tensor[:, :first_empty_col]
if not tensors:
return mask
else:
return mask, *tensors