in qlearn/commun/local_mnf_layer.py [0:0]
def forward(self, x, same_noise=False):
batch_size = x.size()[0]
if self.training:
z = self.sample_z(batch_size, kl=False, same_noise=same_noise)
weight_std = torch.clamp(torch.exp(self.weight_logstd), 0, self.threshold_var)
bias_std = torch.clamp(torch.exp(0.5 * self.bias_logvar), 0, self.threshold_var)
out_mu = torch.matmul(x * z, self.weight_mu) + self.bias_mu
out_var = torch.matmul(x * x, weight_std * weight_std) + bias_std
if batch_size > 1:
if same_noise:
epsilon_linear = self.epsilon_linear.expand(batch_size, self.out_features)
else:
epsilon_linear = Variable(torch.randn(batch_size, self.out_features))
if self.use_cuda:
epsilon_linear = epsilon_linear.cuda()
if batch_size == 1:
epsilon_linear = self.epsilon_linear
out = out_mu + torch.sqrt(out_var) * epsilon_linear
return out
else:
z = self.sample_z(1, kl=False)
weight_mu = z.view(-1, 1) * self.weight_mu
out = torch.matmul(x, weight_mu) + self.bias_mu
return out