def eval_clf_dis_accuracy()

in src/evaluation.py [0:0]


    def eval_clf_dis_accuracy(self):
        """
        Compute the classifier discriminator prediction accuracy.
        """
        data = self.data
        params = self.params
        self.ae.eval()
        self.clf_dis.eval()
        bs = params.batch_size

        all_preds = [[] for _ in range(params.n_attr)]
        for i in range(0, len(data), bs):
            # batch / encode / decode
            batch_x, batch_y = data.eval_batch(i, i + bs)
            enc_outputs = self.ae.encode(batch_x)
            # flip all attributes one by one
            k = 0
            for j, (_, n_cat) in enumerate(params.attr):
                for value in range(n_cat):
                    flipped = flip_attributes(batch_y, params, j, new_value=value)
                    dec_outputs = self.ae.decode(enc_outputs, flipped)
                    # classify
                    clf_dis_preds = self.clf_dis(dec_outputs[-1])[:, j:j + n_cat].max(1)[1].view(-1)
                    all_preds[k].extend((clf_dis_preds.data.cpu() == value).tolist())
                    k += 1
            assert k == params.n_attr

        return [np.mean(x) for x in all_preds]