def offline_ema()

in scripts/compute_offline_ema.py [0:0]


def offline_ema(args):
    checkpoint_dir_path = args.checkpoint_dir_path
    ema_save_path = args.ema_save_path
    ema_decay = args.ema_decay
    checkpoint_interval = args.checkpoint_interval

    dirs = os.listdir(checkpoint_dir_path)
    dirs = [d for d in dirs if d.startswith("checkpoint")]
    dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
    dirs = [Path(checkpoint_dir_path) / dir_ for dir_ in dirs]

    transformer_config = MaskGitTransformer.load_config(Path(checkpoint_dir_path) / dirs[0] / "unwrapped_model")
    if transformer_config["_class_name"] == "MaskGitTransformer":
        model_cls = MaskGitTransformer
    elif transformer_config["_class_name"] == "MaskGiTUViT":
        model_cls = MaskGiTUViT

    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda"

    model = model_cls.from_pretrained(Path(checkpoint_dir_path) / dirs[0] / "unwrapped_model").to(device)
    ema_model = EMAModel(parameters=model.parameters(), decay=ema_decay, update_every=checkpoint_interval)
    ema_model.to(device)

    end_step = int(str(dirs[-1]).split("-")[-1])
    for step in range(0, end_step):
        if (step + 1) % checkpoint_interval == 0:
            print(f"Loading checkpoint {step + 1}...")
            model = model_cls.from_pretrained(Path(checkpoint_dir_path) / f"checkpoint-{step + 1}" / "unwrapped_model")
            model.to(device)

        ema_model.step(model.parameters())

    ema_model.copy_to(model.parameters())
    model.save_pretrained(ema_save_path)