in reweighted.py [0:0]
def main(args):
regression = args.dataset == "iwpc" or args.dataset == "synth"
data = dataloading.load_dataset(
name=args.dataset, split="train", normalize=not args.no_norm,
num_classes=2, root=args.data_folder, regression=regression)
test_data = dataloading.load_dataset(
name=args.dataset, split="test", normalize=not args.no_norm,
num_classes=2, root=args.data_folder, regression=regression)
if args.pca_dims > 0:
data, pca = dataloading.pca(data, num_dims=args.pca_dims)
test_data, _ = dataloading.pca(test_data, mapping=pca)
model = models.get_model(args.model)
# Find the optimal parameters for the model:
logging.info(f"Training {args.model} model.")
model.train(data, l2=args.l2)
train_accuracy = compute_accuracy(model, data, regression=regression)
test_accuracy = compute_accuracy(model, test_data, regression=regression)
if regression:
logging.info(f"MSE train {train_accuracy:.3f},"
f" test: {test_accuracy:.3f}.")
else:
logging.info(f"Accuracy train {train_accuracy:.3f},"
f" test: {test_accuracy:.3f}.")
# Compute the Fisher information loss, eta, for each example in the
# training set:
logging.info("Computing unweighted etas on training set...")
J = model.influence_jacobian(data)
etas = models.compute_information_loss(J, target_attribute=args.attribute,
constrained=args.constrained)
logging.info(f"etas max: {etas.max().item():.4f},"
f" mean: {etas.mean().item():.4f}, std: {etas.std().item():.4f}.")
# Reweight using the fisher information loss:
updated_fi = etas.reciprocal().detach()
maxs = [etas.max().item()]
means = [etas.mean().item()]
stds = [etas.std().item()]
train_accs = [train_accuracy]
test_accs = [test_accuracy]
all_weights = [torch.ones(len(updated_fi))]
for i in range(args.iters):
logging.info(f"Iter {i}: Training weighted model...")
updated_fi *= (len(updated_fi) / updated_fi.sum())
# TODO does it make sense to renormalize after clamping?
updated_fi.clamp_(min=args.min_weight, max=args.max_weight)
weights = get_weights(args.weight_method, updated_fi, data)
model.train(data, l2=args.l2, weights=weights.detach())
# Check predictions of weighted model:
train_accuracy = compute_accuracy(model, data, regression=regression)
test_accuracy = compute_accuracy(model, test_data, regression=regression)
if regression:
logging.info(f"Weighted model MSE train {train_accuracy:.3f},"
f" test: {test_accuracy:.3f}.")
else:
logging.info(f"Weighted model accuracy train {train_accuracy:.3f},"
f" test: {test_accuracy:.3f}.")
J = model.influence_jacobian(data)
weighted_etas = models.compute_information_loss(J, target_attribute=args.attribute,
constrained=args.constrained)
updated_fi /= weighted_etas
maxs.append(weighted_etas.max().item())
means.append(weighted_etas.mean().item())
stds.append(weighted_etas.std().item())
train_accs.append(train_accuracy)
test_accs.append(test_accuracy)
all_weights.append(weights)
logging.info(f"Weighted etas max: {maxs[-1]:.4f},"
f" mean: {means[-1]:.4f},"
f" std: {stds[-1]:.4f}.")
results = {
"weights" : weights.tolist(),
"etas" : etas.tolist(),
"weighted_etas" : weighted_etas.tolist(),
"eta_maxes" : maxs,
"eta_means" : means,
"eta_stds" : stds,
"train_accs" : train_accs,
"test_accs" : test_accs,
}
with open(args.results_file + ".json", 'w') as fid:
json.dump(results, fid)
torch.save(torch.stack(all_weights), args.results_file + ".pth")