in src/evaluation.py [0:0]
def eval_clf_dis_accuracy(self):
"""
Compute the classifier discriminator prediction accuracy.
"""
data = self.data
params = self.params
self.ae.eval()
self.clf_dis.eval()
bs = params.batch_size
all_preds = [[] for _ in range(params.n_attr)]
for i in range(0, len(data), bs):
# batch / encode / decode
batch_x, batch_y = data.eval_batch(i, i + bs)
enc_outputs = self.ae.encode(batch_x)
# flip all attributes one by one
k = 0
for j, (_, n_cat) in enumerate(params.attr):
for value in range(n_cat):
flipped = flip_attributes(batch_y, params, j, new_value=value)
dec_outputs = self.ae.decode(enc_outputs, flipped)
# classify
clf_dis_preds = self.clf_dis(dec_outputs[-1])[:, j:j + n_cat].max(1)[1].view(-1)
all_preds[k].extend((clf_dis_preds.data.cpu() == value).tolist())
k += 1
assert k == params.n_attr
return [np.mean(x) for x in all_preds]