in qlearn/commun/mnf_layer.py [0:0]
def forward(self, input, kl=True):
if self.training:
if kl:
z, logdets = self.sample_z(kl=True)
else:
z = self.sample_z(kl=False)
weight_std = torch.clamp(torch.exp(self.weight_logstd), 0, self.threshold_var)
bias_std = torch.clamp(torch.exp(0.5 * self.bias_logvar), 0, self.threshold_var)
weight_mu = z.view(1, -1) * self.weight_mu
weight = weight_mu + weight_std * self.epsilon_weight
bias = self.bias_mu + bias_std * self.epsilon_bias
out = F.linear(input, weight, bias)
if not kl:
return out
else:
kldiv_weight = 0.5 * (- 2 * self.weight_logstd + torch.exp(2 * self.weight_logstd)
+ weight_mu * weight_mu - 1).sum()
kldiv_bias = 0.5 * (- self.bias_logvar + torch.exp(self.bias_logvar)
+ self.bias_mu * self.bias_mu - 1).sum()
logq = - 0.5 * self.qzero_logvar.sum()
logq -= logdets
cw = F.tanh(torch.matmul(self.rzero_c, weight.t()))
mu_tilde = torch.mean(self.rzero_b1.ger(cw), dim=1)
neg_log_var_tilde = torch.mean(self.rzero_b2.ger(cw), dim=1)
z, logr = self.flow_r(z)
z_mu_square = (z - mu_tilde) * (z - mu_tilde)
logr += 0.5 * (- torch.exp(neg_log_var_tilde) * z_mu_square
+ neg_log_var_tilde).sum()
kldiv = kldiv_weight + kldiv_bias + logq - logr
return out, kldiv
else:
assert kl == False
z = self.sample_z(kl=False)
weight_mu = z.view(1, -1) * self.weight_mu
out = F.linear(input, weight_mu, self.bias_mu)
return out