def forward()

in ml3/mbrl_utils.py [0:0]


    def forward(self, net_output, target):
        assert net_output.dim() == 3
        assert net_output.size(0) == 2
        mean = net_output[0]
        var = net_output[1]
        reduction = "mean"
        ret = 0.5 * torch.log(var) + 0.5 * ((mean - target) ** 2) / var
        # ret = 0.5 * ((mean - target) ** 2)

        if reduction != "none":
            ret = torch.mean(ret) if reduction == "mean" else torch.sum(ret)
        return ret