pipeline/clean/merge-mono.py (190 lines of code) (raw):

import argparse import glob import os from dataclasses import dataclass from pathlib import Path from typing import Generator from pipeline.common.datasets import ( CountingStep, FilteringStep, Statistics, WeakStringSet, shuffle_with_max_lines, ) from pipeline.common.downloads import ( format_bytes, get_human_readable_file_size, read_lines, write_lines, ) from pipeline.common.logging import get_logger from pipeline.common.memory import log_memory logger = get_logger(__file__) @dataclass class FilteringStatistics(Statistics): """ Gather statistics about the filtering process. """ def __init__(self, dataset_path: Path) -> None: super().__init__(dataset_path) self.final_truncated_monolingual_lines = CountingStep( "After truncation via the config's `experiment.mono-max-sentences-src.total`, " "how many lines are left." ) self.final_truncated_monolingual_codepoints = CountingStep( "The amount of codepoints in the final monolingual corpus." ) self.parallel_corpus_lines = CountingStep( "The size of the merged parallel corpus before truncation." ) self.duplicates_of_parallel_corpus = CountingStep( "How much of the monolingual data was duplicated in the merged parallel corpus." ) self.duplicates_of_monolingual_corpus = CountingStep( "How much of the monolingual data was duplicated across the monolingual datasets." ) self.deduplicated_size = FilteringStep( "What was the size of the monolingual data and how much was deduplicated. This " "doesn't count the truncation of datasets at the datasets gathering time." ) self.deduplicated_monolingual_lines = CountingStep( "After deduplication, how much monolingual data is left." ) def filter_and_write_monolingual_data( mono_datasets: list[str], output_path: Path, parallel_hashes: WeakStringSet, max_lines: int, sample_size: int, stats: FilteringStatistics, ) -> None: """ Filtering is done with a set[int]. Seeing if a line is in the set should be O(1) in terms of time complexity. A set[int] was chosen (storing the hash) rather than a set[str], as the latter would retain the string in memory. """ mono_hashes = WeakStringSet() def deduplicate_lines(lines: Generator[str, None, None]) -> Generator[str, None, None]: """ This is the generator that will perform the deduplication on a line stream. It's passed into the shuffler, so needs to be its own function. """ parallel_discards = 0 mono_discards = 0 retained = 0 for line in lines: # Don't add this sentence if it's in the original parallel corpus, or if it's # already present in the monolingual data, perhaps from another source. if line in parallel_hashes: parallel_discards += 1 elif line in mono_hashes: mono_discards += 1 else: retained += 1 mono_hashes.add(line) # Don't add this sentence again. # Report progress periodically. if retained % 1_000_000 == 0: discards = parallel_discards + mono_discards log_memory() logger.info(f"{retained:,} kept, {discards:,} discarded") yield line stats.deduplicated_size.kept = retained stats.deduplicated_size.filtered = parallel_discards + mono_discards stats.deduplicated_monolingual_lines.value = retained stats.duplicates_of_parallel_corpus.value = parallel_discards stats.duplicates_of_monolingual_corpus.value = mono_discards stats.parallel_corpus_lines.value = len(parallel_hashes) # Estimate the byte size. The better the estimate, the better the data distribution will be. # When filtering mono NLLB data against parallel NLLB data, roughly 70% is kept. byte_size_estimate = 0 for dataset in mono_datasets: byte_size_estimate += os.path.getsize(dataset) byte_size_estimate *= 0.7 log_memory(gc_collect=True) logger.info("Deduplicated and shuffling lines in memory.") with read_lines(mono_datasets) as mono_dataset_lines: final_lines = shuffle_with_max_lines( line_stream=deduplicate_lines( mono_dataset_lines, ), seed=347489345, max_lines=max_lines, total_byte_size=byte_size_estimate, ) log_memory(gc_collect=True) logger.info(f"Write the final file: {output_path}") with write_lines(output_path) as outfile: stats.final_truncated_monolingual_lines.value = len(final_lines) for i, line in enumerate(final_lines): stats.final_truncated_monolingual_codepoints.value += len(line) outfile.write(line) if i % 1_000_000 == 999_999: logger.info(f"Wrote line {i+1:,} to {output_path}") log_memory(gc_collect=True) sample_path = output_path.parent / f"{output_path.stem}.sample.txt" logger.info(f"Write a 10,000 line sample of the final: {sample_path}") with write_lines( sample_path, # The browser won't know the encoding when viewing this sample without including # a "byte order mark", which python can do via this encoding. encoding="utf-8-sig", ) as outfile: for line in shuffle_with_max_lines( line_stream=final_lines, seed=9834523434, max_lines=sample_size, total_byte_size=os.path.getsize(output_path), ): outfile.write(line) log_memory(gc_collect=True) stats_path = stats.save_json() logger.info(f"Saved the stats: {stats_path}") def compute_line_hashes(path: Path) -> WeakStringSet: """ In order to de-duplicate sentences we can compute a hash and store it in memory. This makes it so that we don't have to store the full sentence in memory. It's about 10 bytes per int stored in the set. """ line_hashes = WeakStringSet() sentences_visited = 0 with read_lines(path) as lines: for line in lines: sentences_visited += 1 if sentences_visited % 1_000_000 == 0: logger.info(f"Hashing sentence {sentences_visited:,}") line_hashes.add(line) return line_hashes def main() -> None: parser = argparse.ArgumentParser(description="Merge monolingual datasets.") parser.add_argument( "--parallel_corpus", type=Path, help="The path to the parallel corpus of this language, e.g. $MOZ_FETCHES_DIR/corpus.ca.zst", ) parser.add_argument( "--output", type=Path, help="The path to the output compressed file, e.g. /builds/worker/artifacts/mono.ca.zst", ) parser.add_argument( "--max_sentences", type=int, help="The maximum number of sentences that will be merged." ) parser.add_argument( "--datasets_glob", type=str, help="A glob-style path to the mono datasets, e.g. /path/to/*.zst", ) parser.add_argument( "--sample_size", type=int, default=10_000, help="Generate a random sample of sentences." ) args = parser.parse_args() output_path: Path = args.output max_sentences: int = args.max_sentences parallel_corpus: str = args.parallel_corpus mono_dataset_paths: list[str] = glob.glob(args.datasets_glob) if not mono_dataset_paths: raise FileNotFoundError(f"No files found matching glob pattern: {args.datasets_glob}") logger.info("Monolingual datasets:") total_mono_bytes = 0 for path in mono_dataset_paths: formatted_size, bytes = get_human_readable_file_size(path) logger.info(f" - {path} ({formatted_size})") total_mono_bytes += bytes logger.info(f" - {format_bytes(total_mono_bytes)} total") formatted_size = (get_human_readable_file_size(path))[0] logger.info("Parallel corpus:") logger.info(f" - {parallel_corpus} ({formatted_size})") # Ensure output directory exists output_dir = output_path.parent output_dir.mkdir(parents=True, exist_ok=True) # Compute the line hashes so that the monolingual data can be de-duplicated. # It's about 10 bytes per hash in a set, so for a 100 million sentence corpus, # it would be ~1G in memory. log_memory() logger.info(f"Compute hashes of the parallel data: {path}") line_hashes = compute_line_hashes(parallel_corpus) stats = FilteringStatistics(output_path) filter_and_write_monolingual_data( mono_datasets=mono_dataset_paths, output_path=output_path, parallel_hashes=line_hashes, max_lines=max_sentences, sample_size=args.sample_size, stats=stats, ) logger.info("Done: Merging monolingual datasets") if __name__ == "__main__": main() log_memory()