in inference/model.py [0:0]
def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Applies a linear transformation to the incoming data: y = xA^T + b.
This function supports specialized implementations based on quantization
and tensor formats.
Args:
x (torch.Tensor): The input tensor.
weight (torch.Tensor): The weight tensor. It may be quantized and
requires dequantization for certain cases.
bias (Optional[torch.Tensor]): The bias tensor to be added. Default is None.
Returns:
torch.Tensor: The result of the linear transformation, which may involve
quantization-aware computations depending on the input parameters.
Notes:
- If `weight` is quantized (e.g., `element_size() == 1`), a dequantized version
is used for computation.
- If `gemm_impl == "bf16"`, dequantization and a `bf16` GEMM operation are applied.
- For other cases, the function applies quantization to `x` and uses `fp8_gemm` for computation.
"""
if weight.element_size() > 1:
return F.linear(x, weight, bias)
elif gemm_impl == "bf16":
weight = weight_dequant(weight, weight.scale)
return F.linear(x, weight, bias)
else:
x, scale = act_quant(x, block_size)
y = fp8_gemm(x, scale, weight, weight.scale)
if bias is not None:
y += bias
return y