#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import time
from logging import INFO
from typing import List, Optional, Tuple, Type

from ax.core.base_trial import BaseTrial
from ax.core.experiment import Experiment
from ax.core.generator_run import GeneratorRun
from ax.exceptions.core import ObjectNotFoundError
from ax.exceptions.core import UnsupportedError
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.utils.common.executils import retry_on_exception
from ax.utils.common.logger import _round_floats_for_logging, get_logger
from ax.utils.common.typeutils import not_none


RETRY_EXCEPTION_TYPES: Tuple[Type[Exception], ...] = ()

try:  # We don't require SQLAlchemy by default.
    from ax.storage.sqa_store.db import init_engine_and_session_factory
    from ax.storage.sqa_store.decoder import Decoder
    from ax.storage.sqa_store.encoder import Encoder
    from ax.storage.sqa_store.load import (
        _get_experiment_id,
        _get_generation_strategy_id,
        _load_experiment,
        _load_generation_strategy_by_experiment_name,
    )
    from ax.storage.sqa_store.save import (
        _save_or_update_trials,
        _save_experiment,
        _save_generation_strategy,
        _update_generation_strategy,
        update_properties_on_experiment,
    )
    from ax.storage.sqa_store.sqa_config import SQAConfig
    from ax.storage.sqa_store.structs import DBSettings
    from sqlalchemy.exc import OperationalError
    from sqlalchemy.orm.exc import StaleDataError

    # We retry on `OperationalError` if saving to DB.
    RETRY_EXCEPTION_TYPES = (OperationalError, StaleDataError)
except ModuleNotFoundError:  # pragma: no cover
    DBSettings = None
    Decoder = None
    Encoder = None
    SQAConfig = None


STORAGE_MINI_BATCH_SIZE = 50
LOADING_MINI_BATCH_SIZE = 10000

logger = get_logger(__name__)


