backends/gaudi/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py [51:213]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        )
    else:
        return _load_qkv(config, prefix, weights, head_size, num_heads)


def _load_qkv_gptq(config, prefix: str, weights):
    world_size = weights.process_group.size()
    rank = weights.process_group.rank()

    # Weights
    weight = weights.get_weights_col_packed_qkv(
        f"{prefix}.c_attn",
        config.num_attention_heads,
        config.num_attention_heads,
    )

    # Bias
    slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
    shape = slice_.get_shape()
    total_size = shape[0]
    assert total_size % 3 == 0, f"Prepacked is not divisible by {3}"
    single_size = total_size // 3
    assert single_size % world_size == 0
    block_size = single_size // world_size
    start = rank * block_size
    stop = (rank + 1) * block_size
    tensors = []
    for i in range(3):
        tensor = slice_[start + i * single_size : stop + i * single_size]
        tensors.append(tensor)
    bias = torch.cat(tensors, dim=0)
    bias = bias.to(device=weights.device)

    return TensorParallelColumnLinear(get_linear(weight, bias))


def _load_qkv(config, prefix: str, weights, head_size, num_heads):
    """Load QKV from a single, transposed matrix."""

    slice_ = weights._get_slice(f"{prefix}.c_attn.weight")
    shape = slice_.get_shape()
    total_size = shape[1]
    assert total_size % 3 == 0, f"Prepacked is not divisible by {3}"
    world_size = weights.process_group.size()
    single_size = total_size // 3
    assert single_size % world_size == 0
    rank = weights.process_group.rank()

    # Weights
    block_size = single_size // world_size
    start = rank * block_size
    stop = (rank + 1) * block_size
    tensors = []
    for i in range(3):
        tensor = slice_[:, start + i * single_size : stop + i * single_size]
        tensors.append(tensor)
    weight = torch.cat(tensors, dim=1).T
    weight = weight.to(dtype=weights.dtype)
    weight = weight.to(device=weights.device)

    # Bias
    slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
    shape = slice_.get_shape()
    total_size = shape[0]
    single_size = total_size // 3
    block_size = single_size // world_size
    assert single_size % world_size == 0
    start = rank * block_size
    stop = (rank + 1) * block_size
    b = []
    for i in range(3):
        tensor = slice_[start + i * single_size : stop + i * single_size]
        b.append(tensor)
    bias = torch.cat(b, dim=0)
    bias = bias.to(dtype=weights.dtype)
    bias = bias.to(device=weights.device)
    assert list(bias.shape) == [
        3 * num_heads * head_size
    ], f"{weight.shape} != {[3 * num_heads * head_size]}"

    return TensorParallelColumnLinear(get_linear(weight, bias))


def load_row(config, prefix: str, weights, bias: bool):
    """load_row, but with transposed weight matrices."""

    if config.quantize == "gptq":
        weight = weights.get_weights_row(prefix)
    else:
        weight = weights.get_sharded(f"{prefix}.weight", dim=0).T

    if bias and weights.process_group.rank() == 0:
        # Rank is only on the first rank process
        bias = weights.get_tensor(f"{prefix}.bias")
    else:
        bias = None

    return TensorParallelRowLinear(
        get_linear(weight, bias), process_group=weights.process_group
    )


def load_col(config, prefix: str, weights, bias: bool):
    """load_col, but with transposed weight matrices."""
    if config.quantize == "gptq":
        weight = weights.get_multi_weights_col([prefix], dim=1)
    else:
        weight = weights.get_sharded(f"{prefix}.weight", dim=1).T

    if bias:
        bias = weights.get_sharded(f"{prefix}.bias", dim=0)
    else:
        bias = None

    return TensorParallelColumnLinear(get_linear(weight, bias))


class FlashGPT2Attention(torch.nn.Module):
    def __init__(
        self,
        prefix: str,
        config,
        weights,
    ):
        super().__init__()
        self.num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size

        self.head_size = self.hidden_size // self.num_heads
        self.softmax_scale = self.head_size**-0.5

        if self.num_heads % weights.process_group.size() != 0:
            raise ValueError(
                f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
                f"and `num_shards`: {weights.process_group.size()}"
            )
        self.num_heads = self.num_heads // weights.process_group.size()

        self.query_key_value = load_qkv(
            config,
            prefix=prefix,
            weights=weights,
            head_size=self.head_size,
            num_heads=self.num_heads,
        )
        self.kv_scales = get_kv_scales(weights, f"{prefix}")

        self.o_proj = load_row(
            config,
            prefix=f"{prefix}.c_proj",
            weights=weights,
            bias=True,
        )

        self.kv_head_mapping = torch.arange(
            0, self.num_heads, dtype=torch.int32, device=weights.device
        )

    def forward(
        self,
        hidden_states,
        cu_seqlen_prefill,
        kv_cache,
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py [52:214]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        )
    else:
        return _load_qkv(config, prefix, weights, head_size, num_heads)


