in drqa/reader/vector.py [0:0]
def batchify(batch):
"""Gather a batch of individual examples into one batch."""
NUM_INPUTS = 3
NUM_TARGETS = 2
NUM_EXTRA = 1
ids = [ex[-1] for ex in batch]
docs = [ex[0] for ex in batch]
features = [ex[1] for ex in batch]
questions = [ex[2] for ex in batch]
# Batch documents and features
max_length = max([d.size(0) for d in docs])
x1 = torch.LongTensor(len(docs), max_length).zero_()
x1_mask = torch.ByteTensor(len(docs), max_length).fill_(1)
if features[0] is None:
x1_f = None
else:
x1_f = torch.zeros(len(docs), max_length, features[0].size(1))
for i, d in enumerate(docs):
x1[i, :d.size(0)].copy_(d)
x1_mask[i, :d.size(0)].fill_(0)
if x1_f is not None:
x1_f[i, :d.size(0)].copy_(features[i])
# Batch questions
max_length = max([q.size(0) for q in questions])
x2 = torch.LongTensor(len(questions), max_length).zero_()
x2_mask = torch.ByteTensor(len(questions), max_length).fill_(1)
for i, q in enumerate(questions):
x2[i, :q.size(0)].copy_(q)
x2_mask[i, :q.size(0)].fill_(0)
# Maybe return without targets
if len(batch[0]) == NUM_INPUTS + NUM_EXTRA:
return x1, x1_f, x1_mask, x2, x2_mask, ids
elif len(batch[0]) == NUM_INPUTS + NUM_EXTRA + NUM_TARGETS:
# ...Otherwise add targets
if torch.is_tensor(batch[0][3]):
y_s = torch.cat([ex[3] for ex in batch])
y_e = torch.cat([ex[4] for ex in batch])
else:
y_s = [ex[3] for ex in batch]
y_e = [ex[4] for ex in batch]
else:
raise RuntimeError('Incorrect number of inputs per example.')
return x1, x1_f, x1_mask, x2, x2_mask, y_s, y_e, ids