def forward_intermediate_features()

in models/swin_transformer_3d.py [0:0]


    def forward_intermediate_features(self, stage_outputs, out_feat_keys):
        """
        Inputs
        - stage_outputs: list of features without self.norm() applied to them
        - out_feat_keys: list of feature names (str)
                         specified as "stage<int>" for feature with norm
                         or "interim<int>" for feature without norm
        """
        out_features = []
        for key in out_feat_keys:
            if key.startswith("stage"):
                rep = "stage"
            elif key.startswith("interim"):
                rep = "interim"
            else:
                raise ValueError(f"Invalid key {key}")
            idx = int(key.replace(rep, ""))
            feat = stage_outputs[idx]
            if rep == "stage":
                feat = self._apply_norm(feat)
            out_features.append(feat)
        return out_features