in optimum/quanto/tensor/weights/qbits.py [0:0]
def load_from_state_dict(state_dict, prefix, qtype, axis, group_size, size, stride, missing_keys):
if group_size is None:
data_size = size
data_stride = stride
else:
data_size = grouped_shape(size, axis, group_size)
assert len(data_size) == 2
# In row major, inner dimension (stride 1) is the last one
data_stride = (data_size[1], 1)
inner_tensors_dict = {
"_data": PackedTensor.load_from_state_dict(
state_dict, prefix + "_data.", qtype.bits, data_size, data_stride, missing_keys=missing_keys
)
}
missing = inner_tensors_dict["_data"] is None
for name in ["_scale", "_shift"]:
if prefix + name not in state_dict:
missing_keys.append(prefix + name)
missing = True
else:
inner_tensors_dict[name] = state_dict.pop(prefix + name)
if missing: # could not deserialize because of missing keys
return None
meta = {
"qtype": qtype.name,
"axis": str(axis),
"group_size": str(group_size),
"size": str(list(size)),
"stride": str(list(stride)),
}
return WeightQBitsTensor.__tensor_unflatten__(inner_tensors_dict, meta, None, None)