def new_build_nanoset_index()

in src/nanotron/data/nanoset.py [0:0]


    def new_build_nanoset_index(self) -> Tuple[np.ndarray, np.ndarray]:
        """
        Build dataset index that enables sequential reading while respecting weights.
        Uses cache if available and parameters match.
        """
        # Create a cache key based on the parameters that affect the index
        cache_params = {
            "dataset_folders": self.dataset_folders,
            "dataset_lengths": self.dataset_lengths,
            "dataset_weights": self.dataset_weights.tolist(),
            "train_split_num_samples": self.train_split_num_samples,
            "random_seed": self.random_seed,
            "token_size": self.token_size,
            "sequence_length": self.sequence_length,
        }

        # Create a deterministic cache key
        cache_key = hashlib.md5(json.dumps(cache_params, sort_keys=True).encode()).hexdigest()
        cache_file = os.path.join(self.cache_dir, f"index_{cache_key}.npz")

        # Try to load from cache
        if os.path.exists(cache_file):
            try:
                logger.info(f"[Nanoset] Loading index from cache: {cache_file}")
                cached_data = np.load(cache_file)
                return cached_data["dataset_index"], cached_data["dataset_sample_index"]
            except Exception as e:
                logger.warning(f"[Nanoset] Failed to load cache, rebuilding index: {e}")

        logger.info(f"[Nanoset] Building sequential Nanoset index for {len(self.dataset_folders)} datasets")

        # Original index building logic
        total_weighted_samples = np.array(self.dataset_weights) * self.train_split_num_samples
        samples_per_dataset = np.floor(total_weighted_samples).astype(np.int64)

        remaining = self.train_split_num_samples - samples_per_dataset.sum()
        if remaining > 0:
            fractional_parts = total_weighted_samples - samples_per_dataset
            indices = np.argsort(fractional_parts)[-remaining:]
            samples_per_dataset[indices] += 1

        dataset_positions = np.zeros(len(self.dataset_folders), dtype=np.int64)
        dataset_index = np.zeros(self.train_split_num_samples, dtype=np.int64)
        dataset_sample_index = np.zeros(self.train_split_num_samples, dtype=np.int64)

        dataset_order = np.repeat(np.arange(len(self.dataset_folders)), samples_per_dataset)
        rng = np.random.RandomState(self.random_seed)
        rng.shuffle(dataset_order)

        for idx, dataset_idx in tqdm(enumerate(dataset_order), desc="Building Nanoset index"):
            dataset_index[idx] = dataset_idx
            dataset_sample_index[idx] = dataset_positions[dataset_idx]
            dataset_positions[
                dataset_idx
            ] += 1  # Read samples sequentially from each datatrove_dataset assuming they're already shuffled

        # Save to cache
        try:
            os.makedirs(self.cache_dir, exist_ok=True)
            np.savez(cache_file, dataset_index=dataset_index, dataset_sample_index=dataset_sample_index)
            logger.info(f"[Nanoset] Saved index to cache: {cache_file}")
        except Exception as e:
            logger.warning(f"[Nanoset] Failed to save cache: {e}")

        return dataset_index, dataset_sample_index