in scripts/convert_maskgit_vqgan.py [0:0]
def rename_flax_dict(params):
keys = list(params.keys())
for key in keys:
new_key = ".".join(key)
params[new_key] = params.pop(key)
keys = list(params.keys())
block_map = {
0: (0, 0),
1: (0, 1),
2: (1, 0),
3: (1, 1),
4: (2, 0),
5: (2, 1),
6: (3, 0),
7: (3, 1),
8: (4, 0),
9: (4, 1),
}
encoder_keys = [key for key in keys if "encoder.ResBlock" in key]
for key in encoder_keys:
if "ResBlock_10" in key:
new_key = key.replace("ResBlock_10", "mid.0")
new_key = new_key.replace("Conv_0", "conv1")
new_key = new_key.replace("Conv_1", "conv2")
new_key = new_key.replace("GroupNorm_0", "norm1")
new_key = new_key.replace("GroupNorm_1", "norm2")
params[new_key] = params.pop(key)
elif "ResBlock_11" in key:
new_key = key.replace("ResBlock_11", "mid.1")
new_key = new_key.replace("Conv_0", "conv1")
new_key = new_key.replace("Conv_1", "conv2")
new_key = new_key.replace("GroupNorm_0", "norm1")
new_key = new_key.replace("GroupNorm_1", "norm2")
params[new_key] = params.pop(key)
keys = list(params.keys())
encoder_keys = [key for key in keys if "encoder.ResBlock" in key]
for key in encoder_keys:
name = key.split(".")[1]
res_name, idx = name.split("_")
idx1, idx2 = block_map[int(idx)]
new_key = key.replace(name, f"down.{idx1}.block.{idx2}")
new_key = new_key.replace("Conv_0", "conv1")
new_key = new_key.replace("Conv_1", "conv2")
new_key = new_key.replace("Conv_2", "nin_shortcut")
new_key = new_key.replace("GroupNorm_0", "norm1")
new_key = new_key.replace("GroupNorm_1", "norm2")
params[new_key] = params.pop(key)
keys = list(params.keys())
decoder_keys = [key for key in keys if "decoder.ResBlock" in key]
for key in decoder_keys:
if "ResBlock_0" in key:
new_key = key.replace("ResBlock_0", "mid.0")
new_key = new_key.replace("Conv_0", "conv1")
new_key = new_key.replace("Conv_1", "conv2")
new_key = new_key.replace("GroupNorm_0", "norm1")
new_key = new_key.replace("GroupNorm_1", "norm2")
params[new_key] = params.pop(key)
elif "ResBlock_1." in key:
new_key = key.replace("ResBlock_1", "mid.1")
new_key = new_key.replace("Conv_0", "conv1")
new_key = new_key.replace("Conv_1", "conv2")
new_key = new_key.replace("GroupNorm_0", "norm1")
new_key = new_key.replace("GroupNorm_1", "norm2")
params[new_key] = params.pop(key)
keys = list(params.keys())
decoder_keys = [key for key in keys if "decoder.ResBlock" in key]
for key in decoder_keys:
name = key.split(".")[1]
res_name, idx = name.split("_")
idx = int(idx) - 2
idx1, idx2 = block_map[int(idx)]
idx1 = 4 - idx1
new_key = key.replace(name, f"up.{idx1}.block.{idx2}")
new_key = new_key.replace("Conv_0", "conv1")
new_key = new_key.replace("Conv_1", "conv2")
new_key = new_key.replace("Conv_2", "nin_shortcut")
new_key = new_key.replace("GroupNorm_0", "norm1")
new_key = new_key.replace("GroupNorm_1", "norm2")
params[new_key] = params.pop(key)
keys = list(params.keys())
for i in range(1, 5):
w = f"decoder.Conv_{i}.kernel"
b = f"decoder.Conv_{i}.bias"
new_w = f"decoder.up.{5 - i}.upsample_conv.kernel"
new_b = f"decoder.up.{5 - i}.upsample_conv.bias"
params[new_w] = params.pop(w)
params[new_b] = params.pop(b)
keys = list(params.keys())
for key in keys:
if "Conv_" in key:
new_key = key.replace("Conv_0", "conv_in")
new_key = new_key.replace("Conv_5", "conv_out")
new_key = new_key.replace("Conv_1", "conv_out")
params[new_key] = params.pop(key)
elif "GroupNorm" in key:
new_key = key.replace("GroupNorm_0", "norm_out")
params[new_key] = params.pop(key)
params["quantize.embedding.embedding"] = params.pop("quantizer.codebook")
return params