def forward()

in qlearn/commun/mnf_layer.py [0:0]


    def forward(self, input, kl=True):
        if self.training:
            if kl:
                z, logdets = self.sample_z(kl=True)
            else:
                z = self.sample_z(kl=False)
            weight_std = torch.clamp(torch.exp(self.weight_logstd), 0, self.threshold_var)
            bias_std = torch.clamp(torch.exp(0.5 * self.bias_logvar), 0, self.threshold_var)
            weight_mu = z.view(1, -1) * self.weight_mu
            weight = weight_mu + weight_std * self.epsilon_weight
            bias = self.bias_mu + bias_std * self.epsilon_bias
            out = F.linear(input, weight, bias)
            if not kl:
                return out
            else:
                kldiv_weight = 0.5 * (- 2 * self.weight_logstd + torch.exp(2 * self.weight_logstd)
                                      + weight_mu * weight_mu - 1).sum()
                kldiv_bias = 0.5 * (- self.bias_logvar + torch.exp(self.bias_logvar)
                                    + self.bias_mu * self.bias_mu - 1).sum()
                logq = - 0.5 * self.qzero_logvar.sum()
                logq -= logdets

                cw = F.tanh(torch.matmul(self.rzero_c, weight.t()))

                mu_tilde = torch.mean(self.rzero_b1.ger(cw), dim=1)
                neg_log_var_tilde = torch.mean(self.rzero_b2.ger(cw), dim=1)

                z, logr = self.flow_r(z)

                z_mu_square = (z - mu_tilde) * (z - mu_tilde)
                logr += 0.5 * (- torch.exp(neg_log_var_tilde) * z_mu_square
                               + neg_log_var_tilde).sum()


                kldiv = kldiv_weight + kldiv_bias + logq - logr
                return out, kldiv
        else:
            assert kl == False
            z = self.sample_z(kl=False)
            weight_mu = z.view(1, -1) * self.weight_mu
            out = F.linear(input, weight_mu, self.bias_mu)
            return out