in backup/models.py [0:0]
def forward(self, dials, dial_length, response):
import ipdb
ipdb.set_trace()
inp_vec = torch.stack(
[self.extract_sentence_bert(batch) for batch in dials], dim=0
)
inp_vec = torch.mul(inp_vec, self.down)
inp_vec, _ = batchify(inp_vec, vector_mode=True)
dials = inp_vec.to(self.args.device)
y = self.data.pca_predict([self.extract_sentence_bert(response)])[0]
response = torch.stack(y, dim=0)
response = response.to(self.args.device)
dial_length = np.array(dial_length)
inp_len_sorted, idx_sort = np.sort(dial_length)[::-1], np.argsort(-dial_length)
inp_len_sorted = inp_len_sorted.copy()
idx_unsort = np.argsort(idx_sort)
idx_sort = torch.from_numpy(idx_sort).to(dials.device)
dials = dials.index_select(0, idx_sort)
inp_len_sorted_nonzero_idx = np.nonzero(inp_len_sorted)[0]
inp_len_sorted_nonzero_idx = torch.from_numpy(inp_len_sorted_nonzero_idx).to(
dials.device
)
inp_len_sorted = torch.from_numpy(inp_len_sorted).to(dials.device)
non_zero_data = dials.index_select(0, inp_len_sorted_nonzero_idx)
data_pack = pack_padded_sequence(
non_zero_data, inp_len_sorted[inp_len_sorted_nonzero_idx], batch_first=True
)
outp, hidden_rep = self.lstm(data_pack)
outp, _ = pad_packed_sequence(outp, batch_first=True)
outp = outp.contiguous()
outp_l = torch.zeros((dials.size(0), dials.size(1), outp.size(2))).to(
outp.device
)
outp_l[inp_len_sorted_nonzero_idx] = outp
# unsort
idx_unsort = torch.from_numpy(idx_unsort).to(outp_l.device)
outp_l = outp_l.index_select(0, idx_unsort)
# last outp
hidden_rep = torch.max(outp_l, 1)[0]
hidden_rep = torch.bmm(
hidden_rep.unsqueeze(1),
self.W.unsqueeze(0).repeat(hidden_rep.size(0), 1, 1),
)
hidden_rep = hidden_rep.squeeze(1)
response = response.squeeze(1)
rep = torch.cat(
[
hidden_rep,
response,
torch.abs(hidden_rep - response),
hidden_rep * response,
],
dim=1,
)
return self.sigmoid(self.decoder(rep))