def create()

in optimum/quanto/tensor/weights/qbits.py [0:0]


    def create(qtype, axis, group_size, size, stride, data, scale, shift, requires_grad=False):
        """Factory method to create a WeightQBitsTensor

        This selects the most appropriate WeightQBitsTensor based on the configuration.

        Args:
            axis (`int`):
                The axis that is preserved by quantization (usually zero for linear weights).
            group_size (`int`):
                The group size that further splits the data elements for each index along the quantization axis.
            size ():
                The Tensor size.
            stride():
                The Tensor stride.
            data (`torch.Tensor`):
                The tensor data, either as a raw uint8 torch.Tensor or as a PackedTensor.
            scale (`torch.Tensor`):
                The floating point scale expressed as a torch.Tensor.
            shift (`torch.Tensor`):
                The shift expressed as a torch.Tensor. It can be either an integer representing zero
                (i.e. zero-point) or a float value.
            requires_grad (`bool`):
                If the Tensor must be receive a gradient or not.

        Returns:
            a `WeightQBitsTensor` (can be a subclass).
        """
        from .awq import AWQWeightQBitsTensor
        from .tinygemm import TinyGemmWeightQBitsTensor

        if (
            qtype == qint4
            and size[0] >= 128  # FIXME Workaround AWQ GEMM crash (GEMV might work for short inputs)
            and scale.dtype == torch.float16
            and axis == 0
            and group_size == 128
            and len(size) == 2
            and (data.device.type == "cuda" and torch.version.cuda)
            and torch.cuda.get_device_capability(data.device)[0] >= 8
            and is_extension_available("quanto_cuda")
        ):
            if type(data) is PackedTensor:
                data = data.unpack()
            return AWQWeightQBitsTensor(qtype, axis, group_size, size, stride, data, scale, shift, requires_grad)
        if qtype == qint4 and scale.dtype == torch.bfloat16 and axis == 0 and group_size == 128 and len(size) == 2:
            if data.device.type == "cpu" or (
                (data.device.type == "cuda" and torch.version.cuda)
                and version.parse(torch.version.cuda).release >= (12, 1)
                and torch.cuda.get_device_capability(data.device)[0] >= 8
            ):
                if type(data) is PackedTensor:
                    data = data.unpack()
                return TinyGemmWeightQBitsTensor(
                    qtype, axis, group_size, size, stride, data, (scale, shift), requires_grad
                )

        return WeightQBitsTensor(qtype, axis, group_size, size, stride, data, scale, shift, requires_grad)