def init_weights()

in scripts/train_imagenet.py [0:0]


def init_weights(model):
    global conf
    for name, m in model.named_modules():
        if isinstance(m, nn.Conv2d):
            init_fn = getattr(nn.init, conf["network"]["weight_init"] + "_")
            if (
                conf["network"]["weight_init"].startswith("xavier")
                or conf["network"]["weight_init"] == "orthogonal"
            ):
                gain = conf["network"]["weight_gain_multiplier"]
                if (
                    conf["network"]["activation"] == "relu"
                    or conf["network"]["activation"] == "elu"
                ):
                    gain *= nn.init.calculate_gain("relu")
                elif conf["network"]["activation"] == "leaky_relu":
                    gain *= nn.init.calculate_gain(
                        "leaky_relu", conf["network"]["activation_param"]
                    )
                init_fn(m.weight, gain)
            elif conf["network"]["weight_init"].startswith("kaiming"):
                if (
                    conf["network"]["activation"] == "relu"
                    or conf["network"]["activation"] == "elu"
                ):
                    init_fn(m.weight, 0)
                else:
                    init_fn(m.weight, conf["network"]["activation_param"])

            if hasattr(m, "bias") and m.bias is not None:
                nn.init.constant_(m.bias, 0.0)
        elif isinstance(m, nn.BatchNorm2d) or isinstance(m, ABN):
            nn.init.constant_(m.weight, 1.0)
            nn.init.constant_(m.bias, 0.0)
        elif isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight, 0.1)
            nn.init.constant_(m.bias, 0.0)