def forward()

in arctic_inference/vllm/ulysses.py [0:0]


    def forward(self, query, key, value, **kwargs):
        from .model_runner import is_shift_parallel_mode
        if self.sp_size == 1 or is_shift_parallel_mode():
            return self._orig_forward(query, key, value, **kwargs)

        if self.is_kv_replicated:
            # Ulysses all-to-all 1/2 (query)
            q = query.view(-1,
                           self.sp_size, self.num_heads * self.head_size).transpose(
                               0, 1).reshape(-1,
                                             self.num_heads * self.head_size)
            q_ = torch.empty_like(q)
            torch.distributed.all_to_all_single(q_, q, group=self.sp_device_group)
            # Ulysses pack (key, value)
            kv = torch.cat((key.view(-1, self.sp_aa_size, self.num_kv_heads * self.head_size),
                            value.view(-1, self.sp_aa_size, self.num_kv_heads * self.head_size)),
                           dim=-1).transpose(0, 1).reshape(
                               -1, 2 * self.num_kv_heads * self.head_size)
            # Ulysses all-to-all (key, value)
            kv_part = torch.empty_like(kv)
            torch.distributed.all_to_all_single(kv_part, kv, group=self.sp_aa_device_group)
            # Ulysses all-gather (key, value)
            kv_ = torch.empty(q_.shape[0],
                              2 * self.num_kv_heads * self.head_size,
                              dtype=query.dtype,
                              device=query.device)
            torch.distributed.all_gather_into_tensor(kv_,
                                                     kv_part,
                                                     group=self.sp_ag_device_group)
            # reorder
            kv_chunk = kv_.chunk(self.sp_size)
            kv_ordered = torch.cat([kv_chunk[i] for i in self.order])
            # unpack (key, value)
            k_, v_ = kv_ordered.split([self.num_kv_heads * self.head_size] * 2, dim=-1)
        else:
            # pack
            qkv = (torch.cat(
                (query.view(-1, self.sp_size, self.num_heads * self.head_size),
                key.view(-1, self.sp_size, self.num_kv_heads * self.head_size),
                value.view(-1, self.sp_size, self.num_kv_heads * self.head_size)),
                dim=-1)
                .transpose(0, 1)
                .reshape(-1, (self.num_heads + 2 * self.num_kv_heads) * self.head_size))
            # Ulysses all-to-all 1/2
            qkv_ = torch.empty_like(qkv)
            torch.distributed.all_to_all_single(qkv_, qkv, group=self.sp_device_group)
            # unpack
            q_, k_, v_ = qkv_.split([
                self.num_heads * self.head_size, self.num_kv_heads *
                self.head_size, self.num_kv_heads * self.head_size
            ], dim=-1)

        # original attention
        c_ = self._orig_forward(q_, k_, v_, **kwargs)

        # Ulysses all-to-all 2/2
        c = torch.empty_like(c_)
        torch.distributed.all_to_all_single(c, c_, group=self.sp_device_group)
        output = (c.view(self.sp_size, -1, self.num_heads * self.head_size)
                  .transpose(0, 1)
                  .reshape(-1, self.num_heads * self.sp_size * self.head_size))
        
        return output