def __call__()

in optimum/graphcore/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_mixin.py [0:0]


    def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
        batch_size, sequence_length, _ = hidden_states.shape

        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)

        query = attn.to_q(hidden_states)
        query = attn.head_to_batch_dim(query)

        encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)

        # Begin IPU modifications.
        attn_matrix_mem = query.element_size() * query.shape[0] * query.shape[1] * key.shape[1]
        num_slices = attn_matrix_mem // (self._attn_matrix_target_mem_mb * 1024 * 1024)
        num_slices = max(num_slices, 1)
        num_slices = self._nearest_divisor(query.shape[1], num_slices, 2 * num_slices)
        slice_size = query.shape[1] // num_slices

        hidden_states = []

        key = key.transpose(1, 2)
        for i in range(num_slices):
            start_idx = i * slice_size
            end_idx = (i + 1) * slice_size

            attn_slice = torch.matmul(query[:, start_idx:end_idx], key) * attn.scale
            if attention_mask is not None:
                attn_slice = attn_slice + attention_mask[:, start_idx:end_idx]
            attn_slice = attn_slice.softmax(dim=-1)
            attn_slice = torch.matmul(attn_slice, value)

            hidden_states.append(attn_slice)

        hidden_states = torch.cat(hidden_states, dim=1)
        # End IPU modifications.

        hidden_states = attn.batch_to_head_dim(hidden_states)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        return hidden_states