in avg_checkpoints.py [0:0]
def main():
args = parser.parse_args()
# by default use the EMA weights (if present)
args.use_ema = not args.no_use_ema
# by default sort by checkpoint metric (if present) and avg top n checkpoints
args.sort = not args.no_sort
if args.safetensors and args.output == DEFAULT_OUTPUT:
# Default path changes if using safetensors
args.output = DEFAULT_SAFE_OUTPUT
output, output_ext = os.path.splitext(args.output)
if not output_ext:
output_ext = ('.safetensors' if args.safetensors else '.pth')
output = output + output_ext
if args.safetensors and not output_ext == ".safetensors":
print(
"Warning: saving weights as safetensors but output file extension is not "
f"set to '.safetensors': {args.output}"
)
if os.path.exists(output):
print("Error: Output filename ({}) already exists.".format(output))
exit(1)
pattern = args.input
if not args.input.endswith(os.path.sep) and not args.filter.startswith(os.path.sep):
pattern += os.path.sep
pattern += args.filter
checkpoints = glob.glob(pattern, recursive=True)
if args.sort:
checkpoint_metrics = []
for c in checkpoints:
metric = checkpoint_metric(c)
if metric is not None:
checkpoint_metrics.append((metric, c))
checkpoint_metrics = list(sorted(checkpoint_metrics))
checkpoint_metrics = checkpoint_metrics[-args.n:]
if checkpoint_metrics:
print("Selected checkpoints:")
[print(m, c) for m, c in checkpoint_metrics]
avg_checkpoints = [c for m, c in checkpoint_metrics]
else:
avg_checkpoints = checkpoints
if avg_checkpoints:
print("Selected checkpoints:")
[print(c) for c in checkpoints]
if not avg_checkpoints:
print('Error: No checkpoints found to average.')
exit(1)
avg_state_dict = {}
avg_counts = {}
for c in avg_checkpoints:
new_state_dict = load_state_dict(c, args.use_ema)
if not new_state_dict:
print(f"Error: Checkpoint ({c}) doesn't exist")
continue
for k, v in new_state_dict.items():
if k not in avg_state_dict:
avg_state_dict[k] = v.clone().to(dtype=torch.float64)
avg_counts[k] = 1
else:
avg_state_dict[k] += v.to(dtype=torch.float64)
avg_counts[k] += 1
for k, v in avg_state_dict.items():
v.div_(avg_counts[k])
# float32 overflow seems unlikely based on weights seen to date, but who knows
float32_info = torch.finfo(torch.float32)
final_state_dict = {}
for k, v in avg_state_dict.items():
v = v.clamp(float32_info.min, float32_info.max)
final_state_dict[k] = v.to(dtype=torch.float32)
if args.safetensors:
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
safetensors.torch.save_file(final_state_dict, output)
else:
torch.save(final_state_dict, output)
with open(output, 'rb') as f:
sha_hash = hashlib.sha256(f.read()).hexdigest()
print(f"=> Saved state_dict to '{output}, SHA256: {sha_hash}'")