in code/colored_mnist/main.py [0:0]
def __init__(self):
super(MLP, self).__init__()
if flags.grayscale_model:
lin1 = nn.Linear(14 * 14, flags.hidden_dim)
else:
lin1 = nn.Linear(2 * 14 * 14, flags.hidden_dim)
lin2 = nn.Linear(flags.hidden_dim, flags.hidden_dim)
lin3 = nn.Linear(flags.hidden_dim, 1)
for lin in [lin1, lin2, lin3]:
nn.init.xavier_uniform_(lin.weight)
nn.init.zeros_(lin.bias)
self._main = nn.Sequential(lin1, nn.ReLU(True), lin2, nn.ReLU(True), lin3)