def load_from_state_dict()

in optimum/quanto/tensor/weights/qbits.py [0:0]


    def load_from_state_dict(state_dict, prefix, qtype, axis, group_size, size, stride, missing_keys):
        if group_size is None:
            data_size = size
            data_stride = stride
        else:
            data_size = grouped_shape(size, axis, group_size)
            assert len(data_size) == 2
            # In row major, inner dimension (stride 1) is the last one
            data_stride = (data_size[1], 1)
        inner_tensors_dict = {
            "_data": PackedTensor.load_from_state_dict(
                state_dict, prefix + "_data.", qtype.bits, data_size, data_stride, missing_keys=missing_keys
            )
        }
        missing = inner_tensors_dict["_data"] is None
        for name in ["_scale", "_shift"]:
            if prefix + name not in state_dict:
                missing_keys.append(prefix + name)
                missing = True
            else:
                inner_tensors_dict[name] = state_dict.pop(prefix + name)

        if missing:  # could not deserialize because of missing keys
            return None

        meta = {
            "qtype": qtype.name,
            "axis": str(axis),
            "group_size": str(group_size),
            "size": str(list(size)),
            "stride": str(list(stride)),
        }
        return WeightQBitsTensor.__tensor_unflatten__(inner_tensors_dict, meta, None, None)