def get_stats()

in trainers/train_simplexes.py [0:0]


def get_stats(model):
    norms = {}
    numerators = {}
    difs = {}
    cossim = 0
    l2 = 0
    nc2 = (args.n * (args.n - 1)) / 2.0
    for i in range(args.n):
        norms[f"{i}"] = 0.0
        for j in range(i + 1, args.n):
            numerators[f"{i}-{j}"] = 0.0
            difs[f"{i}-{j}"] = 0.0

    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            for i in range(args.n):
                vi = get_weight(m, i)
                norms[f"{i}"] += vi.pow(2).sum()
                for j in range(i + 1, args.n):
                    vj = get_weight(m, j)
                    numerators[f"{i}-{j}"] += (vi * vj).sum()
                    difs[f"{i}-{j}"] += (vi - vj).pow(2).sum()

    for i in range(args.n):
        for j in range(i + 1, args.n):
            cossim += (1.0 / nc2) * (
                (
                    numerators[f"{i}-{j}"].pow(2)
                    / (norms[f"{i}"] * norms[f"{j}"])
                )
            )
            l2 += (1.0 / nc2) * difs[f"{i}-{j}"]

    l2 = l2.pow(0.5).item()
    cossim = cossim.item()
    return cossim, l2