in stylegan2_ada_pytorch/legacy.py [0:0]
def convert_tf_generator(tf_G):
if tf_G.version < 4:
raise ValueError("TensorFlow pickle version too low")
# Collect kwargs.
tf_kwargs = tf_G.static_kwargs
known_kwargs = set()
def kwarg(tf_name, default=None, none=None):
known_kwargs.add(tf_name)
val = tf_kwargs.get(tf_name, default)
return val if val is not None else none
# Convert kwargs.
kwargs = dnnlib.EasyDict(
z_dim=kwarg("latent_size", 512),
c_dim=kwarg("label_size", 0),
w_dim=kwarg("dlatent_size", 512),
img_resolution=kwarg("resolution", 1024),
img_channels=kwarg("num_channels", 3),
mapping_kwargs=dnnlib.EasyDict(
num_layers=kwarg("mapping_layers", 8),
embed_features=kwarg("label_fmaps", None),
layer_features=kwarg("mapping_fmaps", None),
activation=kwarg("mapping_nonlinearity", "lrelu"),
lr_multiplier=kwarg("mapping_lrmul", 0.01),
w_avg_beta=kwarg("w_avg_beta", 0.995, none=1),
),
synthesis_kwargs=dnnlib.EasyDict(
channel_base=kwarg("fmap_base", 16384) * 2,
channel_max=kwarg("fmap_max", 512),
num_fp16_res=kwarg("num_fp16_res", 0),
conv_clamp=kwarg("conv_clamp", None),
architecture=kwarg("architecture", "skip"),
resample_filter=kwarg("resample_kernel", [1, 3, 3, 1]),
use_noise=kwarg("use_noise", True),
activation=kwarg("nonlinearity", "lrelu"),
),
)
# Check for unknown kwargs.
kwarg("truncation_psi")
kwarg("truncation_cutoff")
kwarg("style_mixing_prob")
kwarg("structure")
unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
if len(unknown_kwargs) > 0:
raise ValueError("Unknown TensorFlow kwarg", unknown_kwargs[0])
# Collect params.
tf_params = _collect_tf_params(tf_G)
for name, value in list(tf_params.items()):
match = re.fullmatch(r"ToRGB_lod(\d+)/(.*)", name)
if match:
r = kwargs.img_resolution // (2 ** int(match.group(1)))
tf_params[f"{r}x{r}/ToRGB/{match.group(2)}"] = value
kwargs.synthesis.kwargs.architecture = "orig"
# for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
# Convert params.
from training import networks
G = networks.Generator(**kwargs).eval().requires_grad_(False)
# pylint: disable=unnecessary-lambda
_populate_module_params(
G,
r"mapping\.w_avg",
lambda: tf_params[f"dlatent_avg"],
r"mapping\.embed\.weight",
lambda: tf_params[f"mapping/LabelEmbed/weight"].transpose(),
r"mapping\.embed\.bias",
lambda: tf_params[f"mapping/LabelEmbed/bias"],
r"mapping\.fc(\d+)\.weight",
lambda i: tf_params[f"mapping/Dense{i}/weight"].transpose(),
r"mapping\.fc(\d+)\.bias",
lambda i: tf_params[f"mapping/Dense{i}/bias"],
r"synthesis\.b4\.const",
lambda: tf_params[f"synthesis/4x4/Const/const"][0],
r"synthesis\.b4\.conv1\.weight",
lambda: tf_params[f"synthesis/4x4/Conv/weight"].transpose(3, 2, 0, 1),
r"synthesis\.b4\.conv1\.bias",
lambda: tf_params[f"synthesis/4x4/Conv/bias"],
r"synthesis\.b4\.conv1\.noise_const",
lambda: tf_params[f"synthesis/noise0"][0, 0],
r"synthesis\.b4\.conv1\.noise_strength",
lambda: tf_params[f"synthesis/4x4/Conv/noise_strength"],
r"synthesis\.b4\.conv1\.affine\.weight",
lambda: tf_params[f"synthesis/4x4/Conv/mod_weight"].transpose(),
r"synthesis\.b4\.conv1\.affine\.bias",
lambda: tf_params[f"synthesis/4x4/Conv/mod_bias"] + 1,
r"synthesis\.b(\d+)\.conv0\.weight",
lambda r: tf_params[f"synthesis/{r}x{r}/Conv0_up/weight"][::-1, ::-1].transpose(
3, 2, 0, 1
),
r"synthesis\.b(\d+)\.conv0\.bias",
lambda r: tf_params[f"synthesis/{r}x{r}/Conv0_up/bias"],
r"synthesis\.b(\d+)\.conv0\.noise_const",
lambda r: tf_params[f"synthesis/noise{int(np.log2(int(r)))*2-5}"][0, 0],
r"synthesis\.b(\d+)\.conv0\.noise_strength",
lambda r: tf_params[f"synthesis/{r}x{r}/Conv0_up/noise_strength"],
r"synthesis\.b(\d+)\.conv0\.affine\.weight",
lambda r: tf_params[f"synthesis/{r}x{r}/Conv0_up/mod_weight"].transpose(),
r"synthesis\.b(\d+)\.conv0\.affine\.bias",
lambda r: tf_params[f"synthesis/{r}x{r}/Conv0_up/mod_bias"] + 1,
r"synthesis\.b(\d+)\.conv1\.weight",
lambda r: tf_params[f"synthesis/{r}x{r}/Conv1/weight"].transpose(3, 2, 0, 1),
r"synthesis\.b(\d+)\.conv1\.bias",
lambda r: tf_params[f"synthesis/{r}x{r}/Conv1/bias"],
r"synthesis\.b(\d+)\.conv1\.noise_const",
lambda r: tf_params[f"synthesis/noise{int(np.log2(int(r)))*2-4}"][0, 0],
r"synthesis\.b(\d+)\.conv1\.noise_strength",
lambda r: tf_params[f"synthesis/{r}x{r}/Conv1/noise_strength"],
r"synthesis\.b(\d+)\.conv1\.affine\.weight",
lambda r: tf_params[f"synthesis/{r}x{r}/Conv1/mod_weight"].transpose(),
r"synthesis\.b(\d+)\.conv1\.affine\.bias",
lambda r: tf_params[f"synthesis/{r}x{r}/Conv1/mod_bias"] + 1,
r"synthesis\.b(\d+)\.torgb\.weight",
lambda r: tf_params[f"synthesis/{r}x{r}/ToRGB/weight"].transpose(3, 2, 0, 1),
r"synthesis\.b(\d+)\.torgb\.bias",
lambda r: tf_params[f"synthesis/{r}x{r}/ToRGB/bias"],
r"synthesis\.b(\d+)\.torgb\.affine\.weight",
lambda r: tf_params[f"synthesis/{r}x{r}/ToRGB/mod_weight"].transpose(),
r"synthesis\.b(\d+)\.torgb\.affine\.bias",
lambda r: tf_params[f"synthesis/{r}x{r}/ToRGB/mod_bias"] + 1,
r"synthesis\.b(\d+)\.skip\.weight",
lambda r: tf_params[f"synthesis/{r}x{r}/Skip/weight"][::-1, ::-1].transpose(
3, 2, 0, 1
),
r".*\.resample_filter",
None,
)
return G