in ml3/mbrl_utils.py [0:0]
def forward(self, net_output, target):
assert net_output.dim() == 3
assert net_output.size(0) == 2
mean = net_output[0]
var = net_output[1]
reduction = "mean"
ret = 0.5 * torch.log(var) + 0.5 * ((mean - target) ** 2) / var
# ret = 0.5 * ((mean - target) ** 2)
if reduction != "none":
ret = torch.mean(ret) if reduction == "mean" else torch.sum(ret)
return ret