def load_and_interpolate_3d_weights()

in models/swin_transformer_3d.py [0:0]


    def load_and_interpolate_3d_weights(self, logger):
        checkpoint = torch.load(self.pretrained, map_location=torch.device("cpu"))
        assert self.pretrained3d is not None and self.pretrained2d is False

        if "classy_state_dict" in checkpoint:
            # checkpoints trained in omnivore
            state_dict = checkpoint["classy_state_dict"][self.pretrained_model_key][
                "model"
            ]["trunk"]
        else:
            # checkpoints trained outside omnivore
            state_dict = checkpoint["model"]

        # delete relative_position_index since we always re-init it
        relative_position_index_keys = [
            k for k in state_dict.keys() if "relative_position_index" in k
        ]
        for k in relative_position_index_keys:
            del state_dict[k]

        # delete attn_mask since we always re-init it
        attn_mask_keys = [k for k in state_dict.keys() if "attn_mask" in k]
        for k in attn_mask_keys:
            del state_dict[k]

        # bicubic interpolate relative_position_bias_table if not match
        relative_position_bias_table_keys = [
            k for k in state_dict.keys() if "relative_position_bias_table" in k
        ]
        pretrained_window_size = self.pretrained3d
        T1 = 2 * pretrained_window_size[0] - 1
        S11 = 2 * pretrained_window_size[1] - 1
        S12 = 2 * pretrained_window_size[2] - 1
        assert (
            pretrained_window_size[0] == self.window_size[0]
        ), "Interpolating along time not supported"

        for k in relative_position_bias_table_keys:
            relative_position_bias_table_pretrained = state_dict[k]
            relative_position_bias_table_current = self.state_dict()[k]
            L1, nH1 = relative_position_bias_table_pretrained.size()
            L2, nH2 = relative_position_bias_table_current.size()
            L2 = (
                (2 * self.window_size[0] - 1)
                * (2 * self.window_size[1] - 1)
                * (2 * self.window_size[2] - 1)
            )
            if nH1 != nH2:
                logger.warning(f"Error in loading {k}, passing")
            else:
                if L1 != L2:
                    pretrained_bias = relative_position_bias_table_pretrained.view(
                        T1, S11, S12, nH1
                    )
                    pretrained_bias = pretrained_bias.permute(0, 3, 1, 2)
                    relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate(
                        pretrained_bias,
                        size=(2 * self.window_size[1] - 1, 2 * self.window_size[2] - 1),
                        mode="bicubic",
                    )
                    relative_position_bias_table_pretrained_resized = relative_position_bias_table_pretrained_resized.permute(
                        0, 2, 3, 1
                    )
                    relative_position_bias_table_pretrained = relative_position_bias_table_pretrained_resized.reshape(
                        L2, nH2
                    )

            state_dict[k] = relative_position_bias_table_pretrained
        msg = self.load_state_dict(state_dict, strict=False)
        logger.info(msg)
        logger.info(f"=> loaded successfully '{self.pretrained}'")
        del checkpoint
        torch.cuda.empty_cache()