in ssl/real-dataset/models/mlp_head.py [0:0]
def __init__(self, in_channels, mlp_hidden_size, projection_size, options=None):
super(MLPHead, self).__init__()
if options is None:
options = dict(normalization="bn", has_bias=True, has_bn_affine=False, has_relu=True, additional_bn_at_input=False, custom_nz=None)
assert options["custom_nz"] == "grad_act_zero" or options["custom_nz"] is None
bn_size = in_channels if mlp_hidden_size is None else mlp_hidden_size
l = self._create_normalization(bn_size, options)
if options["additional_bn_at_input"]:
l_before = nn.BatchNorm1d(in_channels, affine=False)
else:
l_before = None
# assert "OriginalBN" in option
layers = []
if l_before is not None:
layers.append(l_before)
if mlp_hidden_size is not None:
layers.append(nn.Linear(in_channels, mlp_hidden_size, bias=options["has_bias"]))
if l is not None:
layers.append(l)
if options["has_relu"]:
layers.append(nn.ReLU(inplace=True))
else:
if l is not None:
layers.append(l)
layers.append(nn.Linear(bn_size, projection_size, bias=options["has_bias"]))
self.layers = nn.ModuleList(layers)
self.gradW = [ None for _ in self.layers ]
self.masks = [ None for _ in self.layers ]
self.prods = [ list() for _ in self.layers ]
self.custom_nz = options["custom_nz"]
self.compute_adj_grad = True