in models/swin_transformer_3d.py [0:0]
def inflate_weights(self, logger):
"""Inflate the swin2d parameters to swin3d.
The differences between swin3d and swin2d mainly lie in an extra
axis. To utilize the pretrained parameters in 2d model,
the weight of swin2d models should be inflated to fit in the shapes of
the 3d counterpart.
Args:
logger (logging.Logger): The logger used to print
debugging infomation.
"""
checkpoint = torch.load(self.pretrained, map_location=torch.device("cpu"))
if "classy_state_dict" in checkpoint:
# checkpoints trained in omnivore
state_dict = checkpoint["classy_state_dict"][self.pretrained_model_key][
"model"
]["trunk"]
else:
# checkpoints trained outside omnivore
state_dict = checkpoint["model"]
# delete relative_position_index since we always re-init it
relative_position_index_keys = [
k for k in state_dict.keys() if "relative_position_index" in k
]
for k in relative_position_index_keys:
del state_dict[k]
# delete attn_mask since we always re-init it
attn_mask_keys = [k for k in state_dict.keys() if "attn_mask" in k]
for k in attn_mask_keys:
del state_dict[k]
if state_dict["patch_embed.proj.weight"].ndim == 4:
state_dict["patch_embed.proj.weight"] = state_dict[
"patch_embed.proj.weight"
].unsqueeze(2)
state_dict["patch_embed.proj.weight"] = (
state_dict["patch_embed.proj.weight"].repeat(1, 1, self.patch_size[0], 1, 1)
/ self.patch_size[0]
)
if (
"depth_patch_embed.proj.weight" in state_dict
and state_dict["depth_patch_embed.proj.weight"].ndim == 4
):
state_dict["depth_patch_embed.proj.weight"] = state_dict[
"depth_patch_embed.proj.weight"
].unsqueeze(2)
# bicubic interpolate relative_position_bias_table if not match
relative_position_bias_table_keys = [
k for k in state_dict.keys() if "relative_position_bias_table" in k
]
for k in relative_position_bias_table_keys:
relative_position_bias_table_pretrained = state_dict[k]
relative_position_bias_table_current = self.state_dict()[k]
L1, nH1 = relative_position_bias_table_pretrained.size()
L2, nH2 = relative_position_bias_table_current.size()
L2 = (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1)
wd = self.window_size[0]
if nH1 != nH2:
logger.warning(f"Error in loading {k}, passing")
else:
if L1 != L2:
S1 = int(L1 ** 0.5)
relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate(
relative_position_bias_table_pretrained.permute(1, 0).view(
1, nH1, S1, S1
),
size=(2 * self.window_size[1] - 1, 2 * self.window_size[2] - 1),
mode="bicubic",
)
relative_position_bias_table_pretrained = relative_position_bias_table_pretrained_resized.view(
nH2, L2
).permute(
1, 0
)
state_dict[k] = relative_position_bias_table_pretrained.repeat(
2 * wd - 1, 1
)
msg = self.load_state_dict(state_dict, strict=False)
logger.info(msg)
logger.info(f"=> loaded successfully '{self.pretrained}'")
del checkpoint
torch.cuda.empty_cache()