in scripts/convert_maskgit_vqgan.py [0:0]
def load_flax_weights_in_pytorch_model(pt_model, flax_state):
"""Load flax checkpoints in a PyTorch model"""
try:
import torch # noqa: F401
except ImportError:
logger.error(
"Loading a Flax weights in PyTorch, requires both PyTorch and Flax to be installed. Please see"
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
" instructions."
)
raise
pt_model_dict = pt_model.state_dict()
# keep track of unexpected & missing keys
unexpected_keys = []
missing_keys = set(pt_model_dict.keys())
for flax_key, flax_tensor in flax_state.items():
flax_key_tuple = tuple(flax_key.split("."))
# rename flax weights to PyTorch format
if flax_key_tuple[-1] == "kernel" and flax_tensor.ndim == 4 and ".".join(flax_key_tuple) not in pt_model_dict:
# conv layer
flax_key_tuple = flax_key_tuple[:-1] + ("weight",)
flax_tensor = jnp.transpose(flax_tensor, (3, 2, 0, 1))
elif flax_key_tuple[-1] == "kernel" and ".".join(flax_key_tuple) not in pt_model_dict:
# linear layer
flax_key_tuple = flax_key_tuple[:-1] + ("weight",)
flax_tensor = flax_tensor.T
elif flax_key_tuple[-1] in ["scale", "embedding"]:
flax_key_tuple = flax_key_tuple[:-1] + ("weight",)
flax_key = ".".join(flax_key_tuple)
if "in_proj.weight" in flax_key:
flax_key = flax_key.replace("in_proj.weight", "in_proj_weight")
if flax_key in pt_model_dict:
if flax_tensor.shape != pt_model_dict[flax_key].shape:
raise ValueError(
f"Flax checkpoint seems to be incorrect. Weight {flax_key_tuple} was expected "
f"to be of shape {pt_model_dict[flax_key].shape}, but is {flax_tensor.shape}."
)
else:
# add weight to pytorch dict
flax_tensor = np.asarray(flax_tensor) if not isinstance(flax_tensor, np.ndarray) else flax_tensor
pt_model_dict[flax_key] = torch.from_numpy(flax_tensor)
# remove from missing keys
missing_keys.remove(flax_key)
else:
# weight is not expected by PyTorch model
unexpected_keys.append(flax_key)
pt_model.load_state_dict(pt_model_dict)
# re-transform missing_keys to list
missing_keys = list(missing_keys)
if len(unexpected_keys) > 0:
logger.warning(
"Some weights of the Flax model were not used when initializing the PyTorch model"
f" {pt_model.__class__.__name__}: {unexpected_keys}."
)
else:
logger.warning(f"All Flax model weights were used when initializing {pt_model.__class__.__name__}.\n")
if len(missing_keys) > 0:
logger.warning(
f"Some weights of {pt_model.__class__.__name__} were not initialized from the Flax model and are newly"
f" initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to"
" use it for predictions and inference."
)
else:
logger.warning(f"All the weights of {pt_model.__class__.__name__} were initialized from the Flax model.")
return pt_model