def forward()

in server/text_generation_server/layers/linear.py [0:0]


    def forward(self, inp: torch.Tensor) -> torch.Tensor:
        weight = self.weight
        bias = self.bias

        if (
            self.use_skinny_gemm
            and inp.dtype == torch.float16
            and inp.shape[-1] % 8 == 0
        ):
            batched = False
            inp_shape = inp.shape

            if inp.dim() == 3:
                inp = inp.view(-1, inp_shape[-1])
                batched = True

            m, n, k = weight.shape[0], inp_shape[0], inp_shape[1]
            if m > 8 and n <= 4:
                out = torch.empty(
                    inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
                )
                ops.wvSpltK(weight, inp, out, n, self.cu_count)
            elif m % 4 == 0 and n == 1 and k <= 8192:
                out = torch.empty(
                    inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
                )
                ops.LLMM1(weight, inp, out, 4)
            else:
                out = F.linear(inp, weight)

            if batched:
                out.view(*inp_shape[:-1], out.shape[-1])

            if bias is not None:
                out = out + bias
            return out
        return F.linear(inp, self.weight, self.bias)