backup/ablation_models.py [58:83]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        dial_length = np.array(dial_length)
        inp_len_sorted, idx_sort = np.sort(dial_length)[::-1], np.argsort(-dial_length)
        inp_len_sorted = inp_len_sorted.copy()
        idx_unsort = np.argsort(idx_sort)

        idx_sort = torch.from_numpy(idx_sort).to(dials.device)
        dials = dials.index_select(0, idx_sort)
        inp_len_sorted_nonzero_idx = np.nonzero(inp_len_sorted)[0]
        inp_len_sorted_nonzero_idx = torch.from_numpy(inp_len_sorted_nonzero_idx).to(
            dials.device
        )
        inp_len_sorted = torch.from_numpy(inp_len_sorted).to(dials.device)
        non_zero_data = dials.index_select(0, inp_len_sorted_nonzero_idx)
        data_pack = pack_padded_sequence(
            non_zero_data, inp_len_sorted[inp_len_sorted_nonzero_idx], batch_first=True
        )
        outp, hidden_rep = self.lstm(data_pack)
        outp, _ = pad_packed_sequence(outp, batch_first=True)
        outp = outp.contiguous()
        outp_l = torch.zeros((dials.size(0), dials.size(1), outp.size(2))).to(
            outp.device
        )
        outp_l[inp_len_sorted_nonzero_idx] = outp
        # unsort
        idx_unsort = torch.from_numpy(idx_unsort).to(outp_l.device)
        outp_l = outp_l.index_select(0, idx_unsort)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



backup/models.py [183:208]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        dial_length = np.array(dial_length)
        inp_len_sorted, idx_sort = np.sort(dial_length)[::-1], np.argsort(-dial_length)
        inp_len_sorted = inp_len_sorted.copy()
        idx_unsort = np.argsort(idx_sort)

        idx_sort = torch.from_numpy(idx_sort).to(dials.device)
        dials = dials.index_select(0, idx_sort)
        inp_len_sorted_nonzero_idx = np.nonzero(inp_len_sorted)[0]
        inp_len_sorted_nonzero_idx = torch.from_numpy(inp_len_sorted_nonzero_idx).to(
            dials.device
        )
        inp_len_sorted = torch.from_numpy(inp_len_sorted).to(dials.device)
        non_zero_data = dials.index_select(0, inp_len_sorted_nonzero_idx)
        data_pack = pack_padded_sequence(
            non_zero_data, inp_len_sorted[inp_len_sorted_nonzero_idx], batch_first=True
        )
        outp, hidden_rep = self.lstm(data_pack)
        outp, _ = pad_packed_sequence(outp, batch_first=True)
        outp = outp.contiguous()
        outp_l = torch.zeros((dials.size(0), dials.size(1), outp.size(2))).to(
            outp.device
        )
        outp_l[inp_len_sorted_nonzero_idx] = outp
        # unsort
        idx_unsort = torch.from_numpy(idx_unsort).to(outp_l.device)
        outp_l = outp_l.index_select(0, idx_unsort)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



