def load_average_with_metadata()

in lib/torch_util.py [0:0]


def load_average_with_metadata(paths, overrides):
    n_models = len(paths)
    model, metadata = load_with_metadata(paths[0], overrides=overrides)
    for p in model.parameters():
        p.mul_(1 / n_models)
    for p in paths[1:]:
        new_model, _ = load_with_metadata(p, overrides=overrides)
        for (n1, p1), (n2, p2) in misc.safezip(model.named_parameters(), new_model.named_parameters()):
            assert n1 == n2, f"names {n1} and {n2} don't match"
            p1.add_(p2.mul_(1 / n_models))
    return model, metadata