in gala/model.py [0:0]
def __init__(self, num_inputs, recurrent=False, hidden_size=64):
super(MLPBase, self).__init__(recurrent, num_inputs, hidden_size)
if recurrent:
num_inputs = hidden_size
init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
constant_(x, 0), np.sqrt(2))
self.actor = nn.Sequential(
init_(nn.Linear(num_inputs, hidden_size)), nn.Tanh(),
init_(nn.Linear(hidden_size, hidden_size)), nn.Tanh())
self.critic = nn.Sequential(
init_(nn.Linear(num_inputs, hidden_size)), nn.Tanh(),
init_(nn.Linear(hidden_size, hidden_size)), nn.Tanh())
self.critic_linear = init_(nn.Linear(hidden_size, 1))
self.train()