def convert_tf_discriminator()

in stylegan2_ada_pytorch/legacy.py [0:0]


def convert_tf_discriminator(tf_D):
    if tf_D.version < 4:
        raise ValueError("TensorFlow pickle version too low")

    # Collect kwargs.
    tf_kwargs = tf_D.static_kwargs
    known_kwargs = set()

    def kwarg(tf_name, default=None):
        known_kwargs.add(tf_name)
        return tf_kwargs.get(tf_name, default)

    # Convert kwargs.
    kwargs = dnnlib.EasyDict(
        c_dim=kwarg("label_size", 0),
        img_resolution=kwarg("resolution", 1024),
        img_channels=kwarg("num_channels", 3),
        architecture=kwarg("architecture", "resnet"),
        channel_base=kwarg("fmap_base", 16384) * 2,
        channel_max=kwarg("fmap_max", 512),
        num_fp16_res=kwarg("num_fp16_res", 0),
        conv_clamp=kwarg("conv_clamp", None),
        cmap_dim=kwarg("mapping_fmaps", None),
        block_kwargs=dnnlib.EasyDict(
            activation=kwarg("nonlinearity", "lrelu"),
            resample_filter=kwarg("resample_kernel", [1, 3, 3, 1]),
            freeze_layers=kwarg("freeze_layers", 0),
        ),
        mapping_kwargs=dnnlib.EasyDict(
            num_layers=kwarg("mapping_layers", 0),
            embed_features=kwarg("mapping_fmaps", None),
            layer_features=kwarg("mapping_fmaps", None),
            activation=kwarg("nonlinearity", "lrelu"),
            lr_multiplier=kwarg("mapping_lrmul", 0.1),
        ),
        epilogue_kwargs=dnnlib.EasyDict(
            mbstd_group_size=kwarg("mbstd_group_size", None),
            mbstd_num_channels=kwarg("mbstd_num_features", 1),
            activation=kwarg("nonlinearity", "lrelu"),
        ),
    )

    # Check for unknown kwargs.
    kwarg("structure")
    unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
    if len(unknown_kwargs) > 0:
        raise ValueError("Unknown TensorFlow kwarg", unknown_kwargs[0])

    # Collect params.
    tf_params = _collect_tf_params(tf_D)
    for name, value in list(tf_params.items()):
        match = re.fullmatch(r"FromRGB_lod(\d+)/(.*)", name)
        if match:
            r = kwargs.img_resolution // (2 ** int(match.group(1)))
            tf_params[f"{r}x{r}/FromRGB/{match.group(2)}"] = value
            kwargs.architecture = "orig"
    # for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')

    # Convert params.
    from training import networks

    D = networks.Discriminator(**kwargs).eval().requires_grad_(False)
    # pylint: disable=unnecessary-lambda
    _populate_module_params(
        D,
        r"b(\d+)\.fromrgb\.weight",
        lambda r: tf_params[f"{r}x{r}/FromRGB/weight"].transpose(3, 2, 0, 1),
        r"b(\d+)\.fromrgb\.bias",
        lambda r: tf_params[f"{r}x{r}/FromRGB/bias"],
        r"b(\d+)\.conv(\d+)\.weight",
        lambda r, i: tf_params[
            f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/weight'
        ].transpose(3, 2, 0, 1),
        r"b(\d+)\.conv(\d+)\.bias",
        lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/bias'],
        r"b(\d+)\.skip\.weight",
        lambda r: tf_params[f"{r}x{r}/Skip/weight"].transpose(3, 2, 0, 1),
        r"mapping\.embed\.weight",
        lambda: tf_params[f"LabelEmbed/weight"].transpose(),
        r"mapping\.embed\.bias",
        lambda: tf_params[f"LabelEmbed/bias"],
        r"mapping\.fc(\d+)\.weight",
        lambda i: tf_params[f"Mapping{i}/weight"].transpose(),
        r"mapping\.fc(\d+)\.bias",
        lambda i: tf_params[f"Mapping{i}/bias"],
        r"b4\.conv\.weight",
        lambda: tf_params[f"4x4/Conv/weight"].transpose(3, 2, 0, 1),
        r"b4\.conv\.bias",
        lambda: tf_params[f"4x4/Conv/bias"],
        r"b4\.fc\.weight",
        lambda: tf_params[f"4x4/Dense0/weight"].transpose(),
        r"b4\.fc\.bias",
        lambda: tf_params[f"4x4/Dense0/bias"],
        r"b4\.out\.weight",
        lambda: tf_params[f"Output/weight"].transpose(),
        r"b4\.out\.bias",
        lambda: tf_params[f"Output/bias"],
        r".*\.resample_filter",
        None,
    )
    return D