benchmark/model_quality.py (115 lines of code) (raw):
import matplotlib.pyplot as plt
import argparse
cfg = [
1.5,
2.0,
3.0,
4.0,
5.0,
6.0,
7.0,
8.0,
10.0,
15.0,
20.0,
]
fid_512 = [
56.13683,
48.3625,
43.13792,
42.07286,
41.21331,
41.21309,
40.76164,
40.51427,
40.22781,
39.66504,
38.57083,
]
clip_512 = [
23.168075,
24.3268,
25.29295,
25.67775,
25.93075,
26.068925,
26.15145,
26.151175,
26.26665,
26.3845,
26.402225,
]
isc_512 = [
20.32828279489911,
23.092083811105134,
25.34707454898865,
25.782333543568505,
26.779519535473717,
26.72532414371535,
26.8378182891666,
27.02354446351334,
27.235757940256587,
27.461719798190302,
27.37252925955596,
]
fid_256 = [43.64503, 40.57112, 39.38306, 39.29915, 40.10225, 41.97274, 45.10721, 49.11104, 59.13854, 81.46585, 96.3426]
clip_256 = [
24.191875,
25.035825,
25.689725,
26.0217,
26.1032,
26.048225,
25.90045,
25.691,
25.319,
24.49525,
23.915725,
]
isc_256 = [
21.247120913990408,
23.008063867685443,
23.49288416726619,
24.13530452474164,
23.197031957136875,
21.741427950979876,
20.435789339047123,
18.84057076723702,
15.793238717380486,
10.74857386855099,
8.62769427725863,
]
if __name__ == "__main__":
args = argparse.ArgumentParser()
args.add_argument("--fid", action="store_true")
args.add_argument("--isc", action="store_true")
args.add_argument("--clip", action="store_true")
args = args.parse_args()
if args.fid:
plt.title(f"FID")
plt.ylabel("FID Score (10k)")
plt.plot(cfg, fid_256, marker="o", label="muse-256")
plt.plot(cfg, fid_512, marker="o", label="muse-512")
elif args.isc:
plt.title(f"Inception Score")
plt.ylabel("Inception Score (10k)")
plt.plot(cfg, isc_256, marker="o", label="muse-256")
plt.plot(cfg, isc_512, marker="o", label="muse-512")
elif args.clip:
plt.title(f"CLIP Score")
plt.ylabel("CLIP Score (10k)")
plt.plot(cfg, clip_256, marker="o", label="muse-256")
plt.plot(cfg, clip_512, marker="o", label="muse-512")
else:
assert False
plt.xlabel("cfg scale")
plt.legend()
# Show grid (optional)
plt.grid(True)
# Display the plot
if args.fid:
plt.savefig("./benchmark/artifacts/fid.png")
elif args.isc:
plt.savefig("./benchmark/artifacts/isc.png")
elif args.clip:
plt.savefig("./benchmark/artifacts/clip.png")
else:
assert False