class WithDBSettingsBase:
    """Helper class providing methods for saving changes made to an experiment
    if `db_settings` property is set to a non-None value on the instance.
    """

    _db_settings: Optional[DBSettings] = None

    def __init__(
        self,
        db_settings: Optional[DBSettings] = None,
        logging_level: int = INFO,
        suppress_all_errors: bool = False,
    ) -> None:
        if db_settings and (not DBSettings or not isinstance(db_settings, DBSettings)):
            raise ValueError(
                "`db_settings` argument should be of type ax.storage.sqa_store."
                f"(Got: {db_settings} of type {type(db_settings)}. "
                "structs.DBSettings. To use `DBSettings`, you will need SQLAlchemy "
                "installed in your environment (can be installed through pip)."
            )
        self._db_settings = db_settings or self._get_default_db_settings()
        self._suppress_all_errors = suppress_all_errors
        if self.db_settings_set:
            init_engine_and_session_factory(
                creator=self.db_settings.creator, url=self.db_settings.url
            )
        logger.setLevel(logging_level)

    @staticmethod
    def _get_default_db_settings() -> Optional[DBSettings]:
        """Overridable method to get default db_settings
        if none are passed in __init__
        """
        return None

    @property
    def db_settings_set(self) -> bool:
        """Whether non-None DB settings are set on this instance."""
        return self._db_settings is not None

    @property
    def db_settings(self) -> DBSettings:
        """DB settings set on this instance; guaranteed to be non-None."""
        if self._db_settings is None:
            raise ValueError("No DB settings are set on this instance.")
        return not_none(self._db_settings)

    def _get_experiment_and_generation_strategy_db_id(
        self, experiment_name: str
    ) -> Tuple[Optional[int], Optional[int]]:
        """Retrieve DB ids of experiment by the given name and the associated
        generation strategy. Each ID is None if corresponding object is not
        found.
        """
        if not self.db_settings_set:
            return None, None

        exp_id = _get_experiment_id(
            experiment_name=experiment_name, config=self.db_settings.decoder.config
        )
        if not exp_id:
            return None, None
        gs_id = _get_generation_strategy_id(
            experiment_name=experiment_name, decoder=self.db_settings.decoder
        )
        return exp_id, gs_id

    def _maybe_save_experiment_and_generation_strategy(
        self, experiment: Experiment, generation_strategy: GenerationStrategy
    ) -> Tuple[bool, bool]:
        """If DB settings are set on this `WithDBSettingsBase` instance, checks
        whether given experiment and generation strategy are already saved and
        saves them, if not.

        Returns:
            Tuple of two booleans: whether experiment was saved in the course of
                this function's execution and whether generation strategy was
                saved.
        """
        saved_exp, saved_gs = False, False
        if self.db_settings_set:
            if experiment._name is None:
                raise ValueError(
                    "Experiment must specify a name to use storage functionality."
                )
            exp_name = not_none(experiment.name)
            exp_id, gs_id = self._get_experiment_and_generation_strategy_db_id(
                experiment_name=exp_name
            )
            if exp_id:  # Experiment in DB.
                logger.info(f"Experiment {exp_name} is in DB, updating it.")
                self._save_experiment_to_db_if_possible(experiment=experiment)
                saved_exp = True
            else:  # Experiment not yet in DB.
                logger.info(f"Experiment {exp_name} is not yet in DB, storing it.")
                self._save_experiment_to_db_if_possible(experiment=experiment)
                saved_exp = True
            if gs_id and generation_strategy._db_id != gs_id:
                raise UnsupportedError(
                    "Experiment was associated with generation strategy in DB, "
                    f"but a new generation strategy {generation_strategy.name} "
                    "was provided. To use the generation strategy currently in DB,"
                    " instantiate scheduler via: `Scheduler.from_stored_experiment`."
                )
            if not gs_id or generation_strategy._db_id is None:
                # There is no GS associated with experiment or the generation
                # strategy passed in is different from the one associated with
                # experiment currently.
                logger.info(
                    f"Generation strategy {generation_strategy.name} is not yet in DB, "
                    "storing it."
                )
                # If generation strategy does not yet have an experiment attached,
                # attach the current experiment to it, as otherwise it will not be
                # possible to retrieve by experiment name.
                if generation_strategy._experiment is None:
                    generation_strategy.experiment = experiment
                self._save_generation_strategy_to_db_if_possible(
                    generation_strategy=generation_strategy
                )
                saved_gs = True
        return saved_exp, saved_gs

    def _load_experiment_and_generation_strategy(
        self,
        experiment_name: str,
        reduced_state: bool = False,
    ) -> Tuple[Optional[Experiment], Optional[GenerationStrategy]]:
        """Loads experiment and its corresponding generation strategy from database
        if DB settings are set on this `WithDBSettingsBase` instance.

        Args:
            experiment_name: Name of the experiment to load, used as unique
                identifier by which to find the experiment.
            reduced_state: Whether to load experiment and generation strategy
                with a slightly reduced state (without abandoned arms on experiment
                and model state on each generator run in experiment and generation
                strategy; last generator run on generation strategy will still
                have its model state).

        Returns:
            - Tuple of `None` and `None` if `DBSettings` are set and no experiment
              exists by the given name.
            - Tuple of `Experiment` and `None` if experiment exists but does not
              have a generation strategy attached to it.
            - Tuple of `Experiment` and `GenerationStrategy` if experiment exists
              and has a generation strategy attached to it.
        """
        if not self.db_settings_set:
            raise ValueError("Cannot load from DB in absence of DB settings.")

        logger.info(
            "Loading experiment and generation strategy (with reduced state: "
            f"{reduced_state})..."
        )
        start_time = time.time()
        experiment = _load_experiment(
            experiment_name,
            decoder=self.db_settings.decoder,
            reduced_state=reduced_state,
            load_trials_in_batches_of_size=LOADING_MINI_BATCH_SIZE,
        )
        if not isinstance(experiment, Experiment):
            raise ValueError("Service API only supports `Experiment`.")
        logger.info(
            f"Loaded experiment {experiment_name} in "
            f"{_round_floats_for_logging(time.time() - start_time)} seconds, "
            f"loading trials in mini-batches of {LOADING_MINI_BATCH_SIZE}."
        )

        try:
            start_time = time.time()
            generation_strategy = _load_generation_strategy_by_experiment_name(
                experiment_name=experiment_name,
                decoder=self.db_settings.decoder,
                experiment=experiment,
                reduced_state=reduced_state,
            )
            logger.info(
                f"Loaded generation strategy for experiment {experiment_name} in "
                f"{_round_floats_for_logging(time.time() - start_time)} seconds."
            )
        except ObjectNotFoundError:
            logger.info(
                "There is no generation strategy associated with experiment "
                f"{experiment_name}."
            )

            return experiment, None

        return experiment, generation_strategy

    def _save_experiment_to_db_if_possible(self, experiment: Experiment) -> bool:
        """Saves attached experiment and generation strategy if DB settings are
        set on this `WithDBSettingsBase` instance.

        Args:
            experiment: Experiment to save new trials in DB.

        Returns:
            bool: Whether the experiment was saved.
        """
        if self.db_settings_set:
            _save_experiment_to_db_if_possible(
                experiment=experiment,
                encoder=self.db_settings.encoder,
                decoder=self.db_settings.decoder,
                suppress_all_errors=self._suppress_all_errors,
            )
            return True
        return False

    def _save_or_update_trials_and_generation_strategy_if_possible(
        self,
        experiment: Experiment,
        trials: List[BaseTrial],
        generation_strategy: GenerationStrategy,
        new_generator_runs: List[GeneratorRun],
        reduce_state_generator_runs: bool = False,
    ) -> None:
        """Saves new trials (and updates existing ones) on given experiment
        and updates the given generation strategy, if DB settings are set on
        this `WithDBSettingsBase` instance.

        Args:
            experiment: Experiment, on which to save new trials in DB.
            trials: Newly added or updated trials to save or update in DB.
            generation_strategy: Generation strategy to update in DB.
            new_generator_runs: Generator runs to add to generation strategy.
        """
        logger.debug(f"Saving or updating {len(trials)} trials in DB.")
        self._save_or_update_trials_in_db_if_possible(
            experiment=experiment,
            trials=trials,
            reduce_state_generator_runs=reduce_state_generator_runs,
        )
        logger.debug(
            "Updating generation strategy in DB with "
            f"{len(new_generator_runs)} generator runs."
        )
        self._update_generation_strategy_in_db_if_possible(
            generation_strategy=generation_strategy,
            new_generator_runs=new_generator_runs,
            reduce_state_generator_runs=reduce_state_generator_runs,
        )
        return

    # No retries needed, covered in `self._save_or_update_trials_in_db_if_possible`
    def _save_or_update_trial_in_db_if_possible(
        self,
        experiment: Experiment,
        trial: BaseTrial,
    ) -> bool:
        """Saves new trial on given experiment if DB settings are set on this
        `WithDBSettingsBase` instance.

        Args:
            experiment: Experiment, on which to save new trial in DB.
            trials: Newly added trial to save.

        Returns:
            bool: Whether the trial was saved.
        """
        return self._save_or_update_trials_in_db_if_possible(
            experiment=experiment,
            trials=[trial],
        )

    def _save_or_update_trials_in_db_if_possible(
        self,
        experiment: Experiment,
        trials: List[BaseTrial],
        reduce_state_generator_runs: bool = False,
    ) -> bool:
        """Saves new trials or update existing trials on given experiment if DB
        settings are set on this `WithDBSettingsBase` instance.

        Args:
            experiment: Experiment, on which to save new trials in DB.
            trials: Newly added trials to save.

        Returns:
            bool: Whether the trials were saved.
        """
        if self.db_settings_set:
            _save_or_update_trials_in_db_if_possible(
                experiment=experiment,
                trials=trials,
                encoder=self.db_settings.encoder,
                decoder=self.db_settings.decoder,
                suppress_all_errors=self._suppress_all_errors,
                reduce_state_generator_runs=reduce_state_generator_runs,
            )
            return True
        return False

    def _save_generation_strategy_to_db_if_possible(
        self, generation_strategy: GenerationStrategy, suppress_all_errors: bool = False
    ) -> bool:
        """Saves given generation strategy if DB settings are set on this
        `WithDBSettingsBase` instance.

        Args:
            generation_strategy: Generation strategy to save in DB.

        Returns:
            bool: Whether the generation strategy was saved.
        """
        if self.db_settings_set:
            _save_generation_strategy_to_db_if_possible(
                generation_strategy=generation_strategy,
                encoder=self.db_settings.encoder,
                decoder=self.db_settings.decoder,
                suppress_all_errors=self._suppress_all_errors,
            )
            return True
        return False

    def _update_generation_strategy_in_db_if_possible(
        self,
        generation_strategy: GenerationStrategy,
        new_generator_runs: List[GeneratorRun],
        reduce_state_generator_runs: bool = False,
    ) -> bool:
        """Updates the given generation strategy with new generator runs (and with
        new current generation step if applicable) if DB settings are set
        on this `WithDBSettingsBase` instance.

        Args:
            generation_strategy: Generation strategy to update in DB.
            new_generator_runs: New generator runs of this generation strategy
                since its last save.

        Returns:
            bool: Whether the experiment was saved.
        """
        if self.db_settings_set:
            _update_generation_strategy_in_db_if_possible(
                generation_strategy=generation_strategy,
                new_generator_runs=new_generator_runs,
                encoder=self.db_settings.encoder,
                decoder=self.db_settings.decoder,
                suppress_all_errors=self._suppress_all_errors,
                reduce_state_generator_runs=reduce_state_generator_runs,
            )
            return True
        return False

    def _update_experiment_properties_in_db(
        self,
        experiment_with_updated_properties: Experiment,
    ) -> bool:
        exp = experiment_with_updated_properties
        if self.db_settings_set:
            _update_experiment_properties_in_db(
                experiment_with_updated_properties=exp,
                sqa_config=self.db_settings.encoder.config,
                suppress_all_errors=self._suppress_all_errors,
            )
            return True
        return False


