def _split_inputs_into_batches()

in optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py [0:0]


    def _split_inputs_into_batches(cls, batch_size, latents, text_embeddings, uncond_embeddings, image, noise_level):
        # Use torch.split to generate num_batches batches of size batch_size
        latents_batches = list(torch.split(latents, batch_size))
        text_embeddings_batches = list(torch.split(text_embeddings, batch_size))
        image_batches = list(torch.split(image, batch_size))
        noise_level_batches = list(torch.split(noise_level.view(-1, 1), batch_size))
        if uncond_embeddings is not None:
            uncond_embeddings_batches = list(torch.split(uncond_embeddings, batch_size))

        # If the last batch has less samples than batch_size, pad it with dummy samples
        num_dummy_samples = 0
        if latents_batches[-1].shape[0] < batch_size:
            num_dummy_samples = batch_size - latents_batches[-1].shape[0]
            # Pad latents_batches
            sequence_to_stack = (latents_batches[-1],) + tuple(
                torch.zeros_like(latents_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
            )
            latents_batches[-1] = torch.vstack(sequence_to_stack)
            # Pad image_batches
            sequence_to_stack = (image_batches[-1],) + tuple(
                torch.zeros_like(image_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
            )
            image_batches[-1] = torch.vstack(sequence_to_stack)
            # Pad noise_level_batches
            sequence_to_stack = (noise_level_batches[-1],) + tuple(
                torch.zeros_like(noise_level_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
            )
            noise_level_batches[-1] = torch.vstack(sequence_to_stack)
            # Pad text_embeddings_batches
            sequence_to_stack = (text_embeddings_batches[-1],) + tuple(
                torch.zeros_like(text_embeddings_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
            )
            text_embeddings_batches[-1] = torch.vstack(sequence_to_stack)
            # Pad uncond_embeddings_batches if necessary
            if uncond_embeddings is not None:
                sequence_to_stack = (uncond_embeddings_batches[-1],) + tuple(
                    torch.zeros_like(uncond_embeddings_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
                )
                uncond_embeddings_batches[-1] = torch.vstack(sequence_to_stack)

        # Stack batches in the same tensor
        latents_batches = torch.stack(latents_batches)
        if uncond_embeddings is not None:
            # For classifier free guidance, we need to do two forward passes.
            # Here we concatenate the unconditional and text embeddings into a single batch
            # to avoid doing two forward passes
            for i, (uncond_embeddings_batch, text_embeddings_batch) in enumerate(
                zip(uncond_embeddings_batches, text_embeddings_batches[:])
            ):
                text_embeddings_batches[i] = torch.cat([uncond_embeddings_batch, text_embeddings_batch])
        text_embeddings_batches = torch.stack(text_embeddings_batches)
        image_batches = torch.stack(image_batches)
        noise_level_batches = torch.stack(noise_level_batches).squeeze(-1)

        return latents_batches, text_embeddings_batches, image_batches, noise_level_batches, num_dummy_samples