in code/experiment_synthetic/main.py [0:0]
def errors(w, w_hat):
w = w.view(-1)
w_hat = w_hat.view(-1)
i_causal = torch.where(w != 0)[0].view(-1)
i_noncausal = torch.where(w == 0)[0].view(-1)
if len(i_causal):
error_causal = (w[i_causal] - w_hat[i_causal]).pow(2).mean()
error_causal = error_causal.item()
else:
error_causal = 0
if len(i_noncausal):
error_noncausal = (w[i_noncausal] - w_hat[i_noncausal]).pow(2).mean()
error_noncausal = error_noncausal.item()
else:
error_noncausal = 0
return error_causal, error_noncausal