in dvd_codebase/data/dataset.py [0:0]
def collate_fn(data, vocab):
def pad_seq(seqs, pad_token, return_lens=False, is_vft=False):
lengths = [s.shape[0] for s in seqs]
max_length = max(lengths)
output = []
for seq in seqs:
if is_vft:
if len(seq.shape)==4: # spatio-temporal feature
result = np.ones((max_length, seq.shape[1], seq.shape[2], seq.shape[3]), dtype=seq.dtype)*pad_token
else:
result = np.ones((max_length, seq.shape[-1]), dtype=seq.dtype)*pad_token
else:
result = np.ones(max_length, dtype=seq.dtype)*pad_token
result[:seq.shape[0]] = seq
output.append(result)
if return_lens:
return lengths, output
return output
def pad_2d_seq(seqs, pad_token, return_lens=False, is_vft=False):
lens1 = [len(s) for s in seqs]
max_len1 = max(lens1)
all_seqs = []
for seq in seqs:
all_seqs.extend(seq)
lens2 = [len(s) for s in all_seqs]
max_len2 = max(lens2)
output = []
all_lens = []
for seq in seqs:
if is_vft:
result = np.ones((max_len1, max_len2, seq[0].shape[-1]))*pad_token
else:
result = np.ones((max_len1, max_len2))*pad_token
turn_lens = np.ones(max_len1, dtype=np.int)
for turn_idx, turn in enumerate(seq):
result[turn_idx,:turn.shape[0]] = turn
turn_lens[turn_idx] = turn.shape[0]
output.append(result)
all_lens.append(turn_lens)
all_lens = np.asarray(all_lens)
if return_lens:
return lens1, all_lens, output
return output
def prepare_data(seqs, is_float=False):
if is_float:
return torch.from_numpy(np.asarray(seqs)).float()
return torch.from_numpy(np.asarray(seqs)).long()
item_info = {}
for key in data[0].keys():
item_info[key] = [d[key] for d in data]
pad_token = vocab['<blank>']
h_lens, h_padded = pad_seq(item_info['history'], pad_token, return_lens=True)
h_batch = prepare_data(h_padded)
q_lens, q_padded = pad_seq(item_info['question'], pad_token, return_lens=True)
q_batch = prepare_data(q_padded)
hq = [np.concatenate([q,h]) for q,h in zip(item_info['history'], item_info['question'])]
hq_lens, hq_padded = pad_seq(hq, pad_token, return_lens=True)
hq_batch = prepare_data(hq_padded)
dial_lens, turn_lens, turns_padded = pad_2d_seq(item_info['turns'], pad_token, return_lens=True)
turns_batch = prepare_data(turns_padded)
a_batch = prepare_data(item_info['answer'])
vft_lens, vft_padded = pad_seq(item_info['curr_vft'], 0, return_lens=True, is_vft=True)
vft_batch = prepare_data(vft_padded, is_float=True)
assert vft_lens == item_info['vft_sizes']
p_lens, p_padded = pad_seq(item_info['program'], pad_token, return_lens=True)
p_batch = prepare_data(p_padded)
s_lens, s_padded = pad_seq(item_info['state'], pad_token, return_lens=True)
s_batch = prepare_data(s_padded)
batch = Batch(vft_batch,
h_batch, q_batch, hq_batch, turns_batch, a_batch,
item_info['vid_split'], item_info['vid'], item_info['qa_id'],
q_lens, h_lens, hq_lens, vft_lens,
dial_lens, turn_lens,
p_batch, p_lens, s_batch, s_lens,
vocab)
return batch