in main.py [0:0]
def main():
# Seed.
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
np.random.seed(args.seed)
# If saving models or saving data, create a folder for storing these files.
# args.save => saving models, tensorboard logs, etc.
# args.save_data => just saving results.
if args.save or args.save_data:
i = 0
while True:
run_base_dir = pathlib.Path(
f"{args.log_dir}/{args.name}+try={str(i)}"
)
if not run_base_dir.exists():
os.makedirs(run_base_dir)
args.name = args.name + f"+try={i}"
break
i += 1
(run_base_dir / "settings.txt").write_text(str(args))
args.run_base_dir = run_base_dir
print(f"=> Saving data in {run_base_dir}")
# Get dataloader.
data_loader = getattr(data, args.set)()
curr_acc1 = 0.0
# Make a list of models, instead of a single model.
# This is not for training subspaces, but rather for the ensemble & SWA baselines.
models = [utils.get_model() for _ in range(args.num_models)]
# when training the SWA baseline, turn off the gradient to all but the first model.
if args.trainswa:
for i in range(1, args.num_models):
for p in models[i].parameters():
p.requires_grad = False
# Resume a model from a saved checkpoint.
num_models_filled = 0
num_models = -1
if args.resume:
for i, resume in enumerate(args.resume):
if type(resume) == tuple:
# can use a tuple to provide how many models to load.
resume, num_models = resume
if os.path.isfile(resume):
print(f"=> Loading checkpoint '{resume}'")
checkpoint = torch.load(resume, map_location="cpu")
pretrained_dicts = [
{k[7:]: v for k, v in c.items()}
for c in checkpoint["state_dicts"]
]
n = 0
for pretrained_dict in pretrained_dicts:
print(num_models_filled)
model_dict = models[num_models_filled].state_dict()
pretrained_dict = {
k: v
for k, v in pretrained_dict.items()
if k in model_dict
}
model_dict.update(pretrained_dict)
models[num_models_filled].load_state_dict(model_dict)
num_models_filled += 1
n += 1
if num_models > 0 and n >= num_models:
break
print(
f"=> Loaded checkpoint '{resume}' (epoch {checkpoint['epoch']})"
)
else:
print(f"=> No checkpoint found at '{resume}'")
# Put models on the GPU.
models = [utils.set_gpu(m) for m in models]
# Get training loss.
if args.label_smoothing is None:
criterion = nn.CrossEntropyLoss()
else:
print("adding label smoothing!")
criterion = LabelSmoothing(smoothing=args.label_smoothing)
criterion = criterion.to(args.device)
if args.save:
writer = SummaryWriter(log_dir=run_base_dir)
else:
writer = None
# Get the "trainer", which specified how the model is trained.
trainer = getattr(trainers, args.trainer or "default")
print(f"=> Using trainer {trainer}")
train, test = trainer.train, trainer.test
# Call "init" on the trainer.
trainer.init(models, writer, data_loader)
# Since we have have a list of models, we also use a list of optimizers & schedulers.
# When training subspaces, this list is of length 1.
metrics = {}
optimizers = [utils.get_optimizer(args, m) for m in models]
lr_schedulers = [
schedulers.get_policy(args.lr_policy or "cosine_lr")(o, args)
for o in optimizers
if o is not None
]
# more logic for resuming a checkpoint, specifically concerned with the "pretrained" argument.
# if args.pretrained, then we are not resuming. This means that we start from epoch 0.
# if not args.pretrained, we are resuming and have to set the epoch, etc. appropriately.
init_epoch = 0
num_models_filled = 0
if args.resume and not args.pretrained:
for i, resume in enumerate(args.resume):
if os.path.isfile(resume):
print(f"=> Loading checkpoint '{resume}'")
checkpoint = torch.load(resume, map_location="cpu")
init_epoch = checkpoint["epoch"]
curr_acc1 = checkpoint["curr_acc1"]
for opt in checkpoint["optimizers"]:
if args.late_start >= 0:
continue
optimizers[num_models_filled].load_state_dict(opt)
num_models_filled += 1
best_acc1 = 0.0
train_loss = 0.0
# Save the initialization.
if init_epoch == 0 and args.save:
print("saving checkpoint")
utils.save_cpt(init_epoch, 0, models, optimizers, best_acc1, curr_acc1)
# If the start epoch == the end epoch, just do evaluation "test".
if init_epoch == args.epochs:
curr_acc1, metrics = test(
models, writer, criterion, data_loader, init_epoch,
)
if args.save or args.save_data:
metrics["epoch"] = init_epoch
utils.write_result_to_csv(
name=args.name + f"+curr_epoch={init_epoch}",
curr_acc1=curr_acc1,
best_acc1=best_acc1,
train_loss=train_loss,
**metrics,
)
# Train from init_epoch -> args.epochs.
for epoch in range(init_epoch, args.epochs):
for lr_scheduler in lr_schedulers:
lr_scheduler(epoch, None)
train_loss = train(
models, writer, data_loader, optimizers, criterion, epoch,
)
if type(train_loss) is tuple:
train_loss, optimizers = train_loss
if (
args.test_freq is None
or (epoch % args.test_freq == 0)
or epoch == args.epochs - 1
):
curr_acc1, metrics = test(
models, writer, criterion, data_loader, epoch,
)
if curr_acc1 > best_acc1:
best_acc1 = curr_acc1
metrics["epoch"] = epoch + 1
# This is for the SWA baseline -- we need to lookup if this an epoch for which we are saving a checkpoint.
# If so we save a checkpoint and move it to the corresponding place in the models list.
if args.trainswa and (epoch + 1) in args.swa_save_epochs:
j = args.swa_save_epochs.index(epoch + 1)
for m1, m2 in zip(models[0].modules(), models[j].modules()):
if isinstance(m1, nn.Conv2d):
m2.weight = nn.Parameter(m1.weight.clone().detach())
m2.weight.requires_grad = False
elif isinstance(m1, nn.BatchNorm2d):
m2.weight = nn.Parameter(m1.weight.clone().detach())
m2.bias = nn.Parameter(m1.bias.clone().detach())
m2.weight.requires_grad = False
m2.bias.requires_grad = False
# Save checkpoint.
if (
args.save
and args.save_epochs is not None
and (epoch + 1) in args.save_epochs
):
it = (epoch + 1) * len(data_loader.train_loader)
utils.save_cpt(
epoch + 1, it, models, optimizers, best_acc1, curr_acc1
)
# Save results.
if args.save or args.save_data:
utils.write_result_to_csv(
name=args.name,
curr_acc1=curr_acc1,
best_acc1=best_acc1,
train_loss=train_loss,
**metrics,
)
# Save final checkpiont.
if args.save:
it = args.epochs * len(data_loader.train_loader)
utils.save_cpt(
args.epochs, it, models, optimizers, best_acc1, curr_acc1
)
return curr_acc1, metrics