def flush_left()

in src/hyperpod_nemo_adapter/utils/dpo_utils.py [0:0]


def flush_left(mask: torch.Tensor, *tensors: torch.Tensor) -> tuple[torch.Tensor, ...]:
    # Create copy of mask and tensors
    mask = mask.clone()
    tensors = [t.clone() for t in tensors]

    # Shift non-zero values to the left
    for i in range(mask.size(0)):
        first_one_idx = torch.nonzero(mask[i])[0].item()
        mask[i] = torch.roll(mask[i], shifts=-first_one_idx)
        for tensor in tensors:
            tensor[i] = torch.roll(tensor[i], shifts=-first_one_idx)

    # Get the first column idx that is all zeros and remove every column after that
    empty_cols = torch.sum(mask, dim=0) == 0
    first_empty_col = torch.nonzero(empty_cols)[0].item() if empty_cols.any() else mask.size(1)
    mask = mask[:, :first_empty_col]
    for i, tensor in enumerate(tensors):
        tensors[i] = tensor[:, :first_empty_col]

    if not tensors:
        return mask
    else:
        return mask, *tensors