in arctic_inference/vllm/ulysses.py [0:0]
def forward(self, query, key, value, **kwargs):
from .model_runner import is_shift_parallel_mode
if self.sp_size == 1 or is_shift_parallel_mode():
return self._orig_forward(query, key, value, **kwargs)
if self.is_kv_replicated:
# Ulysses all-to-all 1/2 (query)
q = query.view(-1,
self.sp_size, self.num_heads * self.head_size).transpose(
0, 1).reshape(-1,
self.num_heads * self.head_size)
q_ = torch.empty_like(q)
torch.distributed.all_to_all_single(q_, q, group=self.sp_device_group)
# Ulysses pack (key, value)
kv = torch.cat((key.view(-1, self.sp_aa_size, self.num_kv_heads * self.head_size),
value.view(-1, self.sp_aa_size, self.num_kv_heads * self.head_size)),
dim=-1).transpose(0, 1).reshape(
-1, 2 * self.num_kv_heads * self.head_size)
# Ulysses all-to-all (key, value)
kv_part = torch.empty_like(kv)
torch.distributed.all_to_all_single(kv_part, kv, group=self.sp_aa_device_group)
# Ulysses all-gather (key, value)
kv_ = torch.empty(q_.shape[0],
2 * self.num_kv_heads * self.head_size,
dtype=query.dtype,
device=query.device)
torch.distributed.all_gather_into_tensor(kv_,
kv_part,
group=self.sp_ag_device_group)
# reorder
kv_chunk = kv_.chunk(self.sp_size)
kv_ordered = torch.cat([kv_chunk[i] for i in self.order])
# unpack (key, value)
k_, v_ = kv_ordered.split([self.num_kv_heads * self.head_size] * 2, dim=-1)
else:
# pack
qkv = (torch.cat(
(query.view(-1, self.sp_size, self.num_heads * self.head_size),
key.view(-1, self.sp_size, self.num_kv_heads * self.head_size),
value.view(-1, self.sp_size, self.num_kv_heads * self.head_size)),
dim=-1)
.transpose(0, 1)
.reshape(-1, (self.num_heads + 2 * self.num_kv_heads) * self.head_size))
# Ulysses all-to-all 1/2
qkv_ = torch.empty_like(qkv)
torch.distributed.all_to_all_single(qkv_, qkv, group=self.sp_device_group)
# unpack
q_, k_, v_ = qkv_.split([
self.num_heads * self.head_size, self.num_kv_heads *
self.head_size, self.num_kv_heads * self.head_size
], dim=-1)
# original attention
c_ = self._orig_forward(q_, k_, v_, **kwargs)
# Ulysses all-to-all 2/2
c = torch.empty_like(c_)
torch.distributed.all_to_all_single(c, c_, group=self.sp_device_group)
output = (c.view(self.sp_size, -1, self.num_heads * self.head_size)
.transpose(0, 1)
.reshape(-1, self.num_heads * self.sp_size * self.head_size))
return output