optimum/habana/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py [178:297]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    @classmethod
    def _split_inputs_into_batches(
        cls,
        batch_size,
        latents,
        prompt_embeds,
        negative_prompt_embeds,
        add_text_embeds,
        negative_pooled_prompt_embeds,
        add_time_ids,
        negative_add_time_ids,
    ):
        # Use torch.split to generate num_batches batches of size batch_size
        latents_batches = list(torch.split(latents, batch_size))
        prompt_embeds_batches = list(torch.split(prompt_embeds, batch_size))
        if negative_prompt_embeds is not None:
            negative_prompt_embeds_batches = list(torch.split(negative_prompt_embeds, batch_size))
        if add_text_embeds is not None:
            add_text_embeds_batches = list(torch.split(add_text_embeds, batch_size))
        if negative_pooled_prompt_embeds is not None:
            negative_pooled_prompt_embeds_batches = list(torch.split(negative_pooled_prompt_embeds, batch_size))
        if add_time_ids is not None:
            add_time_ids_batches = list(torch.split(add_time_ids, batch_size))
        if negative_add_time_ids is not None:
            negative_add_time_ids_batches = list(torch.split(negative_add_time_ids, batch_size))

        # If the last batch has less samples than batch_size, pad it with dummy samples
        num_dummy_samples = 0
        if latents_batches[-1].shape[0] < batch_size:
            num_dummy_samples = batch_size - latents_batches[-1].shape[0]
            # Pad latents_batches
            sequence_to_stack = (latents_batches[-1],) + tuple(
                torch.zeros_like(latents_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
            )
            latents_batches[-1] = torch.vstack(sequence_to_stack)
            # Pad prompt_embeds_batches
            sequence_to_stack = (prompt_embeds_batches[-1],) + tuple(
                torch.zeros_like(prompt_embeds_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
            )
            prompt_embeds_batches[-1] = torch.vstack(sequence_to_stack)
            # Pad negative_prompt_embeds_batches if necessary
            if negative_prompt_embeds is not None:
                sequence_to_stack = (negative_prompt_embeds_batches[-1],) + tuple(
                    torch.zeros_like(negative_prompt_embeds_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
                )
                negative_prompt_embeds_batches[-1] = torch.vstack(sequence_to_stack)
            # Pad add_text_embeds_batches if necessary
            if add_text_embeds is not None:
                sequence_to_stack = (add_text_embeds_batches[-1],) + tuple(
                    torch.zeros_like(add_text_embeds_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
                )
                add_text_embeds_batches[-1] = torch.vstack(sequence_to_stack)
            # Pad negative_pooled_prompt_embeds_batches if necessary
            if negative_pooled_prompt_embeds is not None:
                sequence_to_stack = (negative_pooled_prompt_embeds_batches[-1],) + tuple(
                    torch.zeros_like(negative_pooled_prompt_embeds_batches[-1][0][None, :])
                    for _ in range(num_dummy_samples)
                )
                negative_pooled_prompt_embeds_batches[-1] = torch.vstack(sequence_to_stack)
            # Pad add_time_ids_batches if necessary
            if add_time_ids is not None:
                sequence_to_stack = (add_time_ids_batches[-1],) + tuple(
                    torch.zeros_like(add_time_ids_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
                )
                add_time_ids_batches[-1] = torch.vstack(sequence_to_stack)
            # Pad negative_add_time_ids_batches if necessary
            if negative_add_time_ids is not None:
                sequence_to_stack = (negative_add_time_ids_batches[-1],) + tuple(
                    torch.zeros_like(negative_add_time_ids_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
                )
                negative_add_time_ids_batches[-1] = torch.vstack(sequence_to_stack)

        # Stack batches in the same tensor
        latents_batches = torch.stack(latents_batches)

        if negative_prompt_embeds is not None:
            # For classifier free guidance, we need to do two forward passes.
            # Here we concatenate the unconditional and text embeddings into a single batch
            # to avoid doing two forward passes
            for i, (negative_prompt_embeds_batch, prompt_embeds_batch) in enumerate(
                zip(negative_prompt_embeds_batches, prompt_embeds_batches[:])
            ):
                prompt_embeds_batches[i] = torch.cat([negative_prompt_embeds_batch, prompt_embeds_batch])
        prompt_embeds_batches = torch.stack(prompt_embeds_batches)

        if add_text_embeds is not None:
            if negative_pooled_prompt_embeds is not None:
                # For classifier free guidance, we need to do two forward passes.
                # Here we concatenate the unconditional and text embeddings into a single batch
                # to avoid doing two forward passes
                for i, (negative_pooled_prompt_embeds_batch, add_text_embeds_batch) in enumerate(
                    zip(negative_pooled_prompt_embeds_batches, add_text_embeds_batches[:])
                ):
                    add_text_embeds_batches[i] = torch.cat(
                        [negative_pooled_prompt_embeds_batch, add_text_embeds_batch]
                    )
            add_text_embeds_batches = torch.stack(add_text_embeds_batches)
        else:
            add_text_embeds_batches = None

        if add_time_ids is not None:
            if negative_add_time_ids is not None:
                # For classifier free guidance, we need to do two forward passes.
                # Here we concatenate the unconditional and text embeddings into a single batch
                # to avoid doing two forward passes
                for i, (negative_add_time_ids_batch, add_time_ids_batch) in enumerate(
                    zip(negative_add_time_ids_batches, add_time_ids_batches[:])
                ):
                    add_time_ids_batches[i] = torch.cat([negative_add_time_ids_batch, add_time_ids_batch])
            add_time_ids_batches = torch.stack(add_time_ids_batches)
        else:
            add_time_ids_batches = None

        return latents_batches, prompt_embeds_batches, add_text_embeds_batches, add_time_ids_batches, num_dummy_samples

    @torch.no_grad()
    def __call__(
        self,
        prompt: Union[str, List[str]] = None,
        prompt_2: Optional[Union[str, List[str]]] = None,
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



optimum/habana/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py [151:270]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    @classmethod
    def _split_inputs_into_batches(
        cls,
        batch_size,
        latents,
        prompt_embeds,
        negative_prompt_embeds,
        add_text_embeds,
        negative_pooled_prompt_embeds,
        add_time_ids,
        negative_add_time_ids,
    ):
        # Use torch.split to generate num_batches batches of size batch_size
        latents_batches = list(torch.split(latents, batch_size))
        prompt_embeds_batches = list(torch.split(prompt_embeds, batch_size))
        if negative_prompt_embeds is not None:
            negative_prompt_embeds_batches = list(torch.split(negative_prompt_embeds, batch_size))
        if add_text_embeds is not None:
            add_text_embeds_batches = list(torch.split(add_text_embeds, batch_size))
        if negative_pooled_prompt_embeds is not None:
            negative_pooled_prompt_embeds_batches = list(torch.split(negative_pooled_prompt_embeds, batch_size))
        if add_time_ids is not None:
            add_time_ids_batches = list(torch.split(add_time_ids, batch_size))
        if negative_add_time_ids is not None:
            negative_add_time_ids_batches = list(torch.split(negative_add_time_ids, batch_size))

        # If the last batch has less samples than batch_size, pad it with dummy samples
        num_dummy_samples = 0
        if latents_batches[-1].shape[0] < batch_size:
            num_dummy_samples = batch_size - latents_batches[-1].shape[0]
            # Pad latents_batches
            sequence_to_stack = (latents_batches[-1],) + tuple(
                torch.zeros_like(latents_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
            )
            latents_batches[-1] = torch.vstack(sequence_to_stack)
            # Pad prompt_embeds_batches
            sequence_to_stack = (prompt_embeds_batches[-1],) + tuple(
                torch.zeros_like(prompt_embeds_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
            )
            prompt_embeds_batches[-1] = torch.vstack(sequence_to_stack)
            # Pad negative_prompt_embeds_batches if necessary
            if negative_prompt_embeds is not None:
                sequence_to_stack = (negative_prompt_embeds_batches[-1],) + tuple(
                    torch.zeros_like(negative_prompt_embeds_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
                )
                negative_prompt_embeds_batches[-1] = torch.vstack(sequence_to_stack)
            # Pad add_text_embeds_batches if necessary
            if add_text_embeds is not None:
                sequence_to_stack = (add_text_embeds_batches[-1],) + tuple(
                    torch.zeros_like(add_text_embeds_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
                )
                add_text_embeds_batches[-1] = torch.vstack(sequence_to_stack)
            # Pad negative_pooled_prompt_embeds_batches if necessary
            if negative_pooled_prompt_embeds is not None:
                sequence_to_stack = (negative_pooled_prompt_embeds_batches[-1],) + tuple(
                    torch.zeros_like(negative_pooled_prompt_embeds_batches[-1][0][None, :])
                    for _ in range(num_dummy_samples)
                )
                negative_pooled_prompt_embeds_batches[-1] = torch.vstack(sequence_to_stack)
            # Pad add_time_ids_batches if necessary
            if add_time_ids is not None:
                sequence_to_stack = (add_time_ids_batches[-1],) + tuple(
                    torch.zeros_like(add_time_ids_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
                )
                add_time_ids_batches[-1] = torch.vstack(sequence_to_stack)
            # Pad negative_add_time_ids_batches if necessary
            if negative_add_time_ids is not None:
                sequence_to_stack = (negative_add_time_ids_batches[-1],) + tuple(
                    torch.zeros_like(negative_add_time_ids_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
                )
                negative_add_time_ids_batches[-1] = torch.vstack(sequence_to_stack)

        # Stack batches in the same tensor
        latents_batches = torch.stack(latents_batches)

        if negative_prompt_embeds is not None:
            # For classifier free guidance, we need to do two forward passes.
            # Here we concatenate the unconditional and text embeddings into a single batch
            # to avoid doing two forward passes
            for i, (negative_prompt_embeds_batch, prompt_embeds_batch) in enumerate(
                zip(negative_prompt_embeds_batches, prompt_embeds_batches[:])
            ):
                prompt_embeds_batches[i] = torch.cat([negative_prompt_embeds_batch, prompt_embeds_batch])
        prompt_embeds_batches = torch.stack(prompt_embeds_batches)

        if add_text_embeds is not None:
            if negative_pooled_prompt_embeds is not None:
                # For classifier free guidance, we need to do two forward passes.
                # Here we concatenate the unconditional and text embeddings into a single batch
                # to avoid doing two forward passes
                for i, (negative_pooled_prompt_embeds_batch, add_text_embeds_batch) in enumerate(
                    zip(negative_pooled_prompt_embeds_batches, add_text_embeds_batches[:])
                ):
                    add_text_embeds_batches[i] = torch.cat(
                        [negative_pooled_prompt_embeds_batch, add_text_embeds_batch]
                    )
            add_text_embeds_batches = torch.stack(add_text_embeds_batches)
        else:
            add_text_embeds_batches = None

        if add_time_ids is not None:
            if negative_add_time_ids is not None:
                # For classifier free guidance, we need to do two forward passes.
                # Here we concatenate the unconditional and text embeddings into a single batch
                # to avoid doing two forward passes
                for i, (negative_add_time_ids_batch, add_time_ids_batch) in enumerate(
                    zip(negative_add_time_ids_batches, add_time_ids_batches[:])
                ):
                    add_time_ids_batches[i] = torch.cat([negative_add_time_ids_batch, add_time_ids_batch])
            add_time_ids_batches = torch.stack(add_time_ids_batches)
        else:
            add_time_ids_batches = None

        return latents_batches, prompt_embeds_batches, add_text_embeds_batches, add_time_ids_batches, num_dummy_samples

    @torch.no_grad()
    def __call__(
        self,
        prompt: Union[str, List[str]] = None,
        prompt_2: Optional[Union[str, List[str]]] = None,
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



