in cm/karras_diffusion.py [0:0]
def get_weightings(weight_schedule, snrs, sigma_data):
if weight_schedule == "snr":
weightings = snrs
elif weight_schedule == "snr+1":
weightings = snrs + 1
elif weight_schedule == "karras":
weightings = snrs + 1.0 / sigma_data**2
elif weight_schedule == "truncated-snr":
weightings = th.clamp(snrs, min=1.0)
elif weight_schedule == "uniform":
weightings = th.ones_like(snrs)
else:
raise NotImplementedError()
return weightings