def main()

in avg_checkpoints.py [0:0]


def main():
    args = parser.parse_args()
    # by default use the EMA weights (if present)
    args.use_ema = not args.no_use_ema
    # by default sort by checkpoint metric (if present) and avg top n checkpoints
    args.sort = not args.no_sort

    if args.safetensors and args.output == DEFAULT_OUTPUT:
        # Default path changes if using safetensors
        args.output = DEFAULT_SAFE_OUTPUT

    output, output_ext = os.path.splitext(args.output)
    if not output_ext:
        output_ext = ('.safetensors' if args.safetensors else '.pth')
    output = output + output_ext

    if args.safetensors and not output_ext == ".safetensors":
        print(
            "Warning: saving weights as safetensors but output file extension is not "
            f"set to '.safetensors': {args.output}"
        )

    if os.path.exists(output):
        print("Error: Output filename ({}) already exists.".format(output))
        exit(1)

    pattern = args.input
    if not args.input.endswith(os.path.sep) and not args.filter.startswith(os.path.sep):
        pattern += os.path.sep
    pattern += args.filter
    checkpoints = glob.glob(pattern, recursive=True)

    if args.sort:
        checkpoint_metrics = []
        for c in checkpoints:
            metric = checkpoint_metric(c)
            if metric is not None:
                checkpoint_metrics.append((metric, c))
        checkpoint_metrics = list(sorted(checkpoint_metrics))
        checkpoint_metrics = checkpoint_metrics[-args.n:]
        if checkpoint_metrics:
            print("Selected checkpoints:")
            [print(m, c) for m, c in checkpoint_metrics]
        avg_checkpoints = [c for m, c in checkpoint_metrics]
    else:
        avg_checkpoints = checkpoints
        if avg_checkpoints:
            print("Selected checkpoints:")
            [print(c) for c in checkpoints]

    if not avg_checkpoints:
        print('Error: No checkpoints found to average.')
        exit(1)

    avg_state_dict = {}
    avg_counts = {}
    for c in avg_checkpoints:
        new_state_dict = load_state_dict(c, args.use_ema)
        if not new_state_dict:
            print(f"Error: Checkpoint ({c}) doesn't exist")
            continue
        for k, v in new_state_dict.items():
            if k not in avg_state_dict:
                avg_state_dict[k] = v.clone().to(dtype=torch.float64)
                avg_counts[k] = 1
            else:
                avg_state_dict[k] += v.to(dtype=torch.float64)
                avg_counts[k] += 1

    for k, v in avg_state_dict.items():
        v.div_(avg_counts[k])

    # float32 overflow seems unlikely based on weights seen to date, but who knows
    float32_info = torch.finfo(torch.float32)
    final_state_dict = {}
    for k, v in avg_state_dict.items():
        v = v.clamp(float32_info.min, float32_info.max)
        final_state_dict[k] = v.to(dtype=torch.float32)

    if args.safetensors:
        assert _has_safetensors, "`pip install safetensors` to use .safetensors"
        safetensors.torch.save_file(final_state_dict, output)
    else:
        torch.save(final_state_dict, output)

    with open(output, 'rb') as f:
        sha_hash = hashlib.sha256(f.read()).hexdigest()
    print(f"=> Saved state_dict to '{output}, SHA256: {sha_hash}'")