pipeline/data/hplt.py (189 lines of code) (raw):
import json
import random
from contextlib import ExitStack
from dataclasses import dataclass
from pathlib import Path
import icu
from pipeline.common.datasets import (
CountingStep,
FilteringStep,
Statistics,
WeakStringSet,
)
from pipeline.common.downloads import location_exists, read_lines, write_lines
from pipeline.common.logging import get_logger
from pipeline.common.memory import log_memory
logger = get_logger(__name__)
random.seed(38947598475)
@dataclass
class HPLTDocument:
"""
A structured type for the HPLT document entry in a jsonl file.
https://hplt-project.org/datasets/v2.0
"""
def __init__(self, **json):
self.lang = json["lang"]
self.doc_scores = json["doc_scores"]
self.seg_langs = json["seg_langs"]
# The sentences in the text, which were separated by newlines.
self.lines = json["text"].split("\n")
# The list of detected document languages where the first language is most probable.
# For example: [zho_Hans, zho_Hant, eng_Latn]
lang: list[str]
# The list of document scores from web-docs-scorer where the first score is the overall document score (WDS_score) followed by 8 subscores.
# All the scores are from 0 to 10.
# See https://github.com/pablop16n/web-docs-scorer/
# For example, [8.3, 10, 10, 9.9, 10, 10, 10, 4, 0]
doc_scores: list[float]
# The detected language for each line (segment).
# For example: [yue_Hant, zho_Hans, zho_Hans, zho_Hant, unk, ... ]
seg_langs: list[str]
# All of the text, split by newlines.
lines: list[str]
class FilteringStatistics(Statistics):
"""
Gather statistics about the filtering process.
"""
def __init__(self, dataset_path: Path) -> None:
super().__init__(dataset_path)
self.shards = FilteringStep(
"How many shards were sampled from. Each shard contains a subset of the "
"total datasets available.",
)
self.visited_lines = FilteringStep(
"How many lines were visited and kept from the HPLT documents.",
)
self.document_count = CountingStep(
"How many documents were visited. This can help represent data diversity.",
)
self.duplicate_lines = CountingStep(
"Of the collected lines, this counts how many were duplicates and discarded.",
)
self.final_lines = CountingStep(
"How many lines were actually written.",
)
self.filtered_doc_locale = CountingStep(
"How many lines were filtered based on document locale.",
)
self.filtered_line_locale = CountingStep(
"How many lines were filtered based on line locales.",
)
self.filtered_doc_score = CountingStep(
"How many lines were filtered based on document scores.",
)
self.filtered_too_long = CountingStep(
"How many lines were filtered based on length.",
)
def count_shards_visited(self, *_args):
self.shards.filtered -= 1
self.shards.kept += 1
def get_hplt_locale(lang_iso6931: str) -> str:
"""
Converts language in ISO-693-1 format to the HPLT format.
For example, ru -> rus_Cyrl
"""
# icu return Kore by default which is a mix of Hang and Hani
if lang_iso6931 == "ko":
return "kor_Hang"
locale = icu.Locale(lang_iso6931)
# add default script
locale = icu.Locale.addLikelySubtags(locale)
hplt_locale = f"{locale.getISO3Language()}_{locale.getScript()}"
return hplt_locale
def get_hplt_map_url(hplt_locale: str) -> str:
return f"https://data.hplt-project.org/two/cleaned/{hplt_locale}_map.txt"
def language_has_hplt_support(language: str) -> bool:
hplt_locale = get_hplt_locale(language)
hplt_map = get_hplt_map_url(hplt_locale)
return location_exists(hplt_map)
def load_shuffled_shard_urls(hplt_locale: str) -> list[str]:
"""
Download the list of shards, e.g.
https://data.hplt-project.org/two/cleaned/rus_Cyrl/1.jsonl.zst
https://data.hplt-project.org/two/cleaned/rus_Cyrl/2.jsonl.zst
...
https://data.hplt-project.org/two/cleaned/rus_Cyrl/10.jsonl.zst
"""
url = get_hplt_map_url(hplt_locale)
logger.info(f"Downloading shard list: {url}")
with read_lines(url) as lines:
shard_urls = []
for line in lines:
shard_urls.append(line.strip())
random.Random(url).shuffle(shard_urls)
logger.info(f"Available shards for {hplt_locale}:")
for lines in shard_urls:
logger.info(f" - {lines}")
return shard_urls
class HpltDownloader:
"""
Downloads and filters the HPLT dataset.
https://hplt-project.org/datasets/v2.0
Parameters:
- language: The BCP 47 language code to filter the documents.
- hplt_min_doc_score: The minimum score a document must have to be included in the final dataset.
- max_characters: The maximum number of characters to merge sentences in the document before writing if enabled.
Also filters lines that are too long.
- max_lines: The maximum number of lines to include in the final dataset.
- file_destination: The destination path where the final dataset will be written.
- merge_lines: Whether to accumulate line of the same document in one segment until max_characters is reached.
"""
def __init__(
self,
language: str,
hplt_min_doc_score: float,
max_characters: int,
max_lines: int,
file_destination: Path,
merge_lines: bool,
) -> None:
self.merge_lines = merge_lines
self.max_lines = max_lines
self.max_characters = max_characters
self.hplt_min_doc_score = hplt_min_doc_score
self.hplt_locale = get_hplt_locale(language)
self.accumulated_text = ""
self.cumulative_char_count = 0
self.visited_lines = 0
self.file_destination = file_destination
self.stats = FilteringStatistics(file_destination)
self.strings_seen = WeakStringSet()
self.stack = ExitStack()
self.outfile = self.stack.enter_context(write_lines(file_destination))
def close(self):
self.stack.close()
def download(self):
try:
self._run_download()
finally:
self.close()
def _run_download(self):
logger.info(f"Using HPLT locale {self.hplt_locale}")
shuffled_shard_urls = load_shuffled_shard_urls(self.hplt_locale)
self.stats.shards.filtered = len(shuffled_shard_urls)
# The shard URLs are shuffled, and then streamed into the read_lines iterator.
# This iterator can work over multiple documents. The first document is loaded,
# and then the documents in the shard are read in order from that shard. After
# the first shard is read, the iterator continues with the next shards until
# enough fluent sentences are collected. At this point the remaining shards
# will not be visited.
document_stream = self.stack.enter_context(
read_lines(shuffled_shard_urls, on_enter_location=self.stats.count_shards_visited)
)
for document_json in document_stream:
self.stats.document_count.value += 1
document = HPLTDocument(**json.loads(document_json))
overall_doc_score = document.doc_scores[0]
doc_lang = document.lang[0]
self._maybe_write_accumulated_text()
# HPLT 2.0 uses document level scores
if overall_doc_score < self.hplt_min_doc_score:
self.stats.filtered_doc_score.value += 1
continue
# We want only documents written primarily in the target language
if doc_lang != self.hplt_locale:
self.stats.filtered_doc_locale.value += 1
continue
# Visit the lines in the document.
for line_locale, line in zip(document.seg_langs, document.lines):
self.visited_lines += 1
self._process_line(line_locale, line)
if self.visited_lines % 5_000_000 == 0:
logger.info(f"Visited {self.visited_lines:,} lines")
logger.info(f"Kept {self.stats.visited_lines.kept:,}.")
logger.info(
f"Wrote {self.stats.final_lines.value:,} out of {self.max_lines:,}."
)
log_memory()
if self.stats.final_lines.value == self.max_lines:
break
if self.stats.final_lines.value == self.max_lines:
break
self._maybe_write_accumulated_text()
self.stats.visited_lines.filtered = self.visited_lines - self.stats.visited_lines.kept
logger.info(f"Wrote {self.stats.final_lines.value:,} lines to: {self.file_destination}")
stat_path = self.stats.save_json()
logger.info(f"Saved filtering stats to: {stat_path}")
def _process_line(self, line_locale: str, line: str):
# Line locale does not match expected locale, filter
if line_locale != self.hplt_locale:
self.stats.filtered_line_locale.value += 1
self._maybe_write_accumulated_text()
return
char_count = len(line)
# Filter long segments
if char_count > self.max_characters:
self.stats.filtered_too_long.value += 1
self._maybe_write_accumulated_text()
return
# Just write the current line if merging is disabled
if not self.merge_lines:
self.accumulated_text = line
self.stats.visited_lines.kept += 1
self._maybe_write_accumulated_text()
return
# Text accumulation mode starts here
self.stats.visited_lines.kept += 1
# Determine if this sentence should be added to the previous one or
# written out as a new line.
if self.cumulative_char_count + char_count + 1 > self.max_characters:
# This line would be too long, write it out.
self._maybe_write_accumulated_text()
self.cumulative_char_count += char_count
# Collect this line to write.
if self.accumulated_text:
self.accumulated_text = f"{self.accumulated_text} {line}"
# count the whitespace
self.cumulative_char_count += 1
else:
self.accumulated_text = line
def _maybe_write_accumulated_text(self):
"""
Since the loop below is building up paragraphs of text, we only want to write
out a line when enough text has been accumulated. The paragraph should be
written out when either the text gets too long, or the next line is discarded.
"""
self.cumulative_char_count = 0
if self.accumulated_text:
if self.accumulated_text in self.strings_seen:
self.stats.duplicate_lines.value += 1
else:
self.outfile.write(self.accumulated_text + "\n")
self.stats.final_lines.value += 1
self.strings_seen.add(self.accumulated_text)
self.accumulated_text = ""