def load_model_inference()

in inference/utils.py [0:0]


def load_model_inference(config, device="cuda"):
    """It loads the generator network to do inference with and over-rides the configuration file.

    Arguments
    ---------
        config: dict
            Dictionary with configuration parameters.
        device: str, optional
            Device name
    Returns
    -------
        generator: torch.nn.module
            Generator network.
        config: dict
            Overwritten configuration dictionary from the model checkpoint if it exists.

    """
    if config["model_backbone"] == "biggan":
        # Select checkpoint name according to best FID in checkpoint
        best_fid = 1e5
        best_name_final = ""
        for name_best in ["best0", "best1"]:
            try:
                root = "/".join([config["weights_root"], config["experiment_name"]])
                state_dict_loaded = torch.load(
                    "%s/%s.pth"
                    % (root, biggan_utils.join_strings("_", ["state_dict", name_best]))
                )
                print(
                    "For name best ",
                    name_best,
                    " we have an FID: ",
                    state_dict_loaded["best_FID"],
                )
                if state_dict_loaded["best_FID"] < best_fid:
                    best_fid = state_dict_loaded["best_FID"]
                    best_name_final = name_best
            except:
                print("Checkpoint with name ", name_best, " not in folder.")
        config["load_weights"] = best_name_final
        print("Final name selected is ", best_name_final)

        # Prepare state dict, which holds things like epoch # and itr #
        state_dict = {
            "itr": 0,
            "epoch": 0,
            "save_num": 0,
            "save_best_num": 0,
            "best_IS": 0,
            "best_FID": 999999,
            "config": config,
        }
        # Get override some parameters from trained model in experiment config
        biggan_utils.load_weights(
            None,
            None,
            state_dict,
            config["weights_root"],
            config["experiment_name"],
            config["load_weights"],
            None,
            strict=False,
            load_optim=False,
            eval=True,
        )

        # Ignore items which we might want to overwrite from the command line
        for item in state_dict["config"]:
            if item not in [
                "base_root",
                "data_root",
                "load_weights",
                "batch_size",
                "num_workers",
                "weights_root",
                "logs_root",
                "samples_root",
                "eval_reference_set",
                "eval_instance_set",
                "which_dataset",
                "seed",
                "eval_prdc",
                "use_balanced_sampler",
                "custom_distrib",
                "longtail_temperature",
                "longtail_gen",
                "num_inception_images",
                "sample_num_npz",
                "load_in_mem",
                "split",
                "z_var",
                "kmeans_subsampled",
                "filter_hd",
                "n_subsampled_data",
                "feature_augmentation",
            ]:
                if item == "experiment_name" and config["experiment_name"] != "":
                    continue  # Allows to overwride the name of the experiment at test time
                config[item] = state_dict["config"][item]
        # No data augmentation during testing
        config["feature_augmentation"] = False
        config["hflips"] = False
        config["DA"] = False

        experiment_name = (
            config["experiment_name"]
            if config["experiment_name"]
            else biggan_utils.name_from_config(config)
        )
        print("Experiment name is %s" % experiment_name)

        # Next, build the model
        generator = BigGANModel.Generator(**config).to(device)

        # Load weights
        print("Loading weights...")

        # Here is where we deal with the ema--load ema weights or load normal weights
        biggan_utils.load_weights(
            generator if not (config["use_ema"]) else None,
            None,
            state_dict,
            config["weights_root"],
            experiment_name,
            config["load_weights"],
            generator if config["ema"] and config["use_ema"] else None,
            strict=False,
            load_optim=False,
        )

        if config["G_eval_mode"]:
            print("Putting G in eval mode..")
            generator.eval()
        else:
            print("G is in %s mode..." % ("training" if generator.training else "eval"))

    elif config["model_backbone"] == "stylegan2":
        # StyleGAN2 saves the entire network + weights in a pickle. Load it here.
        network_pkl = os.path.join(
            config["base_root"], config["experiment_name"], "best-network-snapshot.pkl"
        )
        print('Loading networks from "%s"...' % network_pkl)
        with dnnlib.util.open_url(network_pkl) as f:
            generator = legacy.load_network_pkl(f)["G_ema"].to(device)  # type: ignore
    return generator, config