in distributed_training/src_dir/util.py [0:0]
def torch_model(model_name,
num_classes=0,
pretrained=True,
local_rank=0,
model_parallel=False):
# model_names = sorted(name for name in models.__dict__
# if name.islower() and not name.startswith("__")
# and callable(models.__dict__[name]))
if (model_name == "inception_v3"):
raise RuntimeError(
"Currently, inception_v3 is not supported by this example.")
# create model
if pretrained:
print("=> using pre-trained model '{}'".format(model_name))
if model_parallel:
if local_rank == 0:
model = models.__dict__[model_name](pretrained=True)
dis_util.smp_barrier()
model = models.__dict__[model_name](pretrained=True)
else:
print("=> creating model '{}'".format(model_name))
model = models.__dict__[model_name]()
if num_classes > 0:
n_inputs = model.fc.in_features
# add more layers as required
classifier = nn.Sequential(
OrderedDict([('fc_output', nn.Linear(n_inputs, num_classes))]))
model.fc = classifier
return model