def main()

in online_attacks/experiments/stochastic_toy.py [0:0]


def main():
    parser = ArgumentParser(description="Online Attacks")
    # Online params
    parser.add_config("online_params", OnlineParams)

    # Hparams
    parser.add_argument(
        "--K", type=int, default=1, metavar="K", help="Number of attacks to submit"
    )
    parser.add_argument(
        "--eps", type=float, default=1.0, metavar="E", help="Std for noise"
    )
    parser.add_argument(
        "--max_perms",
        type=int,
        default=120,
        metavar="P",
        help="Maximum number of perms of the data stream",
    )
    parser.add_argument(
        "--threshold", type=int, default=None, metavar="T", help="Manual Threshold"
    )
    parser.add_argument(
        "--seed", type=int, metavar="S", help="random seed (default: None)"
    )
    parser.add_argument(
        "--exhaust", action="store_true", default=False, help="Exhaust K"
    )
    parser.add_argument(
        "--knapsack",
        action="store_true",
        default=False,
        help="Use Knapsack Competitive Ratio",
    )
    # Bells
    parser.add_argument(
        "--wandb", action="store_true", default=False, help="Use wandb for logging"
    )
    parser.add_argument(
        "--namestr",
        type=str,
        default="Online-Attacks",
        help="additional info in output filename to describe experiments",
    )

    args = parser.parse_args()
    args.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    seed_everything(args.seed)
    if args.threshold is not None:
        args.online_params.threshold = args.threshold

    if os.path.isfile("settings.json"):
        with open("settings.json") as f:
            data = json.load(f)
        args.wandb_apikey = data.get("wandbapikey")

    if args.wandb:
        os.environ["WANDB_API_KEY"] = args.wandb_apikey
        wandb.init(
            project="Online-Attacks",
            name="Online-Attack-{}-{}".format("toy", args.namestr),
        )
    train_loader = ToyDatastream_Stochastic(
        args.online_params.N, args.max_perms, args.eps
    )
    args.online_params.exhaust = args.exhaust
    args.online_params.K = args.K
    comp_ratio = run_experiment(args.online_params, train_loader, args.knapsack)
    if args.wandb:
        model_name = "Competitive Ratio " + args.online_params.online_type[0].value
        if args.knapsack:
            model_name = (
                "Knapsack Competitive Ratio " + args.online_params.online_type[0].value
            )
        wandb.log({model_name: comp_ratio, "K": k})