in models/swin_transformer_3d.py [0:0]
def init_weights(self, pretrained=None):
"""Initialize the weights in backbone.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
def _init_weights(m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
if pretrained:
self.pretrained = pretrained
if isinstance(self.pretrained, str) or isinstance(self.pretrained, list):
self.apply(_init_weights)
logging.info(f"load model from: {self.pretrained}")
if self.pretrained2d:
# Inflate 2D model into 3D model.
logging.info(f"Inflating with {self.pretrained_model_key}")
self.inflate_weights(logging)
elif self.pretrained3d:
logging.info(f"Loading 3D model with {self.pretrained_model_key}")
self.load_and_interpolate_3d_weights(logging)
else:
raise ValueError(
"Use VISSL loading for this. This code "
"is only for Swin inflation."
)
elif self.pretrained is None:
self.apply(_init_weights)
else:
raise TypeError("pretrained must be a str or None")