def convert_from_v1()

in BigGAN_PyTorch/TFHub/converter.py [0:0]


def convert_from_v1(hub_dict, resolution=128):
    weightname_dict = {"weight_u": "u0", "weight_bar": "weight", "bias": "bias"}
    convnum_dict = {"conv0": "conv1", "conv1": "conv2", "conv_sc": "conv_sc"}
    attention_blocknum = {128: 3, 256: 4, 512: 3}[resolution]
    hub2me = {
        "linear.weight": "shared.weight",  # This is actually the shared weight
        # Linear stuff
        "G_linear.module.weight_bar": "linear.weight",
        "G_linear.module.bias": "linear.bias",
        "G_linear.module.weight_u": "linear.u0",
        # output layer stuff
        "ScaledCrossReplicaBN.weight": "output_layer.0.gain",
        "ScaledCrossReplicaBN.bias": "output_layer.0.bias",
        "ScaledCrossReplicaBN.running_mean": "output_layer.0.stored_mean",
        "ScaledCrossReplicaBN.running_var": "output_layer.0.stored_var",
        "colorize.module.weight_bar": "output_layer.2.weight",
        "colorize.module.bias": "output_layer.2.bias",
        "colorize.module.weight_u": "output_layer.2.u0",
        # Attention stuff
        "attention.gamma": "blocks.%d.1.gamma" % attention_blocknum,
        "attention.theta.module.weight_u": "blocks.%d.1.theta.u0" % attention_blocknum,
        "attention.theta.module.weight_bar": "blocks.%d.1.theta.weight"
        % attention_blocknum,
        "attention.phi.module.weight_u": "blocks.%d.1.phi.u0" % attention_blocknum,
        "attention.phi.module.weight_bar": "blocks.%d.1.phi.weight"
        % attention_blocknum,
        "attention.g.module.weight_u": "blocks.%d.1.g.u0" % attention_blocknum,
        "attention.g.module.weight_bar": "blocks.%d.1.g.weight" % attention_blocknum,
        "attention.o_conv.module.weight_u": "blocks.%d.1.o.u0" % attention_blocknum,
        "attention.o_conv.module.weight_bar": "blocks.%d.1.o.weight"
        % attention_blocknum,
    }

    # Loop over the hub dict and build the hub2me map
    for name in hub_dict.keys():
        if "GBlock" in name:
            if "HyperBN" not in name:  # it's a conv
                out = parse.parse("GBlock.{:d}.{}.module.{}", name)
                blocknum, convnum, weightname = out
                if weightname not in weightname_dict:
                    continue  # else hyperBN in
                out_name = "blocks.%d.0.%s.%s" % (
                    blocknum,
                    convnum_dict[convnum],
                    weightname_dict[weightname],
                )  # Increment conv number by 1
            else:  # hyperbn not conv
                BNnum = 2 if "HyperBN_1" in name else 1
                if "embed" in name:
                    out = parse.parse("GBlock.{:d}.{}.module.{}", name)
                    blocknum, gamma_or_beta, weightname = out
                    if weightname not in weightname_dict:  # Ignore weight_v
                        continue
                    out_name = "blocks.%d.0.bn%d.%s.%s" % (
                        blocknum,
                        BNnum,
                        "gain" if "gamma" in gamma_or_beta else "bias",
                        weightname_dict[weightname],
                    )
                else:
                    out = parse.parse("GBlock.{:d}.{}.bn.{}", name)
                    blocknum, dummy, mean_or_var = out
                    if "num_batches_tracked" in mean_or_var:
                        continue
                    out_name = "blocks.%d.0.bn%d.%s" % (
                        blocknum,
                        BNnum,
                        "stored_mean" if "mean" in mean_or_var else "stored_var",
                    )
            hub2me[name] = out_name

    # Invert the hub2me map
    me2hub = {hub2me[item]: item for item in hub2me}
    new_dict = {}
    dimz_dict = {128: 20, 256: 20, 512: 16}
    for item in me2hub:
        # Swap input dim ordering on batchnorm bois to account for my arbitrary change of ordering when concatenating Ys and Zs
        if (
            ("bn" in item and "weight" in item)
            and ("gain" in item or "bias" in item)
            and ("output_layer" not in item)
        ):
            new_dict[item] = torch.cat(
                [
                    hub_dict[me2hub[item]][:, -128:],
                    hub_dict[me2hub[item]][:, : dimz_dict[resolution]],
                ],
                1,
            )
        # Reshape the first linear weight, bias, and u0
        elif item == "linear.weight":
            new_dict[item] = (
                hub_dict[me2hub[item]]
                .contiguous()
                .view(4, 4, 96 * 16, -1)
                .permute(2, 0, 1, 3)
                .contiguous()
                .view(-1, dimz_dict[resolution])
            )
        elif item == "linear.bias":
            new_dict[item] = (
                hub_dict[me2hub[item]]
                .view(4, 4, 96 * 16)
                .permute(2, 0, 1)
                .contiguous()
                .view(-1)
            )
        elif item == "linear.u0":
            new_dict[item] = (
                hub_dict[me2hub[item]]
                .view(4, 4, 96 * 16)
                .permute(2, 0, 1)
                .contiguous()
                .view(1, -1)
            )
        elif (
            me2hub[item] == "linear.weight"
        ):  # THIS IS THE SHARED WEIGHT NOT THE FIRST LINEAR LAYER
            # Transpose shared weight so that it's an embedding
            new_dict[item] = hub_dict[me2hub[item]].t()
        elif "weight_u" in me2hub[item]:  # Unsqueeze u0s
            new_dict[item] = hub_dict[me2hub[item]].unsqueeze(0)
        else:
            new_dict[item] = hub_dict[me2hub[item]]
    return new_dict