def get_stats()

in curve_utils.py [0:0]


def get_stats(model):
    norms = {}
    numerators = {}
    difs = {}
    cossim = 0
    l2 = 0
    num_points = 2

    for i in range(num_points):
        norms[f"{i}"] = 0.0
        for j in range(i + 1, num_points):
            numerators[f"{i}-{j}"] = 0.0
            difs[f"{i}-{j}"] = 0.0

    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            for i in range(num_points):
                vi = get_weight(m, i)
                norms[f"{i}"] += vi.pow(2).sum()
                for j in range(i + 1, num_points):
                    vj = get_weight(m, j)
                    numerators[f"{i}-{j}"] += (vi * vj).sum()
                    difs[f"{i}-{j}"] += (vi - vj).pow(2).sum()

    for i in range(num_points):
        for j in range(i + 1, num_points):
            cossim += numerators[f"{i}-{j}"].pow(2) / (
                norms[f"{i}"] * norms[f"{j}"]
            )
            l2 += difs[f"{i}-{j}"]

    l2 = l2.pow(0.5).item()
    cossim = cossim.item()
    return cossim, l2