def collate_fn()

in dvd_codebase/data/dataset.py [0:0]


def collate_fn(data, vocab):
    def pad_seq(seqs, pad_token, return_lens=False, is_vft=False):
        lengths = [s.shape[0] for s in seqs]
        max_length = max(lengths)
        output = []
        for seq in seqs:
            if is_vft:
                if len(seq.shape)==4: # spatio-temporal feature
                    result = np.ones((max_length, seq.shape[1], seq.shape[2], seq.shape[3]), dtype=seq.dtype)*pad_token
                else:
                    result = np.ones((max_length, seq.shape[-1]), dtype=seq.dtype)*pad_token
            else:
                result = np.ones(max_length, dtype=seq.dtype)*pad_token
            result[:seq.shape[0]] = seq 
            output.append(result)
        if return_lens:
            return lengths, output
        return output 
    
    def pad_2d_seq(seqs, pad_token, return_lens=False, is_vft=False):
        lens1 = [len(s) for s in seqs]
        max_len1 = max(lens1)
        all_seqs = []
        for seq in seqs:
            all_seqs.extend(seq)
        lens2 = [len(s) for s in all_seqs]
        max_len2 = max(lens2)
        output = []
        all_lens = []
        for seq in seqs:
            if is_vft:
                result = np.ones((max_len1, max_len2, seq[0].shape[-1]))*pad_token
            else:
                result = np.ones((max_len1, max_len2))*pad_token
            turn_lens = np.ones(max_len1, dtype=np.int)
            for turn_idx, turn in enumerate(seq):
                result[turn_idx,:turn.shape[0]] = turn
                turn_lens[turn_idx] = turn.shape[0]
            output.append(result)
            all_lens.append(turn_lens)
        all_lens = np.asarray(all_lens)
        if return_lens:
            return lens1, all_lens, output
        return output

    def prepare_data(seqs, is_float=False):
        if is_float:
            return torch.from_numpy(np.asarray(seqs)).float()
        return torch.from_numpy(np.asarray(seqs)).long()
                        
    item_info = {}
    for key in data[0].keys():
        item_info[key] = [d[key] for d in data]   
    pad_token = vocab['<blank>']
    h_lens, h_padded = pad_seq(item_info['history'], pad_token, return_lens=True)
    h_batch = prepare_data(h_padded)
    q_lens, q_padded = pad_seq(item_info['question'], pad_token, return_lens=True)
    q_batch = prepare_data(q_padded)
    
    hq = [np.concatenate([q,h]) for q,h in zip(item_info['history'], item_info['question'])]
    hq_lens, hq_padded = pad_seq(hq, pad_token, return_lens=True)
    hq_batch = prepare_data(hq_padded) 
    
    dial_lens, turn_lens, turns_padded = pad_2d_seq(item_info['turns'], pad_token, return_lens=True)
    turns_batch = prepare_data(turns_padded)
    a_batch = prepare_data(item_info['answer'])
    
    vft_lens, vft_padded = pad_seq(item_info['curr_vft'], 0, return_lens=True, is_vft=True)        
    vft_batch = prepare_data(vft_padded, is_float=True)

    assert vft_lens == item_info['vft_sizes']
    
    p_lens, p_padded = pad_seq(item_info['program'], pad_token, return_lens=True)
    p_batch = prepare_data(p_padded)
    
    s_lens, s_padded = pad_seq(item_info['state'], pad_token, return_lens=True)
    s_batch = prepare_data(s_padded)                  
    
    batch = Batch(vft_batch,  
                  h_batch, q_batch, hq_batch, turns_batch, a_batch, 
                  item_info['vid_split'], item_info['vid'], item_info['qa_id'], 
                  q_lens, h_lens, hq_lens, vft_lens, 
                  dial_lens, turn_lens,
                  p_batch, p_lens, s_batch, s_lens,
                  vocab)
    return batch