def explained_variance()

in ppo_ewma/torch_util.py [0:0]


def explained_variance(ypred: th.Tensor, y: th.Tensor, comm: MPI.Comm = None) -> float:
    """
    Computes fraction of variance that ypred explains about y.
    Returns 1 - Var[y-ypred] / Var[y]
 
    interpretation:
        ev=0  =>  might as well have predicted zero
        ev=1  =>  perfect prediction
        ev<0  =>  worse than just predicting zero    
    """
    assert ypred.shape == y.shape
    err = y - ypred
    if comm is None:
        var_y = float(y.var())
        var_err = float(err.var())
    else:
        _, var_y = mpi_moments(comm, y)
        _, var_err = mpi_moments(comm, err)
    if var_y == 0:
        return float("nan")
    else:
        return 1.0 - var_err / var_y