in code/colored_mnist/main.py [0:0]
def forward(self, input):
if flags.grayscale_model:
out = input.view(input.shape[0], 2, 14 * 14).sum(dim=1)
else:
out = input.view(input.shape[0], 2 * 14 * 14)
out = self._main(out)
return out