#!/usr/bin/env python3

"""
Get the statistical distribution of a dataset.

Usage:

    python3 pipeline/data/analyze.py \
        --file_location data.en.zst
        --output ./artifacts
        --dataset "opus_NLLB/v1"
        --language en

For parallel corpora, add the arguments twice, separated by a `--`.
"""

import argparse
import gzip
import os
import sys
from typing import Optional

import matplotlib.pyplot as plt
import numpy as np
import zstandard
from matplotlib import ticker

# Ensure the pipeline is available on the path.
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../.."))

from pipeline.common.datasets import Dataset
from pipeline.common.downloads import (
    RemoteGzipLineStreamer,
    RemoteZstdLineStreamer,
)
from pipeline.common.logging import get_logger

logger = get_logger(__file__)


def get_line_streamer(file_location: str):
    """Streams in lines from remote locations, or from disk. Accepts zst, gz, and plain text."""
    if file_location.startswith("http://") or file_location.startswith("https://"):
        if file_location.endswith(".zst"):
            return RemoteZstdLineStreamer(file_location)
        # Assume gzip.
        return RemoteGzipLineStreamer(file_location)

    if file_location.endswith(".gz"):
        return gzip.open(file_location, "rt")
    if file_location.endswith(".zst"):
        return zstandard.open(file_location, "rt")
    return open(file_location, "rt")


def main(args: Optional[list[str]] = None) -> None:
    parser = argparse.ArgumentParser(
        description=__doc__,
        formatter_class=argparse.RawTextHelpFormatter,  # Preserves whitespace in the help text.
    )
    parser.add_argument(
        "--file_location", type=str, required=True, help="A url or file path for analyzing."
    )
    parser.add_argument(
        "--output_dir", type=str, required=True, help="The directory for the output."
    )
    parser.add_argument("--dataset", type=str, required=True, help="The name of the dataset")
    parser.add_argument(
        "--language",
        type=str,
        required=True,
        help="The dataset language, as a BCP-47 language tag",
    )
    # All the use of "--" to add more arguments.
    parser.add_argument("next_dataset_args", nargs=argparse.REMAINDER)

    parsed_args = parser.parse_args(args)

    # Defer parsing any options after "--", and recurse below if there are some.
    next_dataset_args: Optional[list[str]] = None
    if len(parsed_args.next_dataset_args):
        if parsed_args.next_dataset_args[0] != "--":
            print(next_dataset_args)
            raise Exception("Unexpected arguments. Use -- to pass in multiple datasets.")
        next_dataset_args = parsed_args.next_dataset_args[1:]

    logger.info(f"file_location: {parsed_args.file_location}")
    logger.info(f"output_dir: {parsed_args.output_dir}")
    logger.info(f"dataset: {parsed_args.dataset}")
    logger.info(f"language: {parsed_args.language}")

    dataset = Dataset(parsed_args.dataset)
    graph_prefix = f"{dataset.file_safe_name()}.{parsed_args.language}"

    # Compute the distributions for both the codepoints, and word size.
    codepoints_distribution = Histogram()
    word_distribution = Histogram()
    with get_line_streamer(parsed_args.file_location) as lines:
        for line in lines:
            codepoints_distribution.count(len(line))
            word_distribution.count(len(line.split()))

    plot_logarithmic_histogram(
        word_distribution,
        max_size=5_000,  # words
        title="\n".join(
            [
                "Word Count Distribution",
                f"{parsed_args.dataset} - {parsed_args.language}",
            ]
        ),
        x_axis_label="Words (log scale)",
        filename=os.path.join(parsed_args.output_dir, f"{graph_prefix}.distribution-words.png"),
    )

    plot_logarithmic_histogram(
        codepoints_distribution,
        max_size=10_000,  # codepoints
        title="\n".join(
            [
                "Codepoints per Sentence Distribution",
                f"{parsed_args.dataset} - {parsed_args.language}",
            ]
        ),
        x_axis_label="Codepoints (log scale)",
        filename=os.path.join(
            parsed_args.output_dir, f"{graph_prefix}.distribution-codepoints.png"
        ),
    )

    if next_dataset_args:
        # Apply the arguments again after "--".
        main(next_dataset_args)


class Histogram:
    """Computes a histogram based on counts."""

    def __init__(self) -> None:
        # The keys are the bins, the values are the counts.
        self.data: dict[int, int] = {}

    def count(self, count: int):
        if count not in self.data:
            self.data[count] = 0
        self.data[count] += 1

    def log_scale_bins(self, max_size: int, bin_count: int = 30) -> list[int]:
        """Converts the linear bins of the histogram into into logscale bins."""
        # Start with a few small value bins, since it's easy to start with some small fractional
        # values on a log scale.
        bins = [1.0, 2.0]
        for edge in np.logspace(np.log10(1), np.log10(max_size), bin_count):
            if edge > 2.0:
                bins.append(edge)
        return bins


def plot_logarithmic_histogram(
    histogram: Histogram, max_size: int, title: str, x_axis_label: str, filename: str
):
    """
    Converts a histogram of values into a logscale graph, where the x axis is logarithmic,
    and the y scale is linear. The x axis represents the bins of the histogram.
    """

    bins = np.array(histogram.log_scale_bins(max_size))

    # Plot a histogram with logarithmic bins.
    plt.title(title)
    plt.hist(histogram.data.keys(), bins=bins, weights=histogram.data.values(), alpha=0.7)

    plt.xlabel(x_axis_label)
    plt.xscale("log")
    plt.xticks(ticks=bins, labels=[f"{int(edge)}" for edge in bins], rotation="vertical")

    plt.ylabel("Frequency (linear)")
    plt.yscale("linear")
    plt.gca().yaxis.set_major_formatter(ticker.StrMethodFormatter("{x:,.0f}"))

    # Ensure no labels are cut off.
    plt.tight_layout()

    logger.info(f"Saving plot to: {filename}")
    plt.savefig(filename, dpi=150)
    plt.close()


if __name__ == "__main__":
    main()
