in bindings/python/py_src/safetensors/torch.py [0:0]
def _find_shared_tensors(state_dict: Dict[str, torch.Tensor]) -> List[Set[str]]:
tensors = defaultdict(set)
for k, v in state_dict.items():
if v.device != torch.device("meta") and storage_ptr(v) != 0 and storage_size(v) != 0:
# Need to add device as key because of multiple GPU.
tensors[(v.device, storage_ptr(v), storage_size(v))].add(k)
tensors = list(sorted(tensors.values()))
tensors = _filter_shared_not_shared(tensors, state_dict)
return tensors