def __init__()

in ssl/real-dataset/models/mlp_head.py [0:0]


    def __init__(self, in_channels, mlp_hidden_size, projection_size, options=None):
        super(MLPHead, self).__init__()
        if options is None:
            options = dict(normalization="bn", has_bias=True, has_bn_affine=False, has_relu=True, additional_bn_at_input=False, custom_nz=None)

        assert options["custom_nz"] == "grad_act_zero" or options["custom_nz"] is None

        bn_size = in_channels if mlp_hidden_size is None else mlp_hidden_size
        l = self._create_normalization(bn_size, options)

        if options["additional_bn_at_input"]:
            l_before = nn.BatchNorm1d(in_channels, affine=False)
        else:
            l_before = None

        # assert "OriginalBN" in option
        layers = []

        if l_before is not None:
            layers.append(l_before)

        if mlp_hidden_size is not None:
            layers.append(nn.Linear(in_channels, mlp_hidden_size, bias=options["has_bias"]))
            if l is not None:
                layers.append(l)
            if options["has_relu"]:
                layers.append(nn.ReLU(inplace=True))
        else:
            if l is not None:
                layers.append(l)

        layers.append(nn.Linear(bn_size, projection_size, bias=options["has_bias"]))
        self.layers = nn.ModuleList(layers)
        self.gradW = [ None for _ in self.layers ]
        self.masks = [ None for _ in self.layers ]
        self.prods = [ list() for _ in self.layers ]
        self.custom_nz = options["custom_nz"]
        self.compute_adj_grad = True