def jit_loss_fn()

in utils_data.py [0:0]


def jit_loss_fn(statistic_fn, norm=None, lambda_l1=0):

    if norm == "L2":
        ord_norm = 2
    elif norm == "Linfty":
        ord_norm = np.inf
    else:
        ord_norm = 5

    @jit
    def compute_loss_fn(Dprime, target_statistics):
        if norm == "LogExp":
            return np.log(
                np.exp(statistic_fn(Dprime) - target_statistics).sum()
            ) + lambda_l1 * np.linalg.norm(Dprime, 1)
        else:
            return np.linalg.norm(
                statistic_fn(Dprime) - target_statistics, ord=ord_norm
            ) + lambda_l1 * np.linalg.norm(Dprime, 1)

    return compute_loss_fn