in models/swin_transformer_3d.py [0:0]
def load_and_interpolate_3d_weights(self, logger):
checkpoint = torch.load(self.pretrained, map_location=torch.device("cpu"))
assert self.pretrained3d is not None and self.pretrained2d is False
if "classy_state_dict" in checkpoint:
# checkpoints trained in omnivore
state_dict = checkpoint["classy_state_dict"][self.pretrained_model_key][
"model"
]["trunk"]
else:
# checkpoints trained outside omnivore
state_dict = checkpoint["model"]
# delete relative_position_index since we always re-init it
relative_position_index_keys = [
k for k in state_dict.keys() if "relative_position_index" in k
]
for k in relative_position_index_keys:
del state_dict[k]
# delete attn_mask since we always re-init it
attn_mask_keys = [k for k in state_dict.keys() if "attn_mask" in k]
for k in attn_mask_keys:
del state_dict[k]
# bicubic interpolate relative_position_bias_table if not match
relative_position_bias_table_keys = [
k for k in state_dict.keys() if "relative_position_bias_table" in k
]
pretrained_window_size = self.pretrained3d
T1 = 2 * pretrained_window_size[0] - 1
S11 = 2 * pretrained_window_size[1] - 1
S12 = 2 * pretrained_window_size[2] - 1
assert (
pretrained_window_size[0] == self.window_size[0]
), "Interpolating along time not supported"
for k in relative_position_bias_table_keys:
relative_position_bias_table_pretrained = state_dict[k]
relative_position_bias_table_current = self.state_dict()[k]
L1, nH1 = relative_position_bias_table_pretrained.size()
L2, nH2 = relative_position_bias_table_current.size()
L2 = (
(2 * self.window_size[0] - 1)
* (2 * self.window_size[1] - 1)
* (2 * self.window_size[2] - 1)
)
if nH1 != nH2:
logger.warning(f"Error in loading {k}, passing")
else:
if L1 != L2:
pretrained_bias = relative_position_bias_table_pretrained.view(
T1, S11, S12, nH1
)
pretrained_bias = pretrained_bias.permute(0, 3, 1, 2)
relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate(
pretrained_bias,
size=(2 * self.window_size[1] - 1, 2 * self.window_size[2] - 1),
mode="bicubic",
)
relative_position_bias_table_pretrained_resized = relative_position_bias_table_pretrained_resized.permute(
0, 2, 3, 1
)
relative_position_bias_table_pretrained = relative_position_bias_table_pretrained_resized.reshape(
L2, nH2
)
state_dict[k] = relative_position_bias_table_pretrained
msg = self.load_state_dict(state_dict, strict=False)
logger.info(msg)
logger.info(f"=> loaded successfully '{self.pretrained}'")
del checkpoint
torch.cuda.empty_cache()