in BigGAN_PyTorch/TFHub/converter.py [0:0]
def convert_from_v1(hub_dict, resolution=128):
weightname_dict = {"weight_u": "u0", "weight_bar": "weight", "bias": "bias"}
convnum_dict = {"conv0": "conv1", "conv1": "conv2", "conv_sc": "conv_sc"}
attention_blocknum = {128: 3, 256: 4, 512: 3}[resolution]
hub2me = {
"linear.weight": "shared.weight", # This is actually the shared weight
# Linear stuff
"G_linear.module.weight_bar": "linear.weight",
"G_linear.module.bias": "linear.bias",
"G_linear.module.weight_u": "linear.u0",
# output layer stuff
"ScaledCrossReplicaBN.weight": "output_layer.0.gain",
"ScaledCrossReplicaBN.bias": "output_layer.0.bias",
"ScaledCrossReplicaBN.running_mean": "output_layer.0.stored_mean",
"ScaledCrossReplicaBN.running_var": "output_layer.0.stored_var",
"colorize.module.weight_bar": "output_layer.2.weight",
"colorize.module.bias": "output_layer.2.bias",
"colorize.module.weight_u": "output_layer.2.u0",
# Attention stuff
"attention.gamma": "blocks.%d.1.gamma" % attention_blocknum,
"attention.theta.module.weight_u": "blocks.%d.1.theta.u0" % attention_blocknum,
"attention.theta.module.weight_bar": "blocks.%d.1.theta.weight"
% attention_blocknum,
"attention.phi.module.weight_u": "blocks.%d.1.phi.u0" % attention_blocknum,
"attention.phi.module.weight_bar": "blocks.%d.1.phi.weight"
% attention_blocknum,
"attention.g.module.weight_u": "blocks.%d.1.g.u0" % attention_blocknum,
"attention.g.module.weight_bar": "blocks.%d.1.g.weight" % attention_blocknum,
"attention.o_conv.module.weight_u": "blocks.%d.1.o.u0" % attention_blocknum,
"attention.o_conv.module.weight_bar": "blocks.%d.1.o.weight"
% attention_blocknum,
}
# Loop over the hub dict and build the hub2me map
for name in hub_dict.keys():
if "GBlock" in name:
if "HyperBN" not in name: # it's a conv
out = parse.parse("GBlock.{:d}.{}.module.{}", name)
blocknum, convnum, weightname = out
if weightname not in weightname_dict:
continue # else hyperBN in
out_name = "blocks.%d.0.%s.%s" % (
blocknum,
convnum_dict[convnum],
weightname_dict[weightname],
) # Increment conv number by 1
else: # hyperbn not conv
BNnum = 2 if "HyperBN_1" in name else 1
if "embed" in name:
out = parse.parse("GBlock.{:d}.{}.module.{}", name)
blocknum, gamma_or_beta, weightname = out
if weightname not in weightname_dict: # Ignore weight_v
continue
out_name = "blocks.%d.0.bn%d.%s.%s" % (
blocknum,
BNnum,
"gain" if "gamma" in gamma_or_beta else "bias",
weightname_dict[weightname],
)
else:
out = parse.parse("GBlock.{:d}.{}.bn.{}", name)
blocknum, dummy, mean_or_var = out
if "num_batches_tracked" in mean_or_var:
continue
out_name = "blocks.%d.0.bn%d.%s" % (
blocknum,
BNnum,
"stored_mean" if "mean" in mean_or_var else "stored_var",
)
hub2me[name] = out_name
# Invert the hub2me map
me2hub = {hub2me[item]: item for item in hub2me}
new_dict = {}
dimz_dict = {128: 20, 256: 20, 512: 16}
for item in me2hub:
# Swap input dim ordering on batchnorm bois to account for my arbitrary change of ordering when concatenating Ys and Zs
if (
("bn" in item and "weight" in item)
and ("gain" in item or "bias" in item)
and ("output_layer" not in item)
):
new_dict[item] = torch.cat(
[
hub_dict[me2hub[item]][:, -128:],
hub_dict[me2hub[item]][:, : dimz_dict[resolution]],
],
1,
)
# Reshape the first linear weight, bias, and u0
elif item == "linear.weight":
new_dict[item] = (
hub_dict[me2hub[item]]
.contiguous()
.view(4, 4, 96 * 16, -1)
.permute(2, 0, 1, 3)
.contiguous()
.view(-1, dimz_dict[resolution])
)
elif item == "linear.bias":
new_dict[item] = (
hub_dict[me2hub[item]]
.view(4, 4, 96 * 16)
.permute(2, 0, 1)
.contiguous()
.view(-1)
)
elif item == "linear.u0":
new_dict[item] = (
hub_dict[me2hub[item]]
.view(4, 4, 96 * 16)
.permute(2, 0, 1)
.contiguous()
.view(1, -1)
)
elif (
me2hub[item] == "linear.weight"
): # THIS IS THE SHARED WEIGHT NOT THE FIRST LINEAR LAYER
# Transpose shared weight so that it's an embedding
new_dict[item] = hub_dict[me2hub[item]].t()
elif "weight_u" in me2hub[item]: # Unsqueeze u0s
new_dict[item] = hub_dict[me2hub[item]].unsqueeze(0)
else:
new_dict[item] = hub_dict[me2hub[item]]
return new_dict