in trainers/train_simplexes.py [0:0]
def get_stats(model):
norms = {}
numerators = {}
difs = {}
cossim = 0
l2 = 0
nc2 = (args.n * (args.n - 1)) / 2.0
for i in range(args.n):
norms[f"{i}"] = 0.0
for j in range(i + 1, args.n):
numerators[f"{i}-{j}"] = 0.0
difs[f"{i}-{j}"] = 0.0
for m in model.modules():
if isinstance(m, nn.Conv2d):
for i in range(args.n):
vi = get_weight(m, i)
norms[f"{i}"] += vi.pow(2).sum()
for j in range(i + 1, args.n):
vj = get_weight(m, j)
numerators[f"{i}-{j}"] += (vi * vj).sum()
difs[f"{i}-{j}"] += (vi - vj).pow(2).sum()
for i in range(args.n):
for j in range(i + 1, args.n):
cossim += (1.0 / nc2) * (
(
numerators[f"{i}-{j}"].pow(2)
/ (norms[f"{i}"] * norms[f"{j}"])
)
)
l2 += (1.0 / nc2) * difs[f"{i}-{j}"]
l2 = l2.pow(0.5).item()
cossim = cossim.item()
return cossim, l2