def main()

in online_attacks/experiments/toy.py [0:0]


def main():
    parser = ArgumentParser(description="Online Attacks")
    # Online params
    parser.add_config("online_params", OnlineParams)

    # Hparams
    parser.add_argument(
        "--K", type=int, default=1, metavar="K", help="Number of attacks to submit"
    )
    parser.add_argument(
        "--max_perms",
        type=int,
        default=120,
        metavar="P",
        help="Maximum number of perms of the data stream",
    )
    parser.add_argument(
        "--seed", type=int, metavar="S", help="random seed (default: None)"
    )
    parser.add_argument(
        "--exhaust", action="store_true", default=False, help="Exhaust K"
    )
    parser.add_argument(
        "--knapsack",
        action="store_true",
        default=False,
        help="Use Knapsack Competitive Ratio",
    )

    # Bells
    parser.add_argument(
        "--wandb", action="store_true", default=False, help="Use wandb for logging"
    )
    parser.add_argument(
        "--namestr",
        type=str,
        default="Online-Attacks",
        help="additional info in output filename to describe experiments",
    )

    args = parser.parse_args()
    args.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    seed_everything(args.seed)

    if os.path.isfile("settings.json"):
        with open("settings.json") as f:
            data = json.load(f)
        args.wandb_apikey = data.get("wandbapikey")

    if args.wandb:
        os.environ["WANDB_API_KEY"] = args.wandb_apikey
        wandb.init(
            project="Online-Attacks",
            name="Online-Attack-{}-{}".format("toy", args.namestr),
        )
    train_loader = ToyDatastream(args.online_params.N, args.max_perms)
    for k in range(1, args.K + 1):
        args.online_params.K = k
        args.online_params.exhaust = args.exhaust
        comp_ratio = run_experiment(args.online_params, train_loader, args.knapsack)
        if args.wandb:
            model_name = "Competitive Ratio " + args.online_params.online_type.value
            if args.knapsack:
                model_name = (
                    "Knapsack Competitive Ratio " + args.online_params.online_type.value
                )
            wandb.log({model_name: comp_ratio, "K": k})