in experiments/train_imagenet_jsd.py [0:0]
def train(cfg, is_leader):
np.random.seed(cfg.rng_seed)
torch.manual_seed(cfg.rng_seed)
log.info(cfg.pretty())
cur_device = torch.cuda.current_device()
model = instantiate(cfg.model).cuda(device=cur_device)
if cfg.num_gpus > 1:
model = torch.nn.parallel.DistributedDataParallel(
module=model,
device_ids=[cur_device],
output_device=cur_device
)
optimizer = instantiate(cfg.optim, model.parameters())
if cfg.optim.max_epoch > 0 and cfg.train.weights is None:
print("Loading training set...")
train_dataset = instantiate(cfg.train)
else:
print("Skipping loading the training dataset, 0 epochs of training to perform "
" or pre-trained weights provided.")
train_dataset = None
print("Loading test set...")
test_dataset = instantiate(cfg.test)
lr_policy = instantiate(cfg.optim.lr_policy)
print("Training...")
train_net(model=model,
optimizer=optimizer,
train_dataset=train_dataset,
batch_size=cfg.train.batch_size,
max_epoch=cfg.optim.max_epoch,
loader_params=cfg.data_loader,
lr_policy=lr_policy,
save_period=cfg.train.checkpoint_period,
weights=cfg.train.weights,
num_gpus=cfg.num_gpus,
is_leader=is_leader,
jsd_num=cfg.train.params.jsd_num,
jsd_alpha=cfg.train.jsd_alpha
)
print("Testing...")
err = test_net(model=model,
test_dataset=test_dataset,
batch_size=cfg.test.batch_size,
loader_params=cfg.data_loader,
num_gpus=cfg.num_gpus)
test_corrupt_net(model=model,
corrupt_cfg=cfg.corrupt,
batch_size=cfg.corrupt.batch_size,
loader_params=cfg.data_loader,
aug_string=cfg.corrupt.aug_string,
clean_err=err,
mCE_denom=cfg.corrupt.mCE_baseline_file,
num_gpus=cfg.num_gpus,
log_name='train_imagenet.log')