def _load_qkv_gptq(config, prefix: str, weights):
    world_size = weights.process_group.size()
    rank = weights.process_group.rank()

    # Weights
    weight = weights.get_weights_col_packed_qkv(
        f"{prefix}.c_attn",
        config.num_attention_heads,
        config.num_attention_heads,
    )

    # Bias
    slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
    shape = slice_.get_shape()
    total_size = shape[0]
    assert total_size % 3 == 0, f"Prepacked is not divisible by {3}"
    single_size = total_size // 3
    assert single_size % world_size == 0
    block_size = single_size // world_size
    start = rank * block_size
    stop = (rank + 1) * block_size
    tensors = []
    for i in range(3):
        tensor = slice_[start + i * single_size : stop + i * single_size]
        tensors.append(tensor)
    bias = torch.cat(tensors, dim=0)
    bias = bias.to(device=weights.device)

    return TensorParallelColumnLinear(get_linear(weight, bias))


def _load_qkv(config, prefix: str, weights, head_size, num_heads):
    """Load QKV from a single, transposed matrix."""

    slice_ = weights._get_slice(f"{prefix}.c_attn.weight")
    shape = slice_.get_shape()
    total_size = shape[1]
    assert total_size % 3 == 0, f"Prepacked is not divisible by {3}"
    world_size = weights.process_group.size()
    single_size = total_size // 3
    assert single_size % world_size == 0
    rank = weights.process_group.rank()

    # Weights
    block_size = single_size // world_size
    start = rank * block_size
    stop = (rank + 1) * block_size
    tensors = []
    for i in range(3):
        tensor = slice_[:, start + i * single_size : stop + i * single_size]
        tensors.append(tensor)
    weight = torch.cat(tensors, dim=1).T
    weight = weight.to(dtype=weights.dtype)
    weight = weight.to(device=weights.device)

    # Bias
    slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
    shape = slice_.get_shape()
    total_size = shape[0]
    single_size = total_size // 3
    block_size = single_size // world_size
    assert single_size % world_size == 0
    start = rank * block_size
    stop = (rank + 1) * block_size
    b = []
    for i in range(3):
        tensor = slice_[start + i * single_size : stop + i * single_size]
        b.append(tensor)
    bias = torch.cat(b, dim=0)
    bias = bias.to(dtype=weights.dtype)
    bias = bias.to(device=weights.device)
    assert list(bias.shape) == [
        3 * num_heads * head_size
    ], f"{weight.shape} != {[3 * num_heads * head_size]}"

    return TensorParallelColumnLinear(get_linear(weight, bias))


def load_row(config, prefix: str, weights, bias: bool):
    """load_row, but with transposed weight matrices."""

    if config.quantize == "gptq":
        weight = weights.get_weights_row(prefix)
    else:
        weight = weights.get_sharded(f"{prefix}.weight", dim=0).T

    if bias and weights.process_group.rank() == 0:
        # Rank is only on the first rank process
        bias = weights.get_tensor(f"{prefix}.bias")
    else:
        bias = None

    return TensorParallelRowLinear(
        get_linear(weight, bias), process_group=weights.process_group
    )


def load_col(config, prefix: str, weights, bias: bool):
    """load_col, but with transposed weight matrices."""
    if config.quantize == "gptq":
        weight = weights.get_multi_weights_col([prefix], dim=1)
    else:
        weight = weights.get_sharded(f"{prefix}.weight", dim=1).T

    if bias:
        bias = weights.get_sharded(f"{prefix}.bias", dim=0)
    else:
        bias = None

    return TensorParallelColumnLinear(get_linear(weight, bias))


class FlashGPT2Attention(torch.nn.Module):
    def __init__(
        self,
        prefix: str,
        config,
        weights,
    ):
        super().__init__()
        self.num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size

        self.head_size = self.hidden_size // self.num_heads
        self.softmax_scale = self.head_size**-0.5

        if self.num_heads % weights.process_group.size() != 0:
            raise ValueError(
                f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
                f"and `num_shards`: {weights.process_group.size()}"
            )
        self.num_heads = self.num_heads // weights.process_group.size()

        self.query_key_value = load_qkv(
            config,
            prefix=prefix,
            weights=weights,
            head_size=self.head_size,
            num_heads=self.num_heads,
        )
        self.kv_scales = get_kv_scales(weights, f"{prefix}")

        self.o_proj = load_row(
            config,
            prefix=f"{prefix}.c_proj",
            weights=weights,
            bias=True,
        )

        self.kv_head_mapping = torch.arange(
            0, self.num_heads, dtype=torch.int32, device=weights.device
        )

    def forward(
        self,
        hidden_states,
        cu_seqlen_prefill,
        kv_cache,
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



