in models/model_factory.py [0:0]
def create_model(args, arch=None):
n_classes = int(getattr(args, 'n_classes', 1000))
bn_momentum = getattr(args, 'bn_momentum', 0.1)
bn_eps = getattr(args, 'bn_eps', 1e-5)
dropout = getattr(args, 'dropout', 0)
drop_connect = getattr(args, 'drop_connect', 0)
if arch is None:
arch = args.arch
if arch == 'attentive_nas_dynamic_model':
model = AttentiveNasDynamicModel(
args.supernet_config,
n_classes = n_classes,
bn_param = (bn_momentum, bn_eps),
)
elif arch == 'attentive_nas_static_model':
supernet = AttentiveNasDynamicModel(
args.supernet_config,
n_classes = n_classes,
bn_param = (bn_momentum, bn_eps),
)
# load from pretrained models
supernet.load_weights_from_pretrained_models(args.pareto_models.supernet_checkpoint_path)
# subsample a static model with weights inherited from the supernet dynamic model
supernet.set_active_subnet(
resolution=args.active_subnet.resolution,
width = args.active_subnet.width,
depth = args.active_subnet.depth,
kernel_size = args.active_subnet.kernel_size,
expand_ratio = args.active_subnet.expand_ratio
)
model = supernet.get_active_subnet()
# house-keeping stuff
model.set_bn_param(momentum=bn_momentum, eps=bn_eps)
del supernet
else:
raise ValueError(arch)
return model