# ------------- Utils for storage that assume `DBSettings` are provided --------


@retry_on_exception(
    retries=3,
    default_return_on_suppression=False,
    exception_types=RETRY_EXCEPTION_TYPES,
)
def _save_experiment_to_db_if_possible(
    experiment: Experiment,
    encoder: Encoder,
    decoder: Decoder,
    suppress_all_errors: bool,
) -> None:
    start_time = time.time()
    _save_experiment(
        experiment,
        encoder=encoder,
        decoder=decoder,
    )
    logger.debug(
        f"Saved experiment {experiment.name} in "
        f"{_round_floats_for_logging(time.time() - start_time)} seconds."
    )


@retry_on_exception(
    retries=3,
    default_return_on_suppression=False,
    exception_types=RETRY_EXCEPTION_TYPES,
)
def _save_or_update_trials_in_db_if_possible(
    experiment: Experiment,
    trials: List[BaseTrial],
    encoder: Encoder,
    decoder: Decoder,
    suppress_all_errors: bool,
    reduce_state_generator_runs: bool = False,
) -> None:
    start_time = time.time()
    _save_or_update_trials(
        experiment=experiment,
        trials=trials,
        encoder=encoder,
        decoder=decoder,
        batch_size=STORAGE_MINI_BATCH_SIZE,
        reduce_state_generator_runs=reduce_state_generator_runs,
    )
    logger.debug(
        f"Saved or updated trials {[trial.index for trial in trials]} in "
        f"{_round_floats_for_logging(time.time() - start_time)} seconds "
        f"in mini-batches of {STORAGE_MINI_BATCH_SIZE}."
    )


@retry_on_exception(
    retries=3,
    default_return_on_suppression=False,
    exception_types=RETRY_EXCEPTION_TYPES,
)
def _save_generation_strategy_to_db_if_possible(
    generation_strategy: GenerationStrategy,
    encoder: Encoder,
    decoder: Decoder,
    suppress_all_errors: bool,
) -> None:
    start_time = time.time()
    _save_generation_strategy(
        generation_strategy=generation_strategy,
        encoder=encoder,
        decoder=decoder,
    )
    logger.debug(
        f"Saved generation strategy {generation_strategy.name} in "
        f"{_round_floats_for_logging(time.time() - start_time)} seconds."
    )


@retry_on_exception(
    retries=3,
    default_return_on_suppression=False,
    exception_types=RETRY_EXCEPTION_TYPES,
)
def _update_generation_strategy_in_db_if_possible(
    generation_strategy: GenerationStrategy,
    new_generator_runs: List[GeneratorRun],
    encoder: Encoder,
    decoder: Decoder,
    suppress_all_errors: bool,
    reduce_state_generator_runs: bool = False,
) -> None:
    start_time = time.time()
    _update_generation_strategy(
        generation_strategy=generation_strategy,
        generator_runs=new_generator_runs,
        encoder=encoder,
        decoder=decoder,
        batch_size=STORAGE_MINI_BATCH_SIZE,
        reduce_state_generator_runs=reduce_state_generator_runs,
    )
    logger.debug(
        f"Updated generation strategy {generation_strategy.name} in "
        f"{_round_floats_for_logging(time.time() - start_time)} seconds in "
        f"mini-batches of {STORAGE_MINI_BATCH_SIZE} generator runs."
    )


@retry_on_exception(
    retries=3,
    default_return_on_suppression=False,
    exception_types=RETRY_EXCEPTION_TYPES,
)
def _update_experiment_properties_in_db(
    experiment_with_updated_properties: Experiment,
    sqa_config: SQAConfig,
    suppress_all_errors: bool,
) -> None:
    update_properties_on_experiment(
        experiment_with_updated_properties=experiment_with_updated_properties,
        config=sqa_config,
    )
