in src/flint/torch_util.py [0:0]
def start_and_end_token_handling(inputs, lengths, sos_index=1, eos_index=2, pad_index=0,
op=None):
"""
:param inputs: [B, T]
:param lengths: [B]
:param sos_index:
:param eos_index:
:param pad_index:
:return:
"""
batch_size = inputs.size(0)
if not op:
return inputs, lengths
elif op == 'rm_start':
inputs = torch.cat([inputs[:, 1:], Variable(inputs.data.new(batch_size, 1).zero_())], dim=1)
return inputs, lengths - 1
elif op == 'rm_end':
for i in range(batch_size):
pass
# Potential problems!?
# inputs[i, lengths[i] - 1] = pad_index
return inputs, lengths - 1
elif op == 'rm_both':
for i in range(batch_size):
pass
# Potential problems!?
# inputs[i, lengths[i] - 1] = pad_index
inputs = torch.cat([inputs[:, 1:], Variable(inputs.data.new(batch_size, 1).zero_())], dim=1)
return inputs, lengths - 2