in sat/diffusion_video.py [0:0]
def __init__(self, args, **kwargs):
super().__init__()
model_config = args.model_config
# model args preprocess
log_keys = model_config.get("log_keys", None)
input_key = model_config.get("input_key", "mp4")
network_config = model_config.get("network_config", None)
network_wrapper = model_config.get("network_wrapper", None)
denoiser_config = model_config.get("denoiser_config", None)
sampler_config = model_config.get("sampler_config", None)
conditioner_config = model_config.get("conditioner_config", None)
first_stage_config = model_config.get("first_stage_config", None)
loss_fn_config = model_config.get("loss_fn_config", None)
scale_factor = model_config.get("scale_factor", 1.0)
latent_input = model_config.get("latent_input", False)
disable_first_stage_autocast = model_config.get("disable_first_stage_autocast", False)
no_cond_log = model_config.get("disable_first_stage_autocast", False)
not_trainable_prefixes = model_config.get("not_trainable_prefixes", ["first_stage_model", "conditioner"])
compile_model = model_config.get("compile_model", False)
en_and_decode_n_samples_a_time = model_config.get("en_and_decode_n_samples_a_time", None)
lr_scale = model_config.get("lr_scale", None)
lora_train = model_config.get("lora_train", False)
self.use_pd = model_config.get("use_pd", False) # progressive distillation
self.trainable_modules = model_config.get("trainable_modules", ["all"])
self.not_trainable_modules = model_config.get("not_trainable_modules", [])
self.unfreeze_adaLN_gate = model_config.get("unfreeze_adaLN_gate", False)
self.log_keys = log_keys
self.input_key = input_key
self.not_trainable_prefixes = not_trainable_prefixes
self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time
self.lr_scale = lr_scale
self.lora_train = lora_train
self.noised_image_input = model_config.get("noised_image_input", False)
self.noised_image_all_concat = model_config.get("noised_image_all_concat", False)
self.noised_image_dropout = model_config.get("noised_image_dropout", 0.05)
if args.fp16:
dtype = torch.float16
dtype_str = "fp16"
elif args.bf16:
dtype = torch.bfloat16
dtype_str = "bf16"
else:
dtype = torch.float32
dtype_str = "fp32"
self.dtype = dtype
self.dtype_str = dtype_str
network_config["params"]["dtype"] = dtype_str
model = instantiate_from_config(network_config)
self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
model, compile_model=compile_model, dtype=dtype
)
self.denoiser = instantiate_from_config(denoiser_config)
self.sampler = instantiate_from_config(sampler_config) if sampler_config is not None else None
self.conditioner = instantiate_from_config(default(conditioner_config, UNCONDITIONAL_CONFIG))
self._init_first_stage(first_stage_config)
self.loss_fn = instantiate_from_config(loss_fn_config) if loss_fn_config is not None else None
self.latent_input = latent_input
self.scale_factor = scale_factor
self.disable_first_stage_autocast = disable_first_stage_autocast
self.no_cond_log = no_cond_log
self.device = args.device