in train.py [0:0]
def run_experiment(args):
start_time = time.time()
torch.manual_seed(args["init_seed"])
np.random.seed(args["init_seed"])
loaders = get_loaders(args["data_path"], args["dataset"], args["batch_size"], args["method"])
sys.stdout = Tee(os.path.join(
args["output_dir"], 'seed_{}_{}.out'.format(
args["hparams_seed"], args["init_seed"])), sys.stdout)
sys.stderr = Tee(os.path.join(
args["output_dir"], 'seed_{}_{}.err'.format(
args["hparams_seed"], args["init_seed"])), sys.stderr)
checkpoint_file = os.path.join(
args["output_dir"], 'seed_{}_{}.pt'.format(
args["hparams_seed"], args["init_seed"]))
best_checkpoint_file = os.path.join(
args["output_dir"],
"seed_{}_{}.best.pt".format(args["hparams_seed"], args["init_seed"]),
)
model = {
"erm": models.ERM,
"suby": models.ERM,
"subg": models.ERM,
"rwy": models.ERM,
"rwg": models.ERM,
"dro": models.GroupDRO,
"jtt": models.JTT
}[args["method"]](args, loaders["tr"])
last_epoch = 0
best_selec_val = float('-inf')
if os.path.exists(checkpoint_file):
model.load(checkpoint_file)
last_epoch = model.last_epoch
best_selec_val = model.best_selec_val
for epoch in range(last_epoch, args["num_epochs"]):
if epoch == args["T"] + 1 and args["method"] == "jtt":
loaders = get_loaders(
args["data_path"],
args["dataset"],
args["batch_size"],
args["method"],
model.weights.tolist())
for i, x, y, g in loaders["tr"]:
model.update(i, x, y, g, epoch)
result = {
"args": args, "epoch": epoch, "time": time.time() - start_time}
for loader_name, loader in loaders.items():
avg_acc, group_accs = model.accuracy(loader)
result["acc_" + loader_name] = group_accs
result["avg_acc_" + loader_name] = avg_acc
selec_value = {
"min_acc_va": min(result["acc_va"]),
"avg_acc_va": result["avg_acc_va"],
}[args["selector"]]
if selec_value >= best_selec_val:
model.best_selec_val = selec_value
best_selec_val = selec_value
model.save(best_checkpoint_file)
model.save(checkpoint_file)
print(json.dumps(result))