in scripts/train_imagenet.py [0:0]
def init_weights(model):
global conf
for name, m in model.named_modules():
if isinstance(m, nn.Conv2d):
init_fn = getattr(nn.init, conf["network"]["weight_init"] + "_")
if (
conf["network"]["weight_init"].startswith("xavier")
or conf["network"]["weight_init"] == "orthogonal"
):
gain = conf["network"]["weight_gain_multiplier"]
if (
conf["network"]["activation"] == "relu"
or conf["network"]["activation"] == "elu"
):
gain *= nn.init.calculate_gain("relu")
elif conf["network"]["activation"] == "leaky_relu":
gain *= nn.init.calculate_gain(
"leaky_relu", conf["network"]["activation_param"]
)
init_fn(m.weight, gain)
elif conf["network"]["weight_init"].startswith("kaiming"):
if (
conf["network"]["activation"] == "relu"
or conf["network"]["activation"] == "elu"
):
init_fn(m.weight, 0)
else:
init_fn(m.weight, conf["network"]["activation_param"])
if hasattr(m, "bias") and m.bias is not None:
nn.init.constant_(m.bias, 0.0)
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, ABN):
nn.init.constant_(m.weight, 1.0)
nn.init.constant_(m.bias, 0.0)
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight, 0.1)
nn.init.constant_(m.bias, 0.0)