def _gc_dataloader()

in baseline_model/data_utils/ggnn_utils.py [0:0]


def _gc_dataloader(train_size, valid_size, in_f, in_g, in_a, batch_size, reverse_edge, hid_dim):
    src_token_len = 0
    # global pos_table
    # pos_table = _get_sinusoid_encoding_table(682, 416)
    def _collate_fn(batch):
        graphs = []
        labels = []
        feats = []
        for d in batch:
            edges = d['g']
            node_ids = []
            for s, e, t in edges:
                if s not in node_ids:
                    node_ids.append(s)
                if t not in node_ids:
                    node_ids.append(t)
            feat_dict = d['f']
            g = dgl.DGLGraph()
            # g.add_nodes(len(feat_dict))
            # idmap = range(0,len(feat_dict))
            # g.ndata['node_id'] = torch.tensor(idmap, dtype=torch.long)
            g.add_nodes(d['a'].shape[0])
            idmap = range(0,(d['a'].shape[0]))
            g.ndata['node_id'] = torch.tensor(idmap, dtype=torch.long)
            # nid2idx = dict(zip(node_ids, list(range(len(node_ids)))))
            # labels.append(d['eval'][-1])

            edge_types = []
            for s, e, t in edges:
                g.add_edge(s, t)
                edge_types.append(e)
                if e in reverse_edge:
                    g.add_edge(t, s)
                    edge_types.append(reverse_edge[e])
            g.edata['type'] = torch.tensor(edge_types, dtype=torch.long)

            # features
            # feats.append(feat_dict)
            # annotation = torch.zeros([len(feat_dict), hid_dim_in], dtype=torch.long)
            # pdb.set_trace()
            # node_id = torch.zeros([len(feat_dict), 12],dtype=torch.long)
            # for idx in range(0,len(feat_dict)):
            #     annotation[idx][feat_dict[idx]] = 1
            #     node_id[idx] = torch.tensor([int(b) for b in "{:012b}".format(np.abs(0))],dtype=torch.long)

            # annotation = torch.cat([annotation,node_id],dim=1) #+ pos_table[:, :annotation.size(0)]*0.25
            # g.ndata['annotation'] = torch.tensor(d['a'],dtype=torch.long)
            g.ndata['annotation'] = torch.tensor(feat_dict, dtype=torch.long)
            graphs.append(g)
        batch_graph = dgl.batch(graphs)
        # pdb.set_trace()
        return batch_graph

    def _get_dataloader(in_f, in_g, in_a, train_size, valid_size, shuffle):
        # in_f_list = np.int_(in_f).tolist()
        in_f_list = in_f.tolist()
        in_g_list = np.int_(in_g).tolist()
        # in_f_list = in_f.tolist()
        # in_g_list = in_g.tolist()
        train_dict = []
        valid_dict = []
        src_token_len = 0
        for i in range(0,train_size):
            cur_max = max(in_f_list[i])
            if max(in_f_list[i])> src_token_len:
                src_token_len = max(in_f_list[i])
            try:
                ind_f = in_f_list[i].index(0)
                in_f_tmp = in_f_list[i][:ind_f]
            except:
                in_f_tmp = in_f_list[i]
            try:
                ind_g = in_g_list[i].index([0,0,0])
                in_g_tmp = in_g_list[i][:ind_g]
            except:
                in_g_tmp = in_g_list[i]
            train_dict.append({'f':in_f_tmp,'g':in_g_tmp, 'a':in_a[i]})

        for i in range(train_size,train_size+valid_size):
            cur_max = max(in_f_list[i])
            if max(in_f_list[i])> src_token_len:
                src_token_len = max(in_f_list[i])
            try:
                ind_f = in_f_list[i].index(0)
                in_f_tmp = in_f_list[i][:ind_f]
            except:
                in_f_tmp = in_f_list[i]
            try:
                ind_g = in_g_list[i].index([0,0,0])
                in_g_tmp = in_g_list[i][:ind_g]
            except:
                in_g_tmp = in_g_list[i]
            valid_dict.append({'f':in_f_tmp,'g':in_g_tmp, 'a':in_a[i]})
        train_dataloader = DataLoader(dataset=train_dict, batch_size=batch_size, shuffle=shuffle, collate_fn=_collate_fn)
        dev_dataloader = DataLoader(dataset=valid_dict, batch_size=batch_size, shuffle=shuffle, collate_fn=_collate_fn)
        return train_dataloader, dev_dataloader, src_token_len

    train_dataloader, dev_dataloader, src_token_len = _get_dataloader(in_f, in_g, in_a, train_size, valid_size, False)
    max_len_src = in_f.shape[1]

    return train_dataloader, dev_dataloader, max_len_src, src_token_len