def start_and_end_token_handling()

in src/flint/torch_util.py [0:0]


def start_and_end_token_handling(inputs, lengths, sos_index=1, eos_index=2, pad_index=0,
                                 op=None):
    """
    :param inputs: [B, T]
    :param lengths: [B]
    :param sos_index:
    :param eos_index:
    :param pad_index:
    :return:
    """
    batch_size = inputs.size(0)

    if not op:
        return inputs, lengths
    elif op == 'rm_start':
        inputs = torch.cat([inputs[:, 1:], Variable(inputs.data.new(batch_size, 1).zero_())], dim=1)
        return inputs, lengths - 1
    elif op == 'rm_end':
        for i in range(batch_size):
            pass
            # Potential problems!?
            # inputs[i, lengths[i] - 1] = pad_index
        return inputs, lengths - 1
    elif op == 'rm_both':
        for i in range(batch_size):
            pass
            # Potential problems!?
            # inputs[i, lengths[i] - 1] = pad_index
        inputs = torch.cat([inputs[:, 1:], Variable(inputs.data.new(batch_size, 1).zero_())], dim=1)
        return inputs, lengths - 2