in bitsandbytes/functional.py [0:0]
def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> "QuantState":
"""
unpacks components of state_dict into QuantState
where necessary, convert into strings, torch.dtype, ints, etc.
qs_dict: based on state_dict, with only relevant keys, striped of prefixes.
item with key `quant_state.bitsandbytes__[nf4/fp4]` may contain minor and non-tensor quant state items.
"""
# unpacking tensor with non-tensor components
qs_key = [k for k, v in qs_dict.items() if "quant_state" in k and isinstance(v, torch.Tensor)]
if not len(qs_key) and "quant_type" not in qs_dict:
raise ValueError("Expected packed or unpacked quant_state items, found neither")
elif len(qs_key) != 1 or qs_key[0].split(".")[-1] not in cls.valid_qs_type_keys:
raise ValueError(
f"There should be exactly one `quant_state` item with ending from {cls.valid_qs_type_keys}.\nDetected {qs_key}.",
)
# unpacking minor and non-tensor quant state items if necessary
if len(qs_key) == 1:
first_qs_key = qs_key[0]
qs_dict.update(unpack_tensor_to_dict(qs_dict.pop(first_qs_key)))
qs_dict = {k.split(".")[-1]: v for k, v in qs_dict.items()} # strip prefixes
assert set(qs_dict.keys()).issubset(cls.valid_qs_keys)
if "nested_absmax" in qs_dict:
offset = torch.tensor(float(qs_dict["nested_offset"])).to(device)
state2 = cls(
absmax=qs_dict["nested_absmax"].to(device),
blocksize=qs_dict["nested_blocksize"],
code=qs_dict["nested_quant_map"].to(device),
dtype=getattr(torch, qs_dict["nested_dtype"]),
)
else:
offset, state2 = None, None
quant_state = cls(
quant_type=qs_dict["quant_type"],
absmax=qs_dict["absmax"].to(device),
blocksize=qs_dict["blocksize"],
code=qs_dict["quant_map"].to(device),
dtype=getattr(torch, qs_dict["dtype"]),
shape=torch.Size(qs_dict["shape"]) if qs_dict["shape"] is not None else None,
offset=offset,
state2=state2,
)
return quant_state