def forward()

in backup/models.py [0:0]


    def forward(self, dials, dial_length, response):
        import ipdb

        ipdb.set_trace()
        inp_vec = torch.stack(
            [self.extract_sentence_bert(batch) for batch in dials], dim=0
        )
        inp_vec = torch.mul(inp_vec, self.down)
        inp_vec, _ = batchify(inp_vec, vector_mode=True)
        dials = inp_vec.to(self.args.device)
        y = self.data.pca_predict([self.extract_sentence_bert(response)])[0]
        response = torch.stack(y, dim=0)
        response = response.to(self.args.device)
        dial_length = np.array(dial_length)
        inp_len_sorted, idx_sort = np.sort(dial_length)[::-1], np.argsort(-dial_length)
        inp_len_sorted = inp_len_sorted.copy()
        idx_unsort = np.argsort(idx_sort)

        idx_sort = torch.from_numpy(idx_sort).to(dials.device)
        dials = dials.index_select(0, idx_sort)
        inp_len_sorted_nonzero_idx = np.nonzero(inp_len_sorted)[0]
        inp_len_sorted_nonzero_idx = torch.from_numpy(inp_len_sorted_nonzero_idx).to(
            dials.device
        )
        inp_len_sorted = torch.from_numpy(inp_len_sorted).to(dials.device)
        non_zero_data = dials.index_select(0, inp_len_sorted_nonzero_idx)
        data_pack = pack_padded_sequence(
            non_zero_data, inp_len_sorted[inp_len_sorted_nonzero_idx], batch_first=True
        )
        outp, hidden_rep = self.lstm(data_pack)
        outp, _ = pad_packed_sequence(outp, batch_first=True)
        outp = outp.contiguous()
        outp_l = torch.zeros((dials.size(0), dials.size(1), outp.size(2))).to(
            outp.device
        )
        outp_l[inp_len_sorted_nonzero_idx] = outp
        # unsort
        idx_unsort = torch.from_numpy(idx_unsort).to(outp_l.device)
        outp_l = outp_l.index_select(0, idx_unsort)

        # last outp
        hidden_rep = torch.max(outp_l, 1)[0]
        hidden_rep = torch.bmm(
            hidden_rep.unsqueeze(1),
            self.W.unsqueeze(0).repeat(hidden_rep.size(0), 1, 1),
        )
        hidden_rep = hidden_rep.squeeze(1)
        response = response.squeeze(1)
        rep = torch.cat(
            [
                hidden_rep,
                response,
                torch.abs(hidden_rep - response),
                hidden_rep * response,
            ],
            dim=1,
        )
        return self.sigmoid(self.decoder(rep))