in python/dgllife/model/model_zoo/jtvae.py [0:0]
def forward(self, cand_batch, tree_mess, device='cpu'):
fatoms, fbonds = [], []
in_bonds, all_bonds = [], []
# Ensure index 0 is vec(0)
mess_dict, all_mess = {}, [torch.zeros(self.hidden_size).to(device)]
total_atoms = 0
scope = []
for e, vec in tree_mess.items():
mess_dict[e] = len(all_mess)
all_mess.append(vec)
for mol, all_nodes, _ in cand_batch:
n_atoms = mol.GetNumAtoms()
for atom in mol.GetAtoms():
fatoms.append(torch.Tensor(self.atom_featurizer(atom)))
in_bonds.append([])
for bond in mol.GetBonds():
a1 = bond.GetBeginAtom()
a2 = bond.GetEndAtom()
x = a1.GetIdx() + total_atoms
y = a2.GetIdx() + total_atoms
# Here x_nid,y_nid could be 0
x_nid, y_nid = a1.GetAtomMapNum(), a2.GetAtomMapNum()
x_bid = all_nodes[x_nid - 1]['idx'] if x_nid > 0 else -1
y_bid = all_nodes[y_nid - 1]['idx'] if y_nid > 0 else -1
bfeature = torch.Tensor(self.bond_featurizer(bond))
b = len(all_mess) + len(all_bonds) # bond idx offseted by len(all_mess)
all_bonds.append((x, y))
fbonds.append(torch.cat([fatoms[x], bfeature], 0))
in_bonds[y].append(b)
b = len(all_mess) + len(all_bonds)
all_bonds.append((y, x))
fbonds.append(torch.cat([fatoms[y], bfeature], 0))
in_bonds[x].append(b)
if x_bid >= 0 and y_bid >= 0 and x_bid != y_bid:
if (x_bid, y_bid) in mess_dict:
mess_idx = mess_dict[(x_bid, y_bid)]
in_bonds[y].append(mess_idx)
if (y_bid, x_bid) in mess_dict:
mess_idx = mess_dict[(y_bid, x_bid)]
in_bonds[x].append(mess_idx)
scope.append((total_atoms, n_atoms))
total_atoms += n_atoms
total_bonds = len(all_bonds)
total_mess = len(all_mess)
fatoms = torch.stack(fatoms, 0).to(device)
fbonds = torch.stack(fbonds, 0).to(device)
agraph = torch.zeros(total_atoms, MAX_NB).long().to(device)
bgraph = torch.zeros(total_bonds, MAX_NB).long().to(device)
tree_message = torch.stack(all_mess, dim=0)
for a in range(total_atoms):
for i, b in enumerate(in_bonds[a]):
if i == MAX_NB:
break
agraph[a, i] = b
for b1 in range(total_bonds):
x, y = all_bonds[b1]
for i, b2 in enumerate(in_bonds[x]): # b2 is offseted by len(all_mess)
if i == MAX_NB:
break
if b2 < total_mess or all_bonds[b2 - total_mess][0] != y:
bgraph[b1, i] = b2
binput = self.W_i(fbonds)
graph_message = F.relu(binput)
for i in range(self.depth - 1):
message = torch.cat([tree_message, graph_message], dim=0)
nei_message = index_select_ND(message, 0, bgraph)
nei_message = nei_message.sum(dim=1)
nei_message = self.W_h(nei_message)
graph_message = F.relu(binput + nei_message)
message = torch.cat([tree_message, graph_message], dim=0)
nei_message = index_select_ND(message, 0, agraph)
nei_message = nei_message.sum(dim=1)
ainput = torch.cat([fatoms, nei_message], dim=1)
atom_hiddens = F.relu(self.W_o(ainput))
mol_vecs = []
for st, le in scope:
mol_vec = atom_hiddens.narrow(0, st, le).sum(dim=0) / le
mol_vecs.append(mol_vec)
mol_vecs = torch.stack(mol_vecs, dim=0)
return mol_vecs