pipeline/data/cjk.py (148 lines of code) (raw):

""" Chinese, Japanese, Korean (CJK) specific data importing code """ from enum import Flag from pathlib import Path from typing import Optional import hanzidentifier import opencc from pipeline.common.datasets import Statistics from pipeline.common.downloads import read_lines, write_lines from pipeline.common.logging import get_logger logger = get_logger(__file__) CJK_LANGS = ["zh", "ja", "ko"] class ChineseType(Flag): none = 0 simplified = 1 traditional = 2 class ConversionStep(Statistics): """ When converting data, count how many sentences were converted, and how many were visited. """ def __init__( self, description: str, converted=0, filtered=0, dataset_path: Optional[Path] = None ) -> None: super().__init__(dataset_path) self.description = description self.converted = converted self.filtered = filtered self.visited = 0 class DatasetStatistics(Statistics): def __init__(self, dataset_path: Path, script: ChineseType) -> None: super().__init__(dataset_path) self.script = script self.script_conversion = ConversionStep( f"How many sentences in the dataset were converted to {script.name} or filtered", ) class ChineseConverter: def __init__(self): self.s2t = opencc.OpenCC("s2t.json") self.t2s = opencc.OpenCC("t2s.json") def convert_file( self, input_path: Path, output_path: Path, to: ChineseType ) -> DatasetStatistics: """ Convert all lines to one variant of Chinese """ stats = DatasetStatistics(output_path, to) with write_lines(output_path) as out_file, read_lines(input_path) as lines: for line in lines: stats.script_conversion.visited += 1 ch_type = self._detect(line) if ch_type in (ch_type.none, to): new_line = line else: new_line = self._convert_line(line, to) stats.script_conversion.converted += 1 out_file.write(new_line) return stats def filter_file(self, input_path: Path, output_path: Path, variant: ChineseType): """ Filter everything except the specified variant of Chinese """ stats = DatasetStatistics(output_path, variant) with write_lines(output_path) as out_file, read_lines(input_path) as lines: for line in lines: stats.script_conversion.visited += 1 ch_type = self._detect(line) if ch_type == variant: out_file.write(line) else: stats.script_conversion.filtered += 1 return stats def filter_parallel_corpus( self, zh_path: Path, other_path: Path, zh_output_path: Path, other_output_path: Path, variant: ChineseType, ): """ Filter everything except the specified variant of Chinese in a parallel corpus """ stats = DatasetStatistics(zh_output_path, variant) with ( write_lines(zh_output_path) as zh_out_file, write_lines(other_output_path) as other_out_file, read_lines(zh_path) as zh_lines, read_lines(other_path) as other_lines, ): for zh_line, other_line in zip(zh_lines, other_lines): stats.script_conversion.visited += 1 ch_type = self._detect(zh_line) if ch_type == variant: zh_out_file.write(zh_line) other_out_file.write(other_line) else: stats.script_conversion.filtered += 1 return stats @staticmethod def _detect(text) -> ChineseType: res = hanzidentifier.identify(text) if res == hanzidentifier.SIMPLIFIED: return ChineseType.simplified if res == hanzidentifier.TRADITIONAL: return ChineseType.traditional if res in (hanzidentifier.BOTH, hanzidentifier.MIXED): return ChineseType.traditional | ChineseType.simplified return ChineseType.none def _convert_line(self, text: str, to: ChineseType) -> str: if to == ChineseType.simplified: return self.t2s.convert(text) elif to == ChineseType.traditional: return self.s2t.convert(text) raise ValueError(f"Unsupported type: {to}") def handle_chinese_mono(file_destination: Path, is_src: bool, variant: ChineseType): converted_path = file_destination.with_suffix(".converted.zst") chinese_converter = ChineseConverter() if is_src: logger.info(f"Converting the output file to {variant}") stats = chinese_converter.convert_file(file_destination, converted_path, variant) else: logger.info(f"Filtering out everything except {variant} in the output file") stats = chinese_converter.filter_file(file_destination, converted_path, variant) converted_path.replace(file_destination) print( f"Converted {stats.script_conversion.converted}, Filtered: {stats.script_conversion.filtered} Visited: {stats.script_conversion.visited}" ) stats.save_json() def handle_chinese_parallel(output_prefix: str, src: str, trg: str, variant: ChineseType): if "zh" not in (src, trg): raise ValueError("Run only for Chinese") chinese_converter = ChineseConverter() is_src = src == "zh" if is_src: logger.info(f"Converting the output file to {variant}") input_path = Path(f"{output_prefix}.{src}.zst") converted_path = Path(f"{output_prefix}.converted.{src}.zst") stats = chinese_converter.convert_file( input_path=input_path, output_path=converted_path, to=variant, ) converted_path.replace(input_path) else: logger.info(f"Filtering out everything except {variant} from a parallel corpus") trg_path = Path(f"{output_prefix}.{trg}.zst") src_path = Path(f"{output_prefix}.{src}.zst") trg_filtered_path = Path(f"{output_prefix}.filtered.{trg}.zst") src_filtered_path = Path(f"{output_prefix}.filtered.{src}.zst") stats = chinese_converter.filter_parallel_corpus( zh_path=trg_path, other_path=src_path, zh_output_path=trg_filtered_path, other_output_path=src_filtered_path, variant=variant, ) src_filtered_path.replace(src_path) trg_filtered_path.replace(trg_path) print( f"Converted {stats.script_conversion.converted}, Filtered: {stats.script_conversion.filtered} Visited: {stats.script_conversion.visited}" ) stats.save_json()