in online_attacks/experiments/toy.py [0:0]
def main():
parser = ArgumentParser(description="Online Attacks")
# Online params
parser.add_config("online_params", OnlineParams)
# Hparams
parser.add_argument(
"--K", type=int, default=1, metavar="K", help="Number of attacks to submit"
)
parser.add_argument(
"--max_perms",
type=int,
default=120,
metavar="P",
help="Maximum number of perms of the data stream",
)
parser.add_argument(
"--seed", type=int, metavar="S", help="random seed (default: None)"
)
parser.add_argument(
"--exhaust", action="store_true", default=False, help="Exhaust K"
)
parser.add_argument(
"--knapsack",
action="store_true",
default=False,
help="Use Knapsack Competitive Ratio",
)
# Bells
parser.add_argument(
"--wandb", action="store_true", default=False, help="Use wandb for logging"
)
parser.add_argument(
"--namestr",
type=str,
default="Online-Attacks",
help="additional info in output filename to describe experiments",
)
args = parser.parse_args()
args.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
seed_everything(args.seed)
if os.path.isfile("settings.json"):
with open("settings.json") as f:
data = json.load(f)
args.wandb_apikey = data.get("wandbapikey")
if args.wandb:
os.environ["WANDB_API_KEY"] = args.wandb_apikey
wandb.init(
project="Online-Attacks",
name="Online-Attack-{}-{}".format("toy", args.namestr),
)
train_loader = ToyDatastream(args.online_params.N, args.max_perms)
for k in range(1, args.K + 1):
args.online_params.K = k
args.online_params.exhaust = args.exhaust
comp_ratio = run_experiment(args.online_params, train_loader, args.knapsack)
if args.wandb:
model_name = "Competitive Ratio " + args.online_params.online_type.value
if args.knapsack:
model_name = (
"Knapsack Competitive Ratio " + args.online_params.online_type.value
)
wandb.log({model_name: comp_ratio, "K": k})