in curve_utils.py [0:0]
def get_stats(model):
norms = {}
numerators = {}
difs = {}
cossim = 0
l2 = 0
num_points = 2
for i in range(num_points):
norms[f"{i}"] = 0.0
for j in range(i + 1, num_points):
numerators[f"{i}-{j}"] = 0.0
difs[f"{i}-{j}"] = 0.0
for m in model.modules():
if isinstance(m, nn.Conv2d):
for i in range(num_points):
vi = get_weight(m, i)
norms[f"{i}"] += vi.pow(2).sum()
for j in range(i + 1, num_points):
vj = get_weight(m, j)
numerators[f"{i}-{j}"] += (vi * vj).sum()
difs[f"{i}-{j}"] += (vi - vj).pow(2).sum()
for i in range(num_points):
for j in range(i + 1, num_points):
cossim += numerators[f"{i}-{j}"].pow(2) / (
norms[f"{i}"] * norms[f"{j}"]
)
l2 += difs[f"{i}-{j}"]
l2 = l2.pow(0.5).item()
cossim = cossim.item()
return cossim, l2