in decoder.py [0:0]
def __init__(self,n_inputs, n_outputs, factor=6, bn='before'):
super(Encoder, self).__init__()
n_hidden = factor*128
if bn == 'before':
self.net = torch.nn.Sequential(
torch.nn.Linear(n_inputs, n_hidden),
torch.nn.BatchNorm1d(n_hidden),
torch.nn.ReLU(),
torch.nn.Linear(n_hidden, n_hidden),
torch.nn.BatchNorm1d(n_hidden),
torch.nn.ReLU(),
torch.nn.Linear(n_hidden, n_hidden),
torch.nn.BatchNorm1d(n_hidden),
torch.nn.ReLU(),
torch.nn.Linear(n_hidden, n_hidden),
torch.nn.BatchNorm1d(n_hidden),
torch.nn.ReLU(),
torch.nn.Linear(n_hidden, n_outputs))
elif bn == 'after':
self.net = torch.nn.Sequential(
torch.nn.Linear(n_inputs, n_hidden),
torch.nn.ReLU(),
torch.nn.BatchNorm1d(n_hidden),
torch.nn.Linear(n_hidden, n_hidden),
torch.nn.ReLU(),
torch.nn.BatchNorm1d(n_hidden),
torch.nn.Linear(n_hidden, n_hidden),
torch.nn.ReLU(),
torch.nn.BatchNorm1d(n_hidden),
torch.nn.Linear(n_hidden, n_hidden),
torch.nn.ReLU(),
torch.nn.BatchNorm1d(n_hidden),
torch.nn.Linear(n_hidden, n_outputs))
else:
self.net = torch.nn.Sequential(
torch.nn.Linear(n_inputs, n_hidden),
torch.nn.ReLU(),
torch.nn.Linear(n_hidden, n_hidden),
torch.nn.ReLU(),
torch.nn.Linear(n_hidden, n_hidden),
torch.nn.ReLU(),
torch.nn.Linear(n_hidden, n_hidden),
torch.nn.ReLU(),
torch.nn.Linear(n_hidden, n_outputs))