in threestudio/models/guidance/stable_diffusion_bsd_guidance.py [0:0]
def configure(self) -> None:
threestudio.info(f"Loading Stable Diffusion ...")
self.weights_dtype = (
torch.float16 if self.cfg.half_precision_weights else torch.float32
)
pipe_kwargs = {
"tokenizer": None,
"safety_checker": None,
"feature_extractor": None,
"requires_safety_checker": False,
"torch_dtype": self.weights_dtype,
"cache_dir": self.cfg.cache_dir,
"local_files_only": self.cfg.local_files_only
}
pipe_lora_kwargs = {
"tokenizer": None,
"safety_checker": None,
"feature_extractor": None,
"requires_safety_checker": False,
"torch_dtype": self.weights_dtype,
"cache_dir": self.cfg.cache_dir,
"local_files_only": self.cfg.local_files_only
}
@dataclass
class SubModules:
pipe: StableDiffusionPipeline
pipe_lora: StableDiffusionPipeline
pipe_fix: StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained(
self.cfg.pretrained_model_name_or_path,
**pipe_kwargs,
).to(self.device)
self.single_model = False
pipe_lora = StableDiffusionPipeline.from_pretrained(
self.cfg.pretrained_model_name_or_path_lora,
**pipe_lora_kwargs,
).to(self.device)
del pipe_lora.vae
cleanup()
pipe_lora.vae = pipe.vae
pipe_fix = pipe
self.submodules = SubModules(pipe=pipe, pipe_lora=pipe_lora, pipe_fix=pipe_fix)
if self.cfg.enable_memory_efficient_attention:
if parse_version(torch.__version__) >= parse_version("2"):
threestudio.info(
"PyTorch2.0 uses memory efficient attention by default."
)
elif not is_xformers_available():
threestudio.warn(
"xformers is not available, memory efficient attention is not enabled."
)
else:
self.pipe.enable_xformers_memory_efficient_attention()
self.pipe_lora.enable_xformers_memory_efficient_attention()
if self.cfg.enable_sequential_cpu_offload:
self.pipe.enable_sequential_cpu_offload()
self.pipe_lora.enable_sequential_cpu_offload()
if self.cfg.enable_attention_slicing:
self.pipe.enable_attention_slicing(1)
self.pipe_lora.enable_attention_slicing(1)
if self.cfg.enable_channels_last_format:
self.pipe.unet.to(memory_format=torch.channels_last)
self.pipe_lora.unet.to(memory_format=torch.channels_last)
del self.pipe.text_encoder
if not self.single_model:
del self.pipe_lora.text_encoder
cleanup()
for p in self.vae.parameters():
p.requires_grad_(False)
for p in self.vae_fix.parameters():
p.requires_grad_(False)
for p in self.unet_fix.parameters():
p.requires_grad_(False)
# FIXME: hard-coded dims
self.camera_embedding = ToWeightsDType(
TimestepEmbedding(16, 1280), self.weights_dtype
).to(self.device)
# self.unet_lora.class_embedding = self.camera_embedding
# set up LoRA layers
# self.set_up_lora_layers(self.unet_lora)
# self.lora_layers = AttnProcsLayers(self.unet_lora.attn_processors).to(
# self.device
# )
# self.lora_layers._load_state_dict_pre_hooks.clear()
# self.lora_layers._state_dict_hooks.clear()
# set up LoRA layers for pretrain
# self.set_up_lora_layers(self.unet)
# self.lora_layers_pretrain = AttnProcsLayers(self.unet.attn_processors).to(
# self.device
# )
# self.lora_layers_pretrain._load_state_dict_pre_hooks.clear()
# self.lora_layers_pretrain._state_dict_hooks.clear()
self.train_unet = UNet2DConditionModel.from_pretrained(
self.cfg.pretrained_model_name_or_path, subfolder="unet",
torch_dtype=self.weights_dtype
)
self.train_unet.enable_xformers_memory_efficient_attention()
self.train_unet.enable_gradient_checkpointing()
self.train_unet_lora = UNet2DConditionModel.from_pretrained(
self.cfg.pretrained_model_name_or_path_lora, subfolder="unet",
torch_dtype=self.weights_dtype
)
self.train_unet_lora.enable_xformers_memory_efficient_attention()
self.train_unet_lora.enable_gradient_checkpointing()
for p in self.train_unet.parameters():
p.requires_grad_(True)
for p in self.train_unet_lora.parameters():
p.requires_grad_(True)
# for p in self.lora_layers.parameters():
# p.requires_grad_(False)
self.scheduler = DDPMScheduler.from_pretrained( # DDPM
self.cfg.pretrained_model_name_or_path,
subfolder="scheduler",
torch_dtype=self.weights_dtype,
cache_dir=self.cfg.cache_dir,
local_files_only=self.cfg.local_files_only,
)
self.scheduler_lora = DDPMScheduler.from_pretrained(
self.cfg.pretrained_model_name_or_path_lora,
subfolder="scheduler",
torch_dtype=self.weights_dtype,
cache_dir=self.cfg.cache_dir,
local_files_only=self.cfg.local_files_only,
)
self.scheduler_sample = DPMSolverMultistepScheduler.from_config(
self.pipe.scheduler.config
)
self.scheduler_lora_sample = DPMSolverMultistepScheduler.from_config(
self.pipe_lora.scheduler.config
)
self.pipe.scheduler = self.scheduler
self.pipe_lora.scheduler = self.scheduler_lora
self.pipe_fix.scheduler = self.scheduler
self.num_train_timesteps = self.scheduler.config.num_train_timesteps
self.set_min_max_steps() # set to default value
self.alphas: Float[Tensor, "..."] = self.scheduler.alphas_cumprod.to(
self.device
)
self.scheduler.alphas_cumprod = self.scheduler.alphas_cumprod.to(self.device)
self.grad_clip_val: Optional[float] = None
if self.cfg.use_du:
self.perceptual_loss = PerceptualLoss().eval().to(self.device)
for p in self.perceptual_loss.parameters():
p.requires_grad_(False)
self.cache_frames = []
threestudio.info(f"Loaded Stable Diffusion!")