"""Standard definitions for reusable script arguments."""

import fnmatch
import logging
import re
from argparse import Action
from functools import partial

from google.cloud import bigquery

from bigquery_etl.config import ConfigLoader
from bigquery_etl.util.common import TempDatasetReference


def add_argument(parser, *args, **kwargs):
    """Add default to help while adding argument to parser."""
    if "help" in kwargs:
        default = kwargs.get("default")
        if default not in (None, [], [None]):
            if kwargs.get("nargs") in ("*", "+"):
                # unnest a single default for printing, if possible
                try:
                    (default,) = default
                except ValueError:
                    pass
            kwargs["help"] += f"; Defaults to {default}"
    parser.add_argument(*args, **kwargs)


def add_billing_projects(parser, *extra_args, default=[None]):
    """Add argument for billing projects."""
    add_argument(
        parser,
        "-p",
        "--billing-projects",
        "--billing_projects",
        "--billing-project",
        "--billing_project",
        *extra_args,
        nargs="+",
        default=default,
        help="One or more billing projects over which bigquery jobs should be "
        "distributed",
    )


def add_dry_run(parser, debug_log_queries=True):
    """Add argument for dry run."""
    add_argument(
        parser,
        "--dry_run",
        "--dry-run",
        action="store_true",
        help="Do not make changes, only log actions that would be taken"
        + (
            "; Use with --log-level=DEBUG to log query contents"
            if debug_log_queries
            else ""
        ),
    )


def add_log_level(parser, default=logging.getLevelName(logging.INFO)):
    """Add argument for log level."""
    add_argument(
        parser,
        "-l",
        "--log-level",
        "--log_level",
        action=LogLevelAction,
        default=default,
        type=str.upper,
        help="Set logging level for the python root logger",
    )


def add_parallelism(parser, default=4):
    """Add argument for parallel execution."""
    add_argument(
        parser,
        "-P",
        "--parallelism",
        default=default,
        type=int,
        help="Maximum number of tasks to execute concurrently",
    )


def add_priority(parser):
    """Add argument for BigQuery job priority."""
    add_argument(
        parser,
        "--priority",
        default=bigquery.QueryPriority.INTERACTIVE,
        type=str.upper,
        choices=[bigquery.QueryPriority.BATCH, bigquery.QueryPriority.INTERACTIVE],
        help="Priority for BigQuery query jobs; BATCH priority may significantly slow "
        "down queries if reserved slots are not enabled for the billing project; "
        "INTERACTIVE priority is limited to 100 concurrent queries per project",
    )


def add_table_filter(parser, example="telemetry_stable.main_v*"):
    """Add arguments for filtering tables."""
    example_ = f"Pass names or globs like {example!r}"
    add_argument(
        parser,
        "-o",
        "--only",
        nargs="+",
        dest="table_filter",
        raw_dest="only_tables",
        action=TableFilterAction,
        help=f"Process only the given tables; {example_}",
    )
    add_argument(
        parser,
        "-x",
        "--except",
        nargs="+",
        dest="table_filter",
        raw_dest="except_tables",
        action=TableFilterAction,
        help=f"Process all tables except for the given tables; {example_}",
    )


def add_temp_dataset(parser, *extra_args):
    """Add argument for temporary dataset."""
    add_argument(
        parser,
        "--temp-dataset",
        "--temp_dataset",
        "--temporary-dataset",
        "--temporary_dataset",
        *extra_args,
        default=f"{ConfigLoader.get('default', 'project', fallback='moz-fx-data-shared-prod')}.tmp",
        type=TempDatasetReference.from_string,
        help="Dataset where intermediate query results will be temporarily stored, "
        "formatted as PROJECT_ID.DATASET_ID",
    )


class LogLevelAction(Action):
    """Custom argparse.Action for --log-level."""

    def __init__(self, *args, **kwargs):
        """Set default log level if provided."""
        super().__init__(*args, **kwargs)
        if self.default is not None:
            logging.root.setLevel(self.default)

    def __call__(self, parser, namespace, value, option_string=None):
        """Set level for root logger."""
        logging.root.setLevel(value)


class TableFilterAction(Action):
    """Custom argparse.Action for --only and --except."""

    def __init__(self, *args, raw_dest, **kwargs):
        """Add default."""
        super().__init__(*args, default=self.default, **kwargs)
        self.raw_dest = raw_dest
        self.arg = self.option_strings[-1]
        self.invert_match = self.arg == "--except"

    @staticmethod
    def default(table):
        """Return True for default predicate."""
        return True

    @staticmethod
    def compile(values):
        """Compile a list of glob patterns into a single regex."""
        return re.compile("|".join(fnmatch.translate(pattern) for pattern in values))

    def predicate(self, table, pattern):
        """Log tables skipped due to table filter arguments."""
        matched = (pattern.match(table) is not None) != self.invert_match
        if not matched:
            logging.info(f"Skipping {table} due to {self.arg} argument")
        return matched

    def __call__(self, parser, namespace, values, option_string=None):
        """Add table filter to predicates."""
        setattr(namespace, self.raw_dest, values)
        predicates_attr = "_" + self.dest
        predicates = getattr(namespace, predicates_attr, [])
        if not hasattr(namespace, predicates_attr):
            setattr(namespace, predicates_attr, predicates)
            setattr(
                namespace,
                self.dest,
                lambda table: all(predicate(table) for predicate in predicates),
            )
        predicates.append(partial(self.predicate, pattern=self.compile(values)))
