in src/save_dataset.py [0:0]
def save_dataset(file_dir, file_name, X, y, grouplabels, group_sets, group_types, is_binary):
"""
Saves a dataset at the appropriate file path with the appropriate filename to local filesystem
Does NOT overwrite existing dataset to avoid issues if accidentally leaving parameters in the main driver unchanged
"""
# Concatenate a .npz if the file name is specified with no extension
if '.' not in file_name:
file_name = file_name + '.npz'
if not file_name.endswith('npz'):
raise ValueError(f'Invalid file name: {file_name} \n Please use .npz format, or no extension')
# If an extension was specified, ensure that the file is a .npz
if file_name[-4:] != '.npz':
raise ValueError('To save numpy data, file name must end in .npz')
final_path = os.path.join(file_dir, file_name)
if os.path.isfile(final_path):
print(f'WARNING: Desired file {final_path} already exists. Exiting without writing to a file...')
return
if not os.path.isdir(file_dir):
os.makedirs(file_dir)
np.savez(final_path, X=X, Y=y, grouplabels=grouplabels, group_sets=group_sets, group_types=group_types,
is_binary=is_binary)
print(f'Successfully saved data to {final_path}')