cm/script_util.py (215 lines of code) (raw):

import argparse from .karras_diffusion import KarrasDenoiser from .unet import UNetModel import numpy as np NUM_CLASSES = 1000 def cm_train_defaults(): return dict( teacher_model_path="", teacher_dropout=0.1, training_mode="consistency_distillation", target_ema_mode="fixed", scale_mode="fixed", total_training_steps=600000, start_ema=0.0, start_scales=40, end_scales=40, distill_steps_per_iter=50000, loss_norm="lpips", ) def model_and_diffusion_defaults(): """ Defaults for image training. """ res = dict( sigma_min=0.002, sigma_max=80.0, image_size=64, num_channels=128, num_res_blocks=2, num_heads=4, num_heads_upsample=-1, num_head_channels=-1, attention_resolutions="32,16,8", channel_mult="", dropout=0.0, class_cond=False, use_checkpoint=False, use_scale_shift_norm=True, resblock_updown=False, use_fp16=False, use_new_attention_order=False, learn_sigma=False, weight_schedule="karras", ) return res def create_model_and_diffusion( image_size, class_cond, learn_sigma, num_channels, num_res_blocks, channel_mult, num_heads, num_head_channels, num_heads_upsample, attention_resolutions, dropout, use_checkpoint, use_scale_shift_norm, resblock_updown, use_fp16, use_new_attention_order, weight_schedule, sigma_min=0.002, sigma_max=80.0, distillation=False, ): model = create_model( image_size, num_channels, num_res_blocks, channel_mult=channel_mult, learn_sigma=learn_sigma, class_cond=class_cond, use_checkpoint=use_checkpoint, attention_resolutions=attention_resolutions, num_heads=num_heads, num_head_channels=num_head_channels, num_heads_upsample=num_heads_upsample, use_scale_shift_norm=use_scale_shift_norm, dropout=dropout, resblock_updown=resblock_updown, use_fp16=use_fp16, use_new_attention_order=use_new_attention_order, ) diffusion = KarrasDenoiser( sigma_data=0.5, sigma_max=sigma_max, sigma_min=sigma_min, distillation=distillation, weight_schedule=weight_schedule, ) return model, diffusion def create_model( image_size, num_channels, num_res_blocks, channel_mult="", learn_sigma=False, class_cond=False, use_checkpoint=False, attention_resolutions="16", num_heads=1, num_head_channels=-1, num_heads_upsample=-1, use_scale_shift_norm=False, dropout=0, resblock_updown=False, use_fp16=False, use_new_attention_order=False, ): if channel_mult == "": if image_size == 512: channel_mult = (0.5, 1, 1, 2, 2, 4, 4) elif image_size == 256: channel_mult = (1, 1, 2, 2, 4, 4) elif image_size == 128: channel_mult = (1, 1, 2, 3, 4) elif image_size == 64: channel_mult = (1, 2, 3, 4) else: raise ValueError(f"unsupported image size: {image_size}") else: channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(",")) attention_ds = [] for res in attention_resolutions.split(","): attention_ds.append(image_size // int(res)) return UNetModel( image_size=image_size, in_channels=3, model_channels=num_channels, out_channels=(3 if not learn_sigma else 6), num_res_blocks=num_res_blocks, attention_resolutions=tuple(attention_ds), dropout=dropout, channel_mult=channel_mult, num_classes=(NUM_CLASSES if class_cond else None), use_checkpoint=use_checkpoint, use_fp16=use_fp16, num_heads=num_heads, num_head_channels=num_head_channels, num_heads_upsample=num_heads_upsample, use_scale_shift_norm=use_scale_shift_norm, resblock_updown=resblock_updown, use_new_attention_order=use_new_attention_order, ) def create_ema_and_scales_fn( target_ema_mode, start_ema, scale_mode, start_scales, end_scales, total_steps, distill_steps_per_iter, ): def ema_and_scales_fn(step): if target_ema_mode == "fixed" and scale_mode == "fixed": target_ema = start_ema scales = start_scales elif target_ema_mode == "fixed" and scale_mode == "progressive": target_ema = start_ema scales = np.ceil( np.sqrt( (step / total_steps) * ((end_scales + 1) ** 2 - start_scales**2) + start_scales**2 ) - 1 ).astype(np.int32) scales = np.maximum(scales, 1) scales = scales + 1 elif target_ema_mode == "adaptive" and scale_mode == "progressive": scales = np.ceil( np.sqrt( (step / total_steps) * ((end_scales + 1) ** 2 - start_scales**2) + start_scales**2 ) - 1 ).astype(np.int32) scales = np.maximum(scales, 1) c = -np.log(start_ema) * start_scales target_ema = np.exp(-c / scales) scales = scales + 1 elif target_ema_mode == "fixed" and scale_mode == "progdist": distill_stage = step // distill_steps_per_iter scales = start_scales // (2**distill_stage) scales = np.maximum(scales, 2) sub_stage = np.maximum( step - distill_steps_per_iter * (np.log2(start_scales) - 1), 0, ) sub_stage = sub_stage // (distill_steps_per_iter * 2) sub_scales = 2 // (2**sub_stage) sub_scales = np.maximum(sub_scales, 1) scales = np.where(scales == 2, sub_scales, scales) target_ema = 1.0 else: raise NotImplementedError return float(target_ema), int(scales) return ema_and_scales_fn def add_dict_to_argparser(parser, default_dict): for k, v in default_dict.items(): v_type = type(v) if v is None: v_type = str elif isinstance(v, bool): v_type = str2bool parser.add_argument(f"--{k}", default=v, type=v_type) def args_to_dict(args, keys): return {k: getattr(args, k) for k in keys} def str2bool(v): """ https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse """ if isinstance(v, bool): return v if v.lower() in ("yes", "true", "t", "y", "1"): return True elif v.lower() in ("no", "false", "f", "n", "0"): return False else: raise argparse.ArgumentTypeError("boolean value expected")