def check_correctness()

in fast_grad/helpers.py [0:0]


def check_correctness(full, names, approximations, model, X, y):
	print()
	print("  Checking correctness")
	print("  ---")

	true_value = parameters_to_vector(full(model, X, y))
	approx_values = list()
	for i in range(len(approximations)):
		approx_value = batch_grads_to_vec(approximations[i](model, X, y))
		approx_values.append(approx_value)
		#pdb.set_trace()
		print("  - Diff. to full batch for (%5s):        %f" % (names[i], torch.norm(true_value - torch.mean(approx_value, dim=0))))
	for i in range(len(approximations)):
		for j in range(i):
			if i != j:
				print("  - Difference between (%5s) and (%5s): %f" % (names[i], names[j], torch.norm(approx_values[i] - approx_values[j])))