in optimum/intel/openvino/modeling_diffusion.py [0:0]
def reshape(self, batch_size: int, height: int, width: int, num_images_per_prompt: int = -1, num_frames: int = -1):
if self._compile_only:
raise ValueError(
"`reshape()` is not supported with `compile_only` mode, please initialize model without this option"
)
self.is_dynamic = -1 in {batch_size, height, width, num_images_per_prompt}
if self.tokenizer is None and self.tokenizer_2 is None:
tokenizer_max_len = -1
else:
if self.tokenizer is not None and "Gemma" in self.tokenizer.__class__.__name__:
tokenizer_max_len = -1
else:
tokenizer_max_len = (
getattr(self.tokenizer, "model_max_length", -1)
if self.tokenizer is not None
else getattr(self.tokenizer_2, "model_max_length", -1)
)
if self.unet is not None:
self.unet.model = self._reshape_unet(
self.unet.model, batch_size, height, width, num_images_per_prompt, tokenizer_max_len
)
if self.transformer is not None:
self.transformer.model = self._reshape_transformer(
self.transformer.model,
batch_size,
height,
width,
num_images_per_prompt,
tokenizer_max_len,
num_frames=num_frames,
)
self.vae_decoder.model = self._reshape_vae_decoder(
self.vae_decoder.model, height, width, num_images_per_prompt, num_frames=num_frames
)
if self.vae_encoder is not None:
self.vae_encoder.model = self._reshape_vae_encoder(
self.vae_encoder.model, batch_size, height, width, num_frames=num_frames
)
if self.text_encoder is not None:
self.text_encoder.model = self._reshape_text_encoder(
# GemmaTokenizer uses inf as model_max_length, Text Encoder in LTX do not pad input to model_max_length
self.text_encoder.model,
batch_size,
(
getattr(self.tokenizer, "model_max_length", -1)
if "Gemma" not in self.tokenizer.__class__.__name__
and not self.__class__.__name__.startswith("OVLTX")
else -1
),
)
if self.text_encoder_2 is not None:
self.text_encoder_2.model = self._reshape_text_encoder(
self.text_encoder_2.model, batch_size, getattr(self.tokenizer_2, "model_max_length", -1)
)
if self.text_encoder_3 is not None:
self.text_encoder_3.model = self._reshape_text_encoder(self.text_encoder_3.model, batch_size, -1)
self.clear_requests()
return self