in optimum/habana/diffusers/pipelines/flux/pipeline_flux.py [0:0]
def _split_inputs_into_batches(cls, batch_size, latents, prompt_embeds, pooled_prompt_embeds, guidance):
# Use torch.split to generate num_batches batches of size batch_size
latents_batches = list(torch.split(latents, batch_size))
prompt_embeds_batches = list(torch.split(prompt_embeds, batch_size))
if pooled_prompt_embeds is not None:
pooled_prompt_embeds_batches = list(torch.split(pooled_prompt_embeds, batch_size))
if guidance is not None:
guidance_batches = list(torch.split(guidance, batch_size))
# If the last batch has less samples than batch_size, pad it with dummy samples
num_dummy_samples = 0
if latents_batches[-1].shape[0] < batch_size:
num_dummy_samples = batch_size - latents_batches[-1].shape[0]
# Pad latents_batches
sequence_to_stack = (latents_batches[-1],) + tuple(
torch.zeros_like(latents_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
)
latents_batches[-1] = torch.vstack(sequence_to_stack)
# Pad prompt_embeds_batches
sequence_to_stack = (prompt_embeds_batches[-1],) + tuple(
torch.zeros_like(prompt_embeds_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
)
prompt_embeds_batches[-1] = torch.vstack(sequence_to_stack)
# Pad pooled_prompt_embeds if necessary
if pooled_prompt_embeds is not None:
sequence_to_stack = (pooled_prompt_embeds_batches[-1],) + tuple(
torch.zeros_like(pooled_prompt_embeds_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
)
pooled_prompt_embeds_batches[-1] = torch.vstack(sequence_to_stack)
# Pad guidance if necessary
if guidance is not None:
guidance_batches[-1] = guidance_batches[-1].unsqueeze(1)
sequence_to_stack = (guidance_batches[-1],) + tuple(
torch.zeros_like(guidance_batches[-1][0][None, :]) for _ in range(num_dummy_samples)
)
guidance_batches[-1] = torch.vstack(sequence_to_stack).squeeze(1)
# Stack batches in the same tensor
latents_batches = torch.stack(latents_batches)
prompt_embeds_batches = torch.stack(prompt_embeds_batches)
pooled_prompt_embeds_batches = torch.stack(pooled_prompt_embeds_batches)
guidance_batches = torch.stack(guidance_batches) if guidance is not None else None
return (
latents_batches,
prompt_embeds_batches,
pooled_prompt_embeds_batches,
guidance_batches,
num_dummy_samples,
)