in models/swin_transformer_3d.py [0:0]
def forward_intermediate_features(self, stage_outputs, out_feat_keys):
"""
Inputs
- stage_outputs: list of features without self.norm() applied to them
- out_feat_keys: list of feature names (str)
specified as "stage<int>" for feature with norm
or "interim<int>" for feature without norm
"""
out_features = []
for key in out_feat_keys:
if key.startswith("stage"):
rep = "stage"
elif key.startswith("interim"):
rep = "interim"
else:
raise ValueError(f"Invalid key {key}")
idx = int(key.replace(rep, ""))
feat = stage_outputs[idx]
if rep == "stage":
feat = self._apply_norm(feat)
out_features.append(feat)
return out_features