pipeline/data/analyze.py (124 lines of code) (raw):
#!/usr/bin/env python3
"""
Get the statistical distribution of a dataset.
Usage:
python3 pipeline/data/analyze.py \
--file_location data.en.zst
--output ./artifacts
--dataset "opus_NLLB/v1"
--language en
For parallel corpora, add the arguments twice, separated by a `--`.
"""
import argparse
import gzip
import os
import sys
from typing import Optional
import matplotlib.pyplot as plt
import numpy as np
import zstandard
from matplotlib import ticker
# Ensure the pipeline is available on the path.
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../.."))
from pipeline.common.datasets import Dataset
from pipeline.common.downloads import (
RemoteGzipLineStreamer,
RemoteZstdLineStreamer,
)
from pipeline.common.logging import get_logger
logger = get_logger(__file__)
def get_line_streamer(file_location: str):
"""Streams in lines from remote locations, or from disk. Accepts zst, gz, and plain text."""
if file_location.startswith("http://") or file_location.startswith("https://"):
if file_location.endswith(".zst"):
return RemoteZstdLineStreamer(file_location)
# Assume gzip.
return RemoteGzipLineStreamer(file_location)
if file_location.endswith(".gz"):
return gzip.open(file_location, "rt")
if file_location.endswith(".zst"):
return zstandard.open(file_location, "rt")
return open(file_location, "rt")
def main(args: Optional[list[str]] = None) -> None:
parser = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawTextHelpFormatter, # Preserves whitespace in the help text.
)
parser.add_argument(
"--file_location", type=str, required=True, help="A url or file path for analyzing."
)
parser.add_argument(
"--output_dir", type=str, required=True, help="The directory for the output."
)
parser.add_argument("--dataset", type=str, required=True, help="The name of the dataset")
parser.add_argument(
"--language",
type=str,
required=True,
help="The dataset language, as a BCP-47 language tag",
)
# All the use of "--" to add more arguments.
parser.add_argument("next_dataset_args", nargs=argparse.REMAINDER)
parsed_args = parser.parse_args(args)
# Defer parsing any options after "--", and recurse below if there are some.
next_dataset_args: Optional[list[str]] = None
if len(parsed_args.next_dataset_args):
if parsed_args.next_dataset_args[0] != "--":
print(next_dataset_args)
raise Exception("Unexpected arguments. Use -- to pass in multiple datasets.")
next_dataset_args = parsed_args.next_dataset_args[1:]
logger.info(f"file_location: {parsed_args.file_location}")
logger.info(f"output_dir: {parsed_args.output_dir}")
logger.info(f"dataset: {parsed_args.dataset}")
logger.info(f"language: {parsed_args.language}")
dataset = Dataset(parsed_args.dataset)
graph_prefix = f"{dataset.file_safe_name()}.{parsed_args.language}"
# Compute the distributions for both the codepoints, and word size.
codepoints_distribution = Histogram()
word_distribution = Histogram()
with get_line_streamer(parsed_args.file_location) as lines:
for line in lines:
codepoints_distribution.count(len(line))
word_distribution.count(len(line.split()))
plot_logarithmic_histogram(
word_distribution,
max_size=5_000, # words
title="\n".join(
[
"Word Count Distribution",
f"{parsed_args.dataset} - {parsed_args.language}",
]
),
x_axis_label="Words (log scale)",
filename=os.path.join(parsed_args.output_dir, f"{graph_prefix}.distribution-words.png"),
)
plot_logarithmic_histogram(
codepoints_distribution,
max_size=10_000, # codepoints
title="\n".join(
[
"Codepoints per Sentence Distribution",
f"{parsed_args.dataset} - {parsed_args.language}",
]
),
x_axis_label="Codepoints (log scale)",
filename=os.path.join(
parsed_args.output_dir, f"{graph_prefix}.distribution-codepoints.png"
),
)
if next_dataset_args:
# Apply the arguments again after "--".
main(next_dataset_args)
class Histogram:
"""Computes a histogram based on counts."""
def __init__(self) -> None:
# The keys are the bins, the values are the counts.
self.data: dict[int, int] = {}
def count(self, count: int):
if count not in self.data:
self.data[count] = 0
self.data[count] += 1
def log_scale_bins(self, max_size: int, bin_count: int = 30) -> list[int]:
"""Converts the linear bins of the histogram into into logscale bins."""
# Start with a few small value bins, since it's easy to start with some small fractional
# values on a log scale.
bins = [1.0, 2.0]
for edge in np.logspace(np.log10(1), np.log10(max_size), bin_count):
if edge > 2.0:
bins.append(edge)
return bins
def plot_logarithmic_histogram(
histogram: Histogram, max_size: int, title: str, x_axis_label: str, filename: str
):
"""
Converts a histogram of values into a logscale graph, where the x axis is logarithmic,
and the y scale is linear. The x axis represents the bins of the histogram.
"""
bins = np.array(histogram.log_scale_bins(max_size))
# Plot a histogram with logarithmic bins.
plt.title(title)
plt.hist(histogram.data.keys(), bins=bins, weights=histogram.data.values(), alpha=0.7)
plt.xlabel(x_axis_label)
plt.xscale("log")
plt.xticks(ticks=bins, labels=[f"{int(edge)}" for edge in bins], rotation="vertical")
plt.ylabel("Frequency (linear)")
plt.yscale("linear")
plt.gca().yaxis.set_major_formatter(ticker.StrMethodFormatter("{x:,.0f}"))
# Ensure no labels are cut off.
plt.tight_layout()
logger.info(f"Saving plot to: {filename}")
plt.savefig(filename, dpi=150)
plt.close()
if __name__ == "__main__":
main()