in arctic_inference/vllm/ulysses.py [0:0]
def __init__(self, num_heads, *args, **kwargs):
from .model_runner import is_shift_parallel_mode
self.sp_size = parallel_state._SP.world_size
self.sp_device_group = parallel_state._SP.device_group
if not is_shift_parallel_mode():
num_heads //= self.sp_size
num_kv_heads = kwargs["num_kv_heads"]
self.is_kv_replicated = True if num_kv_heads < self.sp_size else False
if self.is_kv_replicated:
num_kv_heads = 1
assert parallel_state._SP_AA is not None and parallel_state._SP_AG is not None, (
"UlyssesAttentionPatch requires SP_AA and SP_AG groups to be initialized.")
self.sp_aa_device_group = parallel_state._SP_AA.device_group
self.sp_ag_device_group = parallel_state._SP_AG.device_group
self.sp_aa_size = parallel_state._SP_AA.world_size
self.sp_ag_size = parallel_state._SP_AG.world_size
# this reorders the all-gathered sequence
self.order = [j * self.sp_aa_size + i
for i in range(self.sp_aa_size)
for j in range(self.sp_ag_size)]
else:
num_kv_heads //= self.sp_size
kwargs["num_kv_heads"] = num_kv_heads
return self._orig_init(num_heads, *args, **kwargs)