in experiments/sgd/train_net.py [0:0]
def main():
args = init_config(mode='train_net')
is_imagenet = args.dataset == 'imagenet'
train_queue, valid_queue, num_classes = image_loader(dataset=args.dataset,
data_dir=args.data_dir,
test=True,
load_train_anyway=True,
batch_size=args.batch_size,
test_batch_size=args.test_batch_size,
num_workers=args.num_workers,
cutout=args.cutout,
cutout_length=args.cutout_length,
seed=args.seed,
noise=args.noise,
n_shots=args.n_shots)
assert args.arch is not None, 'architecture genotype/index must be specified'
try:
genotype = eval('genotypes.%s' % args.arch)
net_args = {'C': args.init_channels,
'genotype': genotype,
'n_cells': args.layers,
'C_mult': int(genotype != ViT) + 1, # assume either ViT or DARTS-style architecture
'preproc': genotype != ViT,
'stem_type': 1} # assume that the ImageNet-style stem is used by default
except:
deepnets = DeepNets1M(split=args.split,
nets_dir=args.data_dir,
large_images=is_imagenet,
arch=args.arch)
assert len(deepnets) == 1, 'one architecture must be chosen to train'
graph = deepnets[0]
net_args, idx = graph.net_args, graph.net_idx
if 'norm' in net_args and net_args['norm'] == 'bn':
net_args['norm'] = 'bn-track'
if isinstance(net_args['genotype'], str):
model = adjust_net(eval('torchvision.models.%s(pretrained=%d)' % (net_args['genotype'], args.pretrained)), is_imagenet)
else:
model = Network(num_classes=num_classes,
is_imagenet_input=is_imagenet,
auxiliary=args.auxiliary,
**net_args)
if args.ckpt is not None or isinstance(model, torchvision.models.ResNet):
model = pretrained_model(model, args.ckpt, num_classes, args.debug, GHN)
model = model.train().to(args.device)
print('\nTraining arch={} with {} parameters'.format(args.arch, capacity(model)[1]))
optimizer = torch.optim.SGD(
model.parameters(),
args.lr,
momentum=args.momentum,
weight_decay=args.wd
)
if is_imagenet:
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, 0.97)
else:
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)
trainer = Trainer(optimizer,
num_classes,
is_imagenet,
n_batches=len(train_queue),
grad_clip=args.grad_clip,
auxiliary=args.auxiliary,
auxiliary_weight=args.auxiliary_weight,
device=args.device,
log_interval=args.log_interval,
amp=args.amp)
for epoch in range(max(1, args.epochs)): # if args.epochs=0, then just evaluate the model
if args.epochs > 0:
print('\nepoch={:03d}/{:03d}, lr={:e}'.format(epoch + 1, args.epochs, scheduler.get_last_lr()[0]))
model.drop_path_prob = args.drop_path_prob * epoch / args.epochs
trainer.reset()
model.train()
for images, targets in train_queue:
trainer.update(model, images, targets)
trainer.log()
if args.save:
checkpoint_path = os.path.join(args.save, 'checkpoint.pt')
torch.save({'state_dict': model.state_dict(), 'epoch': epoch}, checkpoint_path)
print('\nsaved the checkpoint to {}'.format(checkpoint_path))
infer(model.eval(), valid_queue, verbose=True)
scheduler.step()