in utils/common.py [0:0]
def build_model(args, num_classes, device, gpu_id, return_features=False, is_training=True):
# model:
num_outputs = num_classes
if any(x in args.ood_metric for x in ['bkg_c', 'bin_disc']):
num_outputs += 1
elif any(x in args.ood_metric for x in ['mc_disc']):
num_outputs += 2
# 'ResNet18', 'ResNet34', or 'ResNet50'
model: BaseModel = eval(args.model)(num_classes=num_classes, num_outputs=num_outputs, return_features=return_features).to(device)
if args.ddp:
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[gpu_id], broadcast_buffers=False,
find_unused_parameters=True
)
else:
# model = torch.nn.DataParallel(model)
pass
# print('Model Done.')
if is_training:
# optimizer:
if args.opt == 'adam':
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
elif args.opt == 'sgd':
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.wd, momentum=args.momentum, nesterov=True)
if args.decay == 'cos':
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
elif args.decay == 'multisteps':
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.decay_epochs, gamma=0.1)
# print('Optimizer Done.')
return model, optimizer, scheduler, num_outputs
else:
return model, num_outputs