def run_inversion()

in model_inversion.py [0:0]


def run_inversion(args, data, weights):
    regression = (args.dataset == "iwpc" or args.dataset == "synth")

    # Train model:
    model = models.get_model(args.model)
    logging.info(f"Training model {args.model}")
    model.train(data, l2=args.l2, weights=weights)
    # Check predictions for sanity:
    predictions = model.predict(data["features"], regression=regression)
    if regression:
        acc = (predictions - data["targets"]).pow(2).mean()
        logging.info(f"Training MSE of regressor {acc.item():.3f}")
    else:
        acc = ((predictions == data["targets"]).float()).mean()
        logging.info(f"Training accuracy of classifier {acc.item():.3f}")

    # The target attribute can be specified as a range, e.g. `(4, 8)` means the
    # 4th through the 7th feature are the values of the encoded target attribute.
    if args.dataset == "uciadult":
        target_attribute = (24, 25) # [not married, married]
    elif args.dataset == "iwpc":
        #target_attribute = (2, 7) # CYP2C9 genotype
        target_attribute = (11, 13) # VKORC1 genotype
    else:
        raise NotImplementedError("Dataset not yet implemented.")

    if args.inverter == "all":
        inverters = INVERTERS[:-1]
    else:
        inverters = [args.inverter]

    target = features_to_category(data["features"][:, range(*target_attribute)])
    results = { "target" : target.tolist() }

    for inverter in inverters:
        invert_fn = globals()[f"{inverter}_inverter"]
        predictions = invert_fn(
            data, target_attribute, model=model, weights=weights, l2=args.l2)
        acc = compute_metrics(data, predictions, target_attribute)
        logging.info(f"{inverter} inverter Accuracy {acc:.4f}")
        results[inverter] = predictions.tolist()

    return results