def collate_func()

in captioning/data/dataloader_show_control_tell.py [0:0]


    def collate_func(self, batch, split):
        seq_per_img = self.seq_per_img

        fc_batch = []
        att_batch = []
        label_batch = []
        trace_batch = []
        box_batch = []

        show_trace_feat_batch = []
        show_label_batch = []
        show_gate_label_batch = []

        wrapped = False

        infos = []
        gts = []

        for sample in batch:
            # fetch image
            tmp_fc, tmp_att, tmp_trace, tmp_box, tmp_seq, \
                ix, it_pos_now, tmp_wrapped, tmp_show_seq, tmp_show_trace_feat, tmp_show_gate_label_orig = sample
            if tmp_wrapped:
                wrapped = True

            fc_batch.append(tmp_fc)
            att_batch.append(tmp_att)
            trace_batch.append(tmp_trace)
            box_batch.append(tmp_box)
            # show-control-tell
            for tmp_i in range(tmp_show_trace_feat.shape[0]):
                show_trace_feat_batch.append(tmp_show_trace_feat[tmp_i]) # append the trace feats of one caption sentence

            
            tmp_label = np.zeros([seq_per_img, self.seq_length + 2], dtype = 'int')
            if hasattr(self, 'h5_label_file'):
                # if there is ground truth
                tmp_label[:, 1 : self.seq_length + 1] = tmp_seq
            label_batch.append(tmp_label)

            tmp_show_label = np.zeros([5, self.show_seq_length + 2], dtype='int')
            tmp_show_label[:, 1: self.show_seq_length + 1] = tmp_show_seq
            show_label_batch.append(tmp_show_label)

            # for gate
            tmp_show_gate_label = np.zeros([5, self.show_seq_length + 2], dtype='int')
            tmp_show_gate_label[:, 1: self.show_seq_length + 1] = tmp_show_gate_label_orig[:5, :self.show_seq_length]
            show_gate_label_batch.append(tmp_show_gate_label)



            # Used for reward evaluation
            if hasattr(self, 'h5_label_file'):
                # if there is ground truth
                gts.append(self.label[self.label_start_ix[ix] - 1: self.label_end_ix[ix]])
            else:
                gts.append([])
        
            # record associated info as well
            info_dict = {}
            info_dict['ix'] = ix
            info_dict['id'] = self.info['images'][ix]['id']
            info_dict['file_path'] = self.info['images'][ix].get('file_path', '')
            infos.append(info_dict)

        # #sort by att_feat length
        # fc_batch, att_batch, label_batch, gts, infos = \
        #     zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: len(x[1]), reverse=True))
        # commented for classification
        # fc_batch, att_batch, trace_batch, box_batch, label_batch, gts, infos = \
        #     zip(*sorted(zip(fc_batch, att_batch, trace_batch, box_batch, label_batch, gts, infos), key=lambda x: 0, reverse=True))

        data = {}
        data['fc_feats'] = np.stack(fc_batch)
        # merge att_feats
        max_att_len = max([_.shape[0] for _ in att_batch])
        data['att_feats'] = np.zeros([len(att_batch), max_att_len, att_batch[0].shape[1]], dtype = 'float32')
        data['box_feats'] = np.zeros([len(box_batch), max_att_len, box_batch[0].shape[1]], dtype='float32')
        assert att_batch[0].shape[0] == box_batch[0].shape[0], 'box should have same shape[0] with att'
        for i in range(len(att_batch)):
            data['att_feats'][i, :att_batch[i].shape[0]] = att_batch[i]
            data['box_feats'][i, :box_batch[i].shape[0]] = box_batch[i]
        data['att_masks'] = np.zeros(data['att_feats'].shape[:2], dtype='float32')
        for i in range(len(att_batch)):
            data['att_masks'][i, :att_batch[i].shape[0]] = 1
        # set att_masks to None if attention features have same length #commented by zihang
        # if data['att_masks'].sum() == data['att_masks'].size:
        #     data['att_masks'] = None

        # merge trace_feats
        max_trace_len = max([_.shape[0] for _ in trace_batch])
        data['trace_feats'] = np.zeros([len(trace_batch), max_trace_len, trace_batch[0].shape[1]], dtype='float32')
        for i in range(len(trace_batch)):
            data['trace_feats'][i, :trace_batch[i].shape[0]] = trace_batch[i]
        data['trace_masks'] = np.zeros(data['trace_feats'].shape[:2], dtype='float32')
        for i in range(len(trace_batch)):
            data['trace_masks'][i, :trace_batch[i].shape[0]] = 1
        # set trace_masks to None if attention features have same length #commented by zihang
        # if data['trace_masks'].sum() == data['trace_masks'].size:
        #     data['trace_masks'] = None

        # merge show-control-tell trace feats
        max_trace_len = max([_.shape[0] for _ in show_trace_feat_batch])
        data['show_trace_feats'] = np.zeros([len(show_trace_feat_batch), max_trace_len, show_trace_feat_batch[0].shape[1]], dtype='float32')
        for i in range(len(show_trace_feat_batch)):
            data['show_trace_feats'][i, :show_trace_feat_batch[i].shape[0]] = show_trace_feat_batch[i]
        data['show_trace_masks'] = np.zeros(data['show_trace_feats'].shape[:2], dtype='float32')
        for i in range(len(show_trace_feat_batch)):
            data['show_trace_masks'][i, :show_trace_feat_batch[i].shape[0]] = 1
        for i in range(data['show_trace_feats'].shape[0]):
            for j in range(data['show_trace_feats'].shape[1]):
                if data['show_trace_feats'][i,j,0] < 0:
                    data['show_trace_masks'][i, j] = 0
        data['show_trace_feats'] = np.clip(data['show_trace_feats'], 0., 1.)

        data['labels'] = np.vstack(label_batch)
        data['show_labels'] = np.expand_dims(np.vstack(show_label_batch), 1)
        data['show_gate_labels'] = np.expand_dims(np.vstack(show_gate_label_batch), 1)
        # generate mask
        nonzeros = np.array(list(map(lambda x: (x != 0).sum()+2, data['labels'])))
        mask_batch = np.zeros([data['labels'].shape[0], self.seq_length + 2], dtype = 'float32')
        for ix, row in enumerate(mask_batch):
            row[:nonzeros[ix]] = 1
        data['masks'] = mask_batch
        data['labels'] = data['labels'].reshape(len(batch), seq_per_img, -1)
        data['masks'] = data['masks'].reshape(len(batch), seq_per_img, -1)
        # generate mask for show-control-tell
        nonzeros = np.array(list(map(lambda x: (x != 0).sum() + 2, data['show_labels'])))
        mask_batch = np.zeros([data['show_labels'].shape[0], self.show_seq_length + 2], dtype='float32')
        for ix, row in enumerate(mask_batch):
            row[:nonzeros[ix]] = 1
        data['show_masks'] = np.expand_dims(mask_batch, 1)

        data['gts'] = gts # all ground truth captions of each images
        data['bounds'] = {'it_pos_now': it_pos_now, # the it_pos_now of the last sample
                          'it_max': len(self.split_ix[split]), 'wrapped': wrapped}
        # print('In dataloader', len(self.split_ix[split]), split, infos)###zihang
        data['infos'] = infos

        data = {k:torch.from_numpy(v) if type(v) is np.ndarray else v for k,v in data.items()} # Turn all ndarray to torch tensor

        return data