in python/dgllife/model/model_zoo/jtvae.py [0:0]
def decode(self, mol_vec, prob_decode):
device = mol_vec.device
stack = []
init_hidden = torch.zeros(1, self.hidden_size).to(device)
zero_pad = torch.zeros(1, 1, self.hidden_size).to(device)
# Root Prediction
root_hidden = torch.cat([init_hidden, mol_vec], dim=1)
root_hidden = F.relu(self.W(root_hidden))
root_score = self.W_o(root_hidden)
_, root_wid = torch.max(root_score, dim=1)
root_wid = root_wid.item()
root = mol_tree_node(smiles=self.vocab.get_smiles(root_wid), wid=root_wid, idx=0)
stack.append((root, self.vocab.get_slots(root['wid'])))
all_nodes = [root]
h = {}
for step in range(MAX_DECODE_LEN):
node_x, fa_slot = stack[-1]
cur_h_nei = [h[(node_y['idx'], node_x['idx'])] for node_y in node_x['neighbors']]
if len(cur_h_nei) > 0:
cur_h_nei = torch.stack(cur_h_nei, dim=0).view(1, -1, self.hidden_size)
else:
cur_h_nei = zero_pad
cur_x = torch.LongTensor([node_x['wid']]).to(device)
cur_x = self.embedding(cur_x)
# Predict stop
cur_h = cur_h_nei.sum(dim=1)
stop_hidden = torch.cat([cur_x, cur_h, mol_vec], dim=1)
stop_hidden = F.relu(self.U(stop_hidden))
stop_score = torch.sigmoid(self.U_s(stop_hidden) * 20).squeeze()
if prob_decode:
backtrack = (torch.bernoulli(1.0 - stop_score.data)[0] == 1)
else:
backtrack = (stop_score.item() < 0.5)
if not backtrack: # Forward: Predict next clique
new_h = gru_functional(cur_x, cur_h_nei, self.gru_update.W_z, self.W_r,
self.gru_message.U_r, self.gru_update.W_h)
pred_hidden = torch.cat([new_h, mol_vec], dim=1)
pred_hidden = F.relu(self.W(pred_hidden))
pred_score = torch.softmax(self.W_o(pred_hidden) * 20, dim=1)
if prob_decode:
sort_wid = torch.multinomial(pred_score.data.squeeze(), 5)
else:
_, sort_wid = torch.sort(pred_score, dim=1, descending=True)
sort_wid = sort_wid.data.squeeze()
next_wid = None
for wid in sort_wid[:5]:
slots = self.vocab.get_slots(wid)
node_y = mol_tree_node(smiles=self.vocab.get_smiles(wid))
if have_slots(fa_slot, slots) and can_assemble(node_x, node_y):
next_wid = wid
next_slots = slots
break
if next_wid is None:
backtrack = True # No more children can be added
else:
node_y = mol_tree_node(smiles=self.vocab.get_smiles(next_wid),
wid=next_wid, idx=step + 1, nbrs=[node_x])
h[(node_x['idx'], node_y['idx'])] = new_h[0]
stack.append((node_y, next_slots))
all_nodes.append(node_y)
if backtrack: # Backtrack, use if instead of else
if len(stack) == 1:
break # At root, terminate
node_fa, _ = stack[-2]
cur_h_nei = [h[(node_y['idx'], node_x['idx'])] for node_y in node_x['neighbors']
if node_y['idx'] != node_fa['idx']]
if len(cur_h_nei) > 0:
cur_h_nei = torch.stack(cur_h_nei, dim=0).view(1, -1, self.hidden_size)
else:
cur_h_nei = zero_pad
new_h = gru_functional(cur_x, cur_h_nei, self.gru_update.W_z, self.W_r,
self.gru_message.U_r, self.gru_update.W_h)
h[(node_x['idx'], node_fa['idx'])] = new_h[0]
node_fa['neighbors'].append(node_x)
stack.pop()
return root, all_nodes