in stylegan2_ada_pytorch/legacy.py [0:0]
def convert_tf_discriminator(tf_D):
if tf_D.version < 4:
raise ValueError("TensorFlow pickle version too low")
# Collect kwargs.
tf_kwargs = tf_D.static_kwargs
known_kwargs = set()
def kwarg(tf_name, default=None):
known_kwargs.add(tf_name)
return tf_kwargs.get(tf_name, default)
# Convert kwargs.
kwargs = dnnlib.EasyDict(
c_dim=kwarg("label_size", 0),
img_resolution=kwarg("resolution", 1024),
img_channels=kwarg("num_channels", 3),
architecture=kwarg("architecture", "resnet"),
channel_base=kwarg("fmap_base", 16384) * 2,
channel_max=kwarg("fmap_max", 512),
num_fp16_res=kwarg("num_fp16_res", 0),
conv_clamp=kwarg("conv_clamp", None),
cmap_dim=kwarg("mapping_fmaps", None),
block_kwargs=dnnlib.EasyDict(
activation=kwarg("nonlinearity", "lrelu"),
resample_filter=kwarg("resample_kernel", [1, 3, 3, 1]),
freeze_layers=kwarg("freeze_layers", 0),
),
mapping_kwargs=dnnlib.EasyDict(
num_layers=kwarg("mapping_layers", 0),
embed_features=kwarg("mapping_fmaps", None),
layer_features=kwarg("mapping_fmaps", None),
activation=kwarg("nonlinearity", "lrelu"),
lr_multiplier=kwarg("mapping_lrmul", 0.1),
),
epilogue_kwargs=dnnlib.EasyDict(
mbstd_group_size=kwarg("mbstd_group_size", None),
mbstd_num_channels=kwarg("mbstd_num_features", 1),
activation=kwarg("nonlinearity", "lrelu"),
),
)
# Check for unknown kwargs.
kwarg("structure")
unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
if len(unknown_kwargs) > 0:
raise ValueError("Unknown TensorFlow kwarg", unknown_kwargs[0])
# Collect params.
tf_params = _collect_tf_params(tf_D)
for name, value in list(tf_params.items()):
match = re.fullmatch(r"FromRGB_lod(\d+)/(.*)", name)
if match:
r = kwargs.img_resolution // (2 ** int(match.group(1)))
tf_params[f"{r}x{r}/FromRGB/{match.group(2)}"] = value
kwargs.architecture = "orig"
# for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
# Convert params.
from training import networks
D = networks.Discriminator(**kwargs).eval().requires_grad_(False)
# pylint: disable=unnecessary-lambda
_populate_module_params(
D,
r"b(\d+)\.fromrgb\.weight",
lambda r: tf_params[f"{r}x{r}/FromRGB/weight"].transpose(3, 2, 0, 1),
r"b(\d+)\.fromrgb\.bias",
lambda r: tf_params[f"{r}x{r}/FromRGB/bias"],
r"b(\d+)\.conv(\d+)\.weight",
lambda r, i: tf_params[
f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/weight'
].transpose(3, 2, 0, 1),
r"b(\d+)\.conv(\d+)\.bias",
lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/bias'],
r"b(\d+)\.skip\.weight",
lambda r: tf_params[f"{r}x{r}/Skip/weight"].transpose(3, 2, 0, 1),
r"mapping\.embed\.weight",
lambda: tf_params[f"LabelEmbed/weight"].transpose(),
r"mapping\.embed\.bias",
lambda: tf_params[f"LabelEmbed/bias"],
r"mapping\.fc(\d+)\.weight",
lambda i: tf_params[f"Mapping{i}/weight"].transpose(),
r"mapping\.fc(\d+)\.bias",
lambda i: tf_params[f"Mapping{i}/bias"],
r"b4\.conv\.weight",
lambda: tf_params[f"4x4/Conv/weight"].transpose(3, 2, 0, 1),
r"b4\.conv\.bias",
lambda: tf_params[f"4x4/Conv/bias"],
r"b4\.fc\.weight",
lambda: tf_params[f"4x4/Dense0/weight"].transpose(),
r"b4\.fc\.bias",
lambda: tf_params[f"4x4/Dense0/bias"],
r"b4\.out\.weight",
lambda: tf_params[f"Output/weight"].transpose(),
r"b4\.out\.bias",
lambda: tf_params[f"Output/bias"],
r".*\.resample_filter",
None,
)
return D