in optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py [0:0]
def _split_inputs_into_batches(cls, batch_size, latents, text_embeddings, uncond_embeddings, image, noise_level):
# Use torch.split to generate num_batches batches of size batch_size
latents_batches = list(torch.split(latents, batch_size))
text_embeddings_batches = list(torch.split(text_embeddings, batch_size))
image_batches = list(torch.split(image, batch_size))
noise_level_batches = list(torch.split(noise_level.view(-1, 1), batch_size))
if uncond_embeddings is not None:
uncond_embeddings_batches = list(torch.split(uncond_embeddings, batch_size))
# If the last batch has less samples than batch_size, pad it with dummy samples
num_dummy_samples = 0
if latents_batches[-1].shape[0] < batch_size:
num_dummy_samples = batch_size - latents_batches[-1].shape[0]
# Pad latents_batches
sequence_to_stack = (latents_batches[-1],) + tuple(
torch.zeros_like(latents_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
)
latents_batches[-1] = torch.vstack(sequence_to_stack)
# Pad image_batches
sequence_to_stack = (image_batches[-1],) + tuple(
torch.zeros_like(image_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
)
image_batches[-1] = torch.vstack(sequence_to_stack)
# Pad noise_level_batches
sequence_to_stack = (noise_level_batches[-1],) + tuple(
torch.zeros_like(noise_level_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
)
noise_level_batches[-1] = torch.vstack(sequence_to_stack)
# Pad text_embeddings_batches
sequence_to_stack = (text_embeddings_batches[-1],) + tuple(
torch.zeros_like(text_embeddings_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
)
text_embeddings_batches[-1] = torch.vstack(sequence_to_stack)
# Pad uncond_embeddings_batches if necessary
if uncond_embeddings is not None:
sequence_to_stack = (uncond_embeddings_batches[-1],) + tuple(
torch.zeros_like(uncond_embeddings_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
)
uncond_embeddings_batches[-1] = torch.vstack(sequence_to_stack)
# Stack batches in the same tensor
latents_batches = torch.stack(latents_batches)
if uncond_embeddings is not None:
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
for i, (uncond_embeddings_batch, text_embeddings_batch) in enumerate(
zip(uncond_embeddings_batches, text_embeddings_batches[:])
):
text_embeddings_batches[i] = torch.cat([uncond_embeddings_batch, text_embeddings_batch])
text_embeddings_batches = torch.stack(text_embeddings_batches)
image_batches = torch.stack(image_batches)
noise_level_batches = torch.stack(noise_level_batches).squeeze(-1)
return latents_batches, text_embeddings_batches, image_batches, noise_level_batches, num_dummy_samples