in point_e/evals/npz_stream.py [0:0]
def infos_from_file(cls, npz_path: str) -> Dict[str, "NumpyArrayInfo"]:
"""
Extract the info of every array in an npz file.
"""
if not os.path.exists(npz_path):
raise FileNotFoundError(f"batch of samples was not found: {npz_path}")
results = {}
with open(npz_path, "rb") as f:
with zipfile.ZipFile(f, "r") as zip_f:
for name in zip_f.namelist():
if not name.endswith(".npy"):
continue
key_name = name[: -len(".npy")]
with zip_f.open(name, "r") as arr_f:
version = np.lib.format.read_magic(arr_f)
if version == (1, 0):
header = np.lib.format.read_array_header_1_0(arr_f)
elif version == (2, 0):
header = np.lib.format.read_array_header_2_0(arr_f)
else:
raise ValueError(f"unknown numpy array version: {version}")
shape, _, dtype = header
results[key_name] = cls(name=key_name, dtype=dtype, shape=shape)
return results