in qlearn/commun/norm_flows.py [0:0]
def forward(self, z, kl=True):
if kl:
if z.dim() == 1:
logdets = 0
else:
logdets = Variable(torch.zeros_like(z[:, 0]))
for flow in self.flow_list:
z, logdet = flow(z, kl=True)
logdets += logdet
return z, logdets
else:
for flow in self.flow_list:
z = flow(z, kl=False)
return z