in utils_data.py [0:0]
def jit_loss_fn(statistic_fn, norm=None, lambda_l1=0):
if norm == "L2":
ord_norm = 2
elif norm == "Linfty":
ord_norm = np.inf
else:
ord_norm = 5
@jit
def compute_loss_fn(Dprime, target_statistics):
if norm == "LogExp":
return np.log(
np.exp(statistic_fn(Dprime) - target_statistics).sum()
) + lambda_l1 * np.linalg.norm(Dprime, 1)
else:
return np.linalg.norm(
statistic_fn(Dprime) - target_statistics, ord=ord_norm
) + lambda_l1 * np.linalg.norm(Dprime, 1)
return compute_loss_fn