in inference/utils.py [0:0]
def load_model_inference(config, device="cuda"):
"""It loads the generator network to do inference with and over-rides the configuration file.
Arguments
---------
config: dict
Dictionary with configuration parameters.
device: str, optional
Device name
Returns
-------
generator: torch.nn.module
Generator network.
config: dict
Overwritten configuration dictionary from the model checkpoint if it exists.
"""
if config["model_backbone"] == "biggan":
# Select checkpoint name according to best FID in checkpoint
best_fid = 1e5
best_name_final = ""
for name_best in ["best0", "best1"]:
try:
root = "/".join([config["weights_root"], config["experiment_name"]])
state_dict_loaded = torch.load(
"%s/%s.pth"
% (root, biggan_utils.join_strings("_", ["state_dict", name_best]))
)
print(
"For name best ",
name_best,
" we have an FID: ",
state_dict_loaded["best_FID"],
)
if state_dict_loaded["best_FID"] < best_fid:
best_fid = state_dict_loaded["best_FID"]
best_name_final = name_best
except:
print("Checkpoint with name ", name_best, " not in folder.")
config["load_weights"] = best_name_final
print("Final name selected is ", best_name_final)
# Prepare state dict, which holds things like epoch # and itr #
state_dict = {
"itr": 0,
"epoch": 0,
"save_num": 0,
"save_best_num": 0,
"best_IS": 0,
"best_FID": 999999,
"config": config,
}
# Get override some parameters from trained model in experiment config
biggan_utils.load_weights(
None,
None,
state_dict,
config["weights_root"],
config["experiment_name"],
config["load_weights"],
None,
strict=False,
load_optim=False,
eval=True,
)
# Ignore items which we might want to overwrite from the command line
for item in state_dict["config"]:
if item not in [
"base_root",
"data_root",
"load_weights",
"batch_size",
"num_workers",
"weights_root",
"logs_root",
"samples_root",
"eval_reference_set",
"eval_instance_set",
"which_dataset",
"seed",
"eval_prdc",
"use_balanced_sampler",
"custom_distrib",
"longtail_temperature",
"longtail_gen",
"num_inception_images",
"sample_num_npz",
"load_in_mem",
"split",
"z_var",
"kmeans_subsampled",
"filter_hd",
"n_subsampled_data",
"feature_augmentation",
]:
if item == "experiment_name" and config["experiment_name"] != "":
continue # Allows to overwride the name of the experiment at test time
config[item] = state_dict["config"][item]
# No data augmentation during testing
config["feature_augmentation"] = False
config["hflips"] = False
config["DA"] = False
experiment_name = (
config["experiment_name"]
if config["experiment_name"]
else biggan_utils.name_from_config(config)
)
print("Experiment name is %s" % experiment_name)
# Next, build the model
generator = BigGANModel.Generator(**config).to(device)
# Load weights
print("Loading weights...")
# Here is where we deal with the ema--load ema weights or load normal weights
biggan_utils.load_weights(
generator if not (config["use_ema"]) else None,
None,
state_dict,
config["weights_root"],
experiment_name,
config["load_weights"],
generator if config["ema"] and config["use_ema"] else None,
strict=False,
load_optim=False,
)
if config["G_eval_mode"]:
print("Putting G in eval mode..")
generator.eval()
else:
print("G is in %s mode..." % ("training" if generator.training else "eval"))
elif config["model_backbone"] == "stylegan2":
# StyleGAN2 saves the entire network + weights in a pickle. Load it here.
network_pkl = os.path.join(
config["base_root"], config["experiment_name"], "best-network-snapshot.pkl"
)
print('Loading networks from "%s"...' % network_pkl)
with dnnlib.util.open_url(network_pkl) as f:
generator = legacy.load_network_pkl(f)["G_ema"].to(device) # type: ignore
return generator, config