def _split_inputs_into_batches()

in optimum/habana/diffusers/pipelines/flux/pipeline_flux.py [0:0]


    def _split_inputs_into_batches(cls, batch_size, latents, prompt_embeds, pooled_prompt_embeds, guidance):
        # Use torch.split to generate num_batches batches of size batch_size
        latents_batches = list(torch.split(latents, batch_size))
        prompt_embeds_batches = list(torch.split(prompt_embeds, batch_size))
        if pooled_prompt_embeds is not None:
            pooled_prompt_embeds_batches = list(torch.split(pooled_prompt_embeds, batch_size))
        if guidance is not None:
            guidance_batches = list(torch.split(guidance, batch_size))

        # If the last batch has less samples than batch_size, pad it with dummy samples
        num_dummy_samples = 0
        if latents_batches[-1].shape[0] < batch_size:
            num_dummy_samples = batch_size - latents_batches[-1].shape[0]

            # Pad latents_batches
            sequence_to_stack = (latents_batches[-1],) + tuple(
                torch.zeros_like(latents_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
            )
            latents_batches[-1] = torch.vstack(sequence_to_stack)

            # Pad prompt_embeds_batches
            sequence_to_stack = (prompt_embeds_batches[-1],) + tuple(
                torch.zeros_like(prompt_embeds_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
            )
            prompt_embeds_batches[-1] = torch.vstack(sequence_to_stack)

            # Pad pooled_prompt_embeds if necessary
            if pooled_prompt_embeds is not None:
                sequence_to_stack = (pooled_prompt_embeds_batches[-1],) + tuple(
                    torch.zeros_like(pooled_prompt_embeds_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
                )
                pooled_prompt_embeds_batches[-1] = torch.vstack(sequence_to_stack)

            # Pad guidance if necessary
            if guidance is not None:
                guidance_batches[-1] = guidance_batches[-1].unsqueeze(1)
                sequence_to_stack = (guidance_batches[-1],) + tuple(
                    torch.zeros_like(guidance_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
                )
                guidance_batches[-1] = torch.vstack(sequence_to_stack).squeeze(1)

        # Stack batches in the same tensor
        latents_batches = torch.stack(latents_batches)
        prompt_embeds_batches = torch.stack(prompt_embeds_batches)
        pooled_prompt_embeds_batches = torch.stack(pooled_prompt_embeds_batches)
        guidance_batches = torch.stack(guidance_batches) if guidance is not None else None

        return (
            latents_batches,
            prompt_embeds_batches,
            pooled_prompt_embeds_batches,
            guidance_batches,
            num_dummy_samples,
        )