def __init__()

in arctic_inference/vllm/ulysses.py [0:0]


    def __init__(self, num_heads, *args, **kwargs):
        from .model_runner import is_shift_parallel_mode
        self.sp_size = parallel_state._SP.world_size
        self.sp_device_group = parallel_state._SP.device_group
        if not is_shift_parallel_mode():
            num_heads //= self.sp_size
            num_kv_heads = kwargs["num_kv_heads"]
            self.is_kv_replicated = True if num_kv_heads < self.sp_size else False
            if self.is_kv_replicated:
                num_kv_heads = 1
                assert parallel_state._SP_AA is not None and parallel_state._SP_AG is not None, (
                    "UlyssesAttentionPatch requires SP_AA and SP_AG groups to be initialized.")
                self.sp_aa_device_group = parallel_state._SP_AA.device_group
                self.sp_ag_device_group = parallel_state._SP_AG.device_group
                self.sp_aa_size = parallel_state._SP_AA.world_size
                self.sp_ag_size = parallel_state._SP_AG.world_size
                # this reorders the all-gathered sequence
                self.order = [j * self.sp_aa_size + i 
                              for i in range(self.sp_aa_size) 
                              for j in range(self.sp_ag_size)]
            else:
                num_kv_heads //= self.sp_size
            kwargs["num_kv_heads"] = num_kv_heads
        return self._orig_init(num_heads, *args, **kwargs)