in data_augmentation/my_training.py [0:0]
def main_worker(gpu, ngpus_per_node, args, ckpt_path, repo):
global best_acc1
args.gpu = gpu
if not args.distributed or (args.distributed and args.rank % ngpus_per_node == 0):
model_dir = create_repo(args, repo)
if args.gpu is not None:
print("Use GPU: {} for training".format(args.gpu))
cur_device = torch.cuda.current_device()
# create model
if args.pretrained:
print("=> using pre-trained model '{}'".format(args.arch))
net = models.__dict__[args.arch](pretrained=True)
else:
print("=> creating model '{}'".format(args.arch))
if args.arch == "modified_resnet18":
net = modified_resnet18(modify=args.modify, num_classes=args.num_classes)
else:
net = models.__dict__[args.arch](num_classes=args.num_classes)
if args.augerino:
if args.inv_per_class:
assert args.disable_at_valid
augerino_classes = args.num_classes
else:
augerino_classes = 1
if args.transfos == ["tx", "ty", "scale"]: # special case we pass it 1 by 1
if args.min_val:
print("Using UniformAugEachMin")
augerino = UniformAugEachMin(
transfos=args.transfos,
min_values=args.min_values,
shutvals=args.shutdown_vals,
num_classes=augerino_classes,
)
else:
print("Using UniformAugEach")
augerino = UniformAugEachPos(
transfos=args.transfos,
shutvals=args.shutdown_vals,
num_classes=augerino_classes,
)
else:
if args.min_val:
augerino = AugModuleMin(
transfos=args.transfos,
min_values=args.min_values,
shutvals=args.shutdown_vals,
num_classes=augerino_classes,
)
else:
augerino = MyUniformAug(
transfos=args.transfos,
shutvals=args.shutdown_vals,
num_classes=augerino_classes,
)
augerino.set_width(
torch.FloatTensor(args.startwidth)[None,:].repeat(augerino_classes,1)
)
print(augerino.width)
if args.fixed_augerino:
augerino.width.requires_grad=False
model = AugAveragedModel(
net, augerino, disabled=False, ncopies=args.ncopies, onecopy=args.onecopy
)
else:
model = net
# save initial width
if args.augerino:
widths = [model.aug.width.clone().detach()]
else:
widths = []
model = model.cuda(device=cur_device)
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(
module=model, device_ids=[cur_device], output_device=cur_device
)
if (
args.pretrained_noaugment and ckpt_path is None
): # to ensure we're not restarting after preemption
if args.distributed:
net = torch.nn.parallel.DistributedDataParallel(
module=net, device_ids=[cur_device], output_device=cur_device
)
checkpointing.restart_from_checkpoint(
args.noaugment_path, args, state_dict=net
) # WARNING: Lr is not adjusted accordingly
factor_lr = 1.0
else:
factor_lr = 0.1
# define loss function (criterion) and optimizer
if args.augerino:
criterion = aug_losses.safe_unif_aug_loss_each
model_param = (
model.module.model.parameters()
if args.distributed
else model.model.parameters()
)
aug_param = (
model.module.aug.parameters()
if args.distributed
else model.aug.parameters()
)
params = [
{
"name": "model",
"params": model_param,
"momentum": args.momentum,
"weight_decay": args.weight_decay,
},
{
"name": "aug",
"params": aug_param,
"momentum": args.momentum,
"weight_decay": 0.0,
"lr": args.lr * factor_lr,
},
]
else:
criterion = nn.CrossEntropyLoss().cuda()
params = [
{
"name": "model",
"params": model.parameters(),
"momentum": args.momentum,
"weight_decay": args.weight_decay,
}
]
optimizer = torch.optim.SGD(params, args.lr)
to_restore = {"epoch": 0, "best_acc1": 0.0, "all_acc1": [], "width": widths}
if ckpt_path is not None:
checkpointing.restart_from_checkpoint(
ckpt_path,
args,
run_variables=to_restore,
state_dict=model,
optimizer=optimizer,
)
args.start_epoch = to_restore["epoch"]
best_acc1 = to_restore["best_acc1"]
all_acc1 = to_restore["all_acc1"]
widths = to_restore["width"]
print("Starting from Epoch", args.start_epoch)
cudnn.benchmark = True
# Data loading code
traindir = os.path.join(args.data, "train")
valdir = os.path.join(args.data, "val")
train_loader, val_loader, train_sampler = functions.return_loader_and_sampler(
args, traindir, valdir
)
if args.evaluate:
validate(val_loader, model, criterion, args)
# return
for epoch in range(args.start_epoch, args.epochs):
adjust_learning_rate(optimizer, epoch, args, factor_lr)
if args.distributed:
train_sampler.set_epoch(epoch)
# train for one epoch
train(train_loader, model, criterion, optimizer, epoch, args)
# evaluate on validation set
acc1, acc5, val_loss = validate(val_loader, model, criterion, args)
# remember best acc@1 and save checkpoint
is_best = acc1 > best_acc1
best_acc1 = max(acc1, best_acc1)
all_acc1.append(acc1)
if args.augerino:
width = model.module.aug.width if args.distributed else module.aug.width
widths.append(width.clone().detach())
if not args.distributed or (
args.distributed and args.rank % ngpus_per_node == 0
):
save_checkpoint(
{
"epoch": epoch + 1,
"arch": args.arch,
"state_dict": model.state_dict(),
"best_acc1": best_acc1,
"acc1": acc1,
"acc5": acc5,
"all_acc1": all_acc1,
"optimizer": optimizer.state_dict(),
"val_loss": val_loss,
"width": widths,
},
is_best,
model_dir,
)