pipeline/clean/merge-parallel.py (217 lines of code) (raw):

""" Merges multiple parallel corpora into a single "source" language file, and a single "target" language file, each. For instance: dataset1.en.zst dataset1.ru.zst dataset2.en.zst dataset2.ru.zst dataset3.en.zst dataset3.ru.zst Gets merged into: corpus.en.zst corpus.ru.zst """ import argparse from contextlib import ExitStack from glob import glob from pathlib import Path from typing import Generator, Optional from pipeline.common.datasets import ( FilteringStep, Statistics, WeakStringSet, shuffle_with_max_lines, ) from pipeline.common.downloads import get_human_readable_file_size, read_lines, write_lines from pipeline.common.logging import get_logger logger = get_logger(__file__) class FilteringStatistics(Statistics): """ Gather statistics about the filtering process. """ def __init__(self, dataset_path: Path) -> None: super().__init__(dataset_path) self.parallel_corpus = FilteringStep( "The parallel corpora are merged and deduplicated", ) self.final_truncated = FilteringStep("The final result can be truncated by max_lines") self.datasets = [] def add_parallel_dataset(self, location: str): # e.g. /path/to/ada83_v1.en.zst path = Path(location) # e.g. ada83_v1 dataset_stem = Path(path.stem).stem step = FilteringStep(dataset_stem) self.datasets.append(step) return step def log_dataset(location: str): logger.info(f"Reading dataset {location}") class DeduplicateCorpus: def __init__( self, datasets_src: list[Path], datasets_trg: list[Path], src_outpath: Path, trg_outpath: Path, stats: FilteringStatistics, ) -> None: self.datasets_src: list[Path] = datasets_src self.datasets_trg: list[Path] = datasets_trg self.src_outpath: Path = src_outpath self.trg_outpath: Path = trg_outpath self.stats: FilteringStatistics = stats self.dataset_stats: FilteringStep = None def run( self, total_corpus_bytes: int, max_lines: Optional[int], ): stats = self.stats with ExitStack() as stack: src_outfile = stack.enter_context(write_lines(self.src_outpath)) trg_outfile = stack.enter_context(write_lines(self.trg_outpath)) if max_lines: for line in shuffle_with_max_lines( line_stream=self.yield_lines_string(stack), seed=38540735095, max_lines=max_lines, total_byte_size=total_corpus_bytes, ): src_line, trg_line = line.split("\t") src_outfile.write(src_line) trg_outfile.write(trg_line) stats.final_truncated.visited = stats.parallel_corpus.kept stats.final_truncated.kept = min(max_lines, stats.parallel_corpus.kept) else: for src_line, trg_line in self.yield_lines_tuple(stack): src_outfile.write(src_line) trg_outfile.write(trg_line) stats.final_truncated.kept = stats.parallel_corpus.kept stats.final_truncated.visited = stats.parallel_corpus.kept def yield_lines_tuple(self, stack: ExitStack) -> Generator[tuple[str, str], None, None]: strings_seen = WeakStringSet() stats = self.stats src_lines: Generator[str, None, None] = stack.enter_context( read_lines(self.datasets_src, on_enter_location=self.on_enter_location) ) trg_lines: Generator[str, None, None] = stack.enter_context( read_lines(self.datasets_trg, on_enter_location=log_dataset) ) for src_line, trg_line in zip(src_lines, trg_lines): # No separator is needed as the newline is included. line = src_line + trg_line if line in strings_seen: stats.parallel_corpus.filtered += 1 self.dataset_stats.filtered += 1 else: stats.parallel_corpus.kept += 1 self.dataset_stats.kept += 1 strings_seen.add(line) yield src_line, trg_line def yield_lines_string(self, stack: ExitStack) -> Generator[str, None, None]: for src_line, trg_line in self.yield_lines_tuple(stack): if "\t" in src_line or "\t" in trg_line: logger.error("A line contained a tab character, skipping:") logger.error(f" src: {src_line}") logger.error(f" trg: {src_line}") else: yield f"{src_line}\t{trg_line}" def on_enter_location(self, location): log_dataset(location) self.dataset_stats = self.stats.add_parallel_dataset(location) def sample_corpus( artifacts: Path, name: str, sample_size: int, src_outpath: Path, trg_outpath: Path ): """ Generate a sample of the corpus data with the following format: e.g. > cat artifacts/corpus.sample.txt Sentence 1 in source language Sentence 1 in target language Sentence 2 in source language Sentence 2 in target language Sentence 3 in source language Sentence 3 in target language ... """ total_byte_size = src_outpath.stat().st_size + trg_outpath.stat().st_size with ExitStack() as stack: sample_path = artifacts / f"{name}.sample.txt" src_lines = stack.enter_context(read_lines(src_outpath)) trg_lines = stack.enter_context(read_lines(trg_outpath)) sample_outfile = stack.enter_context( write_lines( sample_path, # The browser won't know the encoding when viewing this sample without including # a "byte order mark", which python can do via this encoding. encoding="utf-8-sig", ) ) def join_src_trg(): for src_line, trg_line in zip(src_lines, trg_lines): # The src and trg line each have a newline at the end. This means that # each sentence pair will be separate by a blank line to make for easy # scanning of datasets. yield f"{src_line}{trg_line}\n" logger.info("Stream in:") logger.info(f" - {src_outpath}") logger.info(f" - {trg_outpath}") logger.info(f"Write a {sample_size:,} line sample of the merged corpus:") logger.info(f" - {sample_path}") for line in shuffle_with_max_lines( line_stream=join_src_trg(), seed=9834523434, max_lines=sample_size, total_byte_size=total_byte_size, ): sample_outfile.write(line) def get_datasets(src: str, trg: str, datasets_glob: str): dataset_paths: list[str] = glob(datasets_glob) datasets_src: list[Path] = [] datasets_trg: list[Path] = [] dataset_paths.sort() total_corpus_bytes = 0 for dataset in dataset_paths: path = Path(dataset) if dataset.endswith(f"{src}.zst"): datasets_src.append(path) elif dataset.endswith(f"{trg}.zst"): datasets_trg.append(path) else: raise Exception(f"Dataset does not match naming scheme: {dataset}") formatted_size, bytes = get_human_readable_file_size(path) logger.info(f" - {path} ({formatted_size})") total_corpus_bytes += bytes return datasets_src, datasets_trg, total_corpus_bytes def main() -> None: parser = argparse.ArgumentParser( description=__doc__, # Preserves whitespace in the help text. formatter_class=argparse.RawTextHelpFormatter, ) parser.add_argument( "--src", type=str, help="The source locale", ) parser.add_argument( "--trg", type=str, help="The target locale", ) parser.add_argument( "--datasets_glob", type=str, help="A glob-style path to the mono datasets, e.g. /path/to/*.zst", ) parser.add_argument( "--max_lines", type=str, default="None", help="The (optionally) maximum number of sentences that will be merged.", ) parser.add_argument( "--sample_size", type=int, default=10_000, help="Generate a random sample of sentences." ) parser.add_argument( "--artifacts", type=Path, help="The path to the artifacts directory.", ) parser.add_argument( "--name", type=str, help='The final corpus name, e.g. "corpus" will output a "corpus.en.zst" file.', ) args = parser.parse_args() datasets_src, datasets_trg, total_corpus_bytes = get_datasets( args.src, args.trg, args.datasets_glob ) logger.info("Parallel datasets:") src_outpath = args.artifacts / f"{args.name}.{args.src}.zst" trg_outpath = args.artifacts / f"{args.name}.{args.trg}.zst" stats = FilteringStatistics(args.artifacts / args.name) max_lines: Optional[int] = None if args.max_lines != "None": max_lines = int(args.max_lines) deduplicate_corpus = DeduplicateCorpus( datasets_src, datasets_trg, src_outpath, trg_outpath, stats, ) deduplicate_corpus.run(total_corpus_bytes, max_lines) sample_corpus( artifacts=args.artifacts, name=args.name, sample_size=args.sample_size, src_outpath=src_outpath, trg_outpath=trg_outpath, ) stats.save_json() if __name__ == "__main__": main()