def torch_model()

in distributed_training/src_dir/util.py [0:0]


def torch_model(model_name,
                num_classes=0,
                pretrained=True,
                local_rank=0,
                model_parallel=False):
    #     model_names = sorted(name for name in models.__dict__
    #                          if name.islower() and not name.startswith("__")
    #                          and callable(models.__dict__[name]))

    if (model_name == "inception_v3"):
        raise RuntimeError(
            "Currently, inception_v3 is not supported by this example.")

    # create model
    if pretrained:
        print("=> using pre-trained model '{}'".format(model_name))
        if model_parallel:
            if local_rank == 0:
                model = models.__dict__[model_name](pretrained=True)
            dis_util.smp_barrier()
        model = models.__dict__[model_name](pretrained=True)
    else:
        print("=> creating model '{}'".format(model_name))
        model = models.__dict__[model_name]()

    if num_classes > 0:
        n_inputs = model.fc.in_features

        # add more layers as required
        classifier = nn.Sequential(
            OrderedDict([('fc_output', nn.Linear(n_inputs, num_classes))]))

        model.fc = classifier

    return model