in lib/torch_util.py [0:0]
def load_average_with_metadata(paths, overrides):
n_models = len(paths)
model, metadata = load_with_metadata(paths[0], overrides=overrides)
for p in model.parameters():
p.mul_(1 / n_models)
for p in paths[1:]:
new_model, _ = load_with_metadata(p, overrides=overrides)
for (n1, p1), (n2, p2) in misc.safezip(model.named_parameters(), new_model.named_parameters()):
assert n1 == n2, f"names {n1} and {n2} don't match"
p1.add_(p2.mul_(1 / n_models))
return model, metadata