in dialogue_personalization/utils/data_reader.py [0:0]
def collate_fn(data):
def merge(sequences):
lengths = [len(seq) for seq in sequences]
padded_seqs = torch.ones(len(sequences), max(lengths)).long()
for i, seq in enumerate(sequences):
end = lengths[i]
padded_seqs[i, :end] = seq[:end]
return padded_seqs, lengths
data.sort(key=lambda x: len(x["input_batch"]), reverse=True) ## sort by source seq
item_info = {}
for key in data[0].keys():
item_info[key] = [d[key] for d in data]
input_batch, input_lengths = merge(item_info['input_batch'])
target_batch, target_lengths = merge(item_info['target_batch'])
input_batch = input_batch.transpose(0, 1)
target_batch = target_batch.transpose(0, 1)
input_lengths = torch.LongTensor(input_lengths)
target_lengths = torch.LongTensor(target_lengths)
if config.USE_CUDA:
input_batch = input_batch.cuda()
target_batch = target_batch.cuda()
input_lengths = input_lengths.cuda()
target_lengths = target_lengths.cuda()
d = {}
d["input_batch"] = input_batch
d["target_batch"] = target_batch
d["input_lengths"] = input_lengths
d["target_lengths"] = target_lengths
d["input_txt"] = item_info["input_txt"]
d["target_txt"] = item_info["target_txt"]
d["cand_txt"] = item_info["cand_txt"]
d["cand_index"] = item_info["cand_index"]
d["persona_txt"] = item_info["persona_txt"]
if 'input_ext_vocab_batch' in item_info:
input_ext_vocab_batch, _ = merge(item_info['input_ext_vocab_batch'])
target_ext_vocab_batch, _ = merge(item_info['target_ext_vocab_batch'])
input_ext_vocab_batch = input_ext_vocab_batch.transpose(0, 1)
target_ext_vocab_batch = target_ext_vocab_batch.transpose(0, 1)
if config.USE_CUDA:
input_ext_vocab_batch = input_ext_vocab_batch.cuda()
target_ext_vocab_batch = target_ext_vocab_batch.cuda()
d["input_ext_vocab_batch"] = input_ext_vocab_batch
d["target_ext_vocab_batch"] = target_ext_vocab_batch
if "article_oovs" in item_info:
d["article_oovs"] = item_info["article_oovs"]
d["max_art_oovs"] = max(len(art_oovs) for art_oovs in item_info["article_oovs"])
return d