def reshape()

in optimum/intel/openvino/modeling_diffusion.py [0:0]


    def reshape(self, batch_size: int, height: int, width: int, num_images_per_prompt: int = -1, num_frames: int = -1):
        if self._compile_only:
            raise ValueError(
                "`reshape()` is not supported with `compile_only` mode, please initialize model without this option"
            )

        self.is_dynamic = -1 in {batch_size, height, width, num_images_per_prompt}

        if self.tokenizer is None and self.tokenizer_2 is None:
            tokenizer_max_len = -1
        else:
            if self.tokenizer is not None and "Gemma" in self.tokenizer.__class__.__name__:
                tokenizer_max_len = -1
            else:
                tokenizer_max_len = (
                    getattr(self.tokenizer, "model_max_length", -1)
                    if self.tokenizer is not None
                    else getattr(self.tokenizer_2, "model_max_length", -1)
                )

        if self.unet is not None:
            self.unet.model = self._reshape_unet(
                self.unet.model, batch_size, height, width, num_images_per_prompt, tokenizer_max_len
            )
        if self.transformer is not None:
            self.transformer.model = self._reshape_transformer(
                self.transformer.model,
                batch_size,
                height,
                width,
                num_images_per_prompt,
                tokenizer_max_len,
                num_frames=num_frames,
            )
        self.vae_decoder.model = self._reshape_vae_decoder(
            self.vae_decoder.model, height, width, num_images_per_prompt, num_frames=num_frames
        )

        if self.vae_encoder is not None:
            self.vae_encoder.model = self._reshape_vae_encoder(
                self.vae_encoder.model, batch_size, height, width, num_frames=num_frames
            )

        if self.text_encoder is not None:
            self.text_encoder.model = self._reshape_text_encoder(
                # GemmaTokenizer uses inf as model_max_length, Text Encoder in LTX do not pad input to model_max_length
                self.text_encoder.model,
                batch_size,
                (
                    getattr(self.tokenizer, "model_max_length", -1)
                    if "Gemma" not in self.tokenizer.__class__.__name__
                    and not self.__class__.__name__.startswith("OVLTX")
                    else -1
                ),
            )

        if self.text_encoder_2 is not None:
            self.text_encoder_2.model = self._reshape_text_encoder(
                self.text_encoder_2.model, batch_size, getattr(self.tokenizer_2, "model_max_length", -1)
            )

        if self.text_encoder_3 is not None:
            self.text_encoder_3.model = self._reshape_text_encoder(self.text_encoder_3.model, batch_size, -1)

        self.clear_requests()
        return self