scripts/compute_offline_ema.py (42 lines of code) (raw):
import argparse
import os
from pathlib import Path
import torch
from muse import EMAModel, MaskGitTransformer, MaskGiTUViT
def offline_ema(args):
checkpoint_dir_path = args.checkpoint_dir_path
ema_save_path = args.ema_save_path
ema_decay = args.ema_decay
checkpoint_interval = args.checkpoint_interval
dirs = os.listdir(checkpoint_dir_path)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
dirs = [Path(checkpoint_dir_path) / dir_ for dir_ in dirs]
transformer_config = MaskGitTransformer.load_config(Path(checkpoint_dir_path) / dirs[0] / "unwrapped_model")
if transformer_config["_class_name"] == "MaskGitTransformer":
model_cls = MaskGitTransformer
elif transformer_config["_class_name"] == "MaskGiTUViT":
model_cls = MaskGiTUViT
device = "cpu"
if torch.cuda.is_available():
device = "cuda"
model = model_cls.from_pretrained(Path(checkpoint_dir_path) / dirs[0] / "unwrapped_model").to(device)
ema_model = EMAModel(parameters=model.parameters(), decay=ema_decay, update_every=checkpoint_interval)
ema_model.to(device)
end_step = int(str(dirs[-1]).split("-")[-1])
for step in range(0, end_step):
if (step + 1) % checkpoint_interval == 0:
print(f"Loading checkpoint {step + 1}...")
model = model_cls.from_pretrained(Path(checkpoint_dir_path) / f"checkpoint-{step + 1}" / "unwrapped_model")
model.to(device)
ema_model.step(model.parameters())
ema_model.copy_to(model.parameters())
model.save_pretrained(ema_save_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint_dir_path", type=str, default=None, required=True)
parser.add_argument("--ema_save_path", type=str, default=None, required=True)
parser.add_argument("--ema_decay", type=float, default=0.9999)
parser.add_argument("--checkpoint_interval", type=int, default=1000)
args = parser.parse_args()
offline_ema(args)