def build_model()

in utils/common.py [0:0]


def build_model(args, num_classes, device, gpu_id, return_features=False, is_training=True):
    # model:
    num_outputs = num_classes
    if any(x in args.ood_metric for x in ['bkg_c', 'bin_disc']):
        num_outputs += 1
    elif any(x in args.ood_metric for x in ['mc_disc']):
        num_outputs += 2
    # 'ResNet18', 'ResNet34', or 'ResNet50'
    model: BaseModel = eval(args.model)(num_classes=num_classes, num_outputs=num_outputs, return_features=return_features).to(device)
    if args.ddp:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[gpu_id], broadcast_buffers=False,
            find_unused_parameters=True
            )
    else:
        # model = torch.nn.DataParallel(model)
        pass
    # print('Model Done.')

    if is_training:
        # optimizer:
        if args.opt == 'adam':
            optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
        elif args.opt == 'sgd':
            optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.wd, momentum=args.momentum, nesterov=True)
        if args.decay == 'cos':
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
        elif args.decay == 'multisteps':
            scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.decay_epochs, gamma=0.1)
        # print('Optimizer Done.')

        return model, optimizer, scheduler, num_outputs
    else:
        return model, num_outputs