in baseline_model/data_utils/ggnn_utils.py [0:0]
def _gc_dataloader(train_size, valid_size, in_f, in_g, in_a, batch_size, reverse_edge, hid_dim):
src_token_len = 0
# global pos_table
# pos_table = _get_sinusoid_encoding_table(682, 416)
def _collate_fn(batch):
graphs = []
labels = []
feats = []
for d in batch:
edges = d['g']
node_ids = []
for s, e, t in edges:
if s not in node_ids:
node_ids.append(s)
if t not in node_ids:
node_ids.append(t)
feat_dict = d['f']
g = dgl.DGLGraph()
# g.add_nodes(len(feat_dict))
# idmap = range(0,len(feat_dict))
# g.ndata['node_id'] = torch.tensor(idmap, dtype=torch.long)
g.add_nodes(d['a'].shape[0])
idmap = range(0,(d['a'].shape[0]))
g.ndata['node_id'] = torch.tensor(idmap, dtype=torch.long)
# nid2idx = dict(zip(node_ids, list(range(len(node_ids)))))
# labels.append(d['eval'][-1])
edge_types = []
for s, e, t in edges:
g.add_edge(s, t)
edge_types.append(e)
if e in reverse_edge:
g.add_edge(t, s)
edge_types.append(reverse_edge[e])
g.edata['type'] = torch.tensor(edge_types, dtype=torch.long)
# features
# feats.append(feat_dict)
# annotation = torch.zeros([len(feat_dict), hid_dim_in], dtype=torch.long)
# pdb.set_trace()
# node_id = torch.zeros([len(feat_dict), 12],dtype=torch.long)
# for idx in range(0,len(feat_dict)):
# annotation[idx][feat_dict[idx]] = 1
# node_id[idx] = torch.tensor([int(b) for b in "{:012b}".format(np.abs(0))],dtype=torch.long)
# annotation = torch.cat([annotation,node_id],dim=1) #+ pos_table[:, :annotation.size(0)]*0.25
# g.ndata['annotation'] = torch.tensor(d['a'],dtype=torch.long)
g.ndata['annotation'] = torch.tensor(feat_dict, dtype=torch.long)
graphs.append(g)
batch_graph = dgl.batch(graphs)
# pdb.set_trace()
return batch_graph
def _get_dataloader(in_f, in_g, in_a, train_size, valid_size, shuffle):
# in_f_list = np.int_(in_f).tolist()
in_f_list = in_f.tolist()
in_g_list = np.int_(in_g).tolist()
# in_f_list = in_f.tolist()
# in_g_list = in_g.tolist()
train_dict = []
valid_dict = []
src_token_len = 0
for i in range(0,train_size):
cur_max = max(in_f_list[i])
if max(in_f_list[i])> src_token_len:
src_token_len = max(in_f_list[i])
try:
ind_f = in_f_list[i].index(0)
in_f_tmp = in_f_list[i][:ind_f]
except:
in_f_tmp = in_f_list[i]
try:
ind_g = in_g_list[i].index([0,0,0])
in_g_tmp = in_g_list[i][:ind_g]
except:
in_g_tmp = in_g_list[i]
train_dict.append({'f':in_f_tmp,'g':in_g_tmp, 'a':in_a[i]})
for i in range(train_size,train_size+valid_size):
cur_max = max(in_f_list[i])
if max(in_f_list[i])> src_token_len:
src_token_len = max(in_f_list[i])
try:
ind_f = in_f_list[i].index(0)
in_f_tmp = in_f_list[i][:ind_f]
except:
in_f_tmp = in_f_list[i]
try:
ind_g = in_g_list[i].index([0,0,0])
in_g_tmp = in_g_list[i][:ind_g]
except:
in_g_tmp = in_g_list[i]
valid_dict.append({'f':in_f_tmp,'g':in_g_tmp, 'a':in_a[i]})
train_dataloader = DataLoader(dataset=train_dict, batch_size=batch_size, shuffle=shuffle, collate_fn=_collate_fn)
dev_dataloader = DataLoader(dataset=valid_dict, batch_size=batch_size, shuffle=shuffle, collate_fn=_collate_fn)
return train_dataloader, dev_dataloader, src_token_len
train_dataloader, dev_dataloader, src_token_len = _get_dataloader(in_f, in_g, in_a, train_size, valid_size, False)
max_len_src = in_f.shape[1]
return train_dataloader, dev_dataloader, max_len_src, src_token_len