in src/nanotron/data/nanoset.py [0:0]
def new_build_nanoset_index(self) -> Tuple[np.ndarray, np.ndarray]:
"""
Build dataset index that enables sequential reading while respecting weights.
Uses cache if available and parameters match.
"""
# Create a cache key based on the parameters that affect the index
cache_params = {
"dataset_folders": self.dataset_folders,
"dataset_lengths": self.dataset_lengths,
"dataset_weights": self.dataset_weights.tolist(),
"train_split_num_samples": self.train_split_num_samples,
"random_seed": self.random_seed,
"token_size": self.token_size,
"sequence_length": self.sequence_length,
}
# Create a deterministic cache key
cache_key = hashlib.md5(json.dumps(cache_params, sort_keys=True).encode()).hexdigest()
cache_file = os.path.join(self.cache_dir, f"index_{cache_key}.npz")
# Try to load from cache
if os.path.exists(cache_file):
try:
logger.info(f"[Nanoset] Loading index from cache: {cache_file}")
cached_data = np.load(cache_file)
return cached_data["dataset_index"], cached_data["dataset_sample_index"]
except Exception as e:
logger.warning(f"[Nanoset] Failed to load cache, rebuilding index: {e}")
logger.info(f"[Nanoset] Building sequential Nanoset index for {len(self.dataset_folders)} datasets")
# Original index building logic
total_weighted_samples = np.array(self.dataset_weights) * self.train_split_num_samples
samples_per_dataset = np.floor(total_weighted_samples).astype(np.int64)
remaining = self.train_split_num_samples - samples_per_dataset.sum()
if remaining > 0:
fractional_parts = total_weighted_samples - samples_per_dataset
indices = np.argsort(fractional_parts)[-remaining:]
samples_per_dataset[indices] += 1
dataset_positions = np.zeros(len(self.dataset_folders), dtype=np.int64)
dataset_index = np.zeros(self.train_split_num_samples, dtype=np.int64)
dataset_sample_index = np.zeros(self.train_split_num_samples, dtype=np.int64)
dataset_order = np.repeat(np.arange(len(self.dataset_folders)), samples_per_dataset)
rng = np.random.RandomState(self.random_seed)
rng.shuffle(dataset_order)
for idx, dataset_idx in tqdm(enumerate(dataset_order), desc="Building Nanoset index"):
dataset_index[idx] = dataset_idx
dataset_sample_index[idx] = dataset_positions[dataset_idx]
dataset_positions[
dataset_idx
] += 1 # Read samples sequentially from each datatrove_dataset assuming they're already shuffled
# Save to cache
try:
os.makedirs(self.cache_dir, exist_ok=True)
np.savez(cache_file, dataset_index=dataset_index, dataset_sample_index=dataset_sample_index)
logger.info(f"[Nanoset] Saved index to cache: {cache_file}")
except Exception as e:
logger.warning(f"[Nanoset] Failed to save cache: {e}")
return dataset_index, dataset_sample_index