# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from abc import abstractproperty, ABC
from typing import Any, Callable, Optional, Type, Dict

from ax.core.metric import Metric
from ax.core.runner import Runner
from ax.storage.json_store.registry import (
    CORE_ENCODER_REGISTRY,
    CORE_CLASS_DECODER_REGISTRY,
    CORE_DECODER_REGISTRY,
    CORE_CLASS_ENCODER_REGISTRY,
)
from ax.storage.metric_registry import register_metrics
from ax.storage.runner_registry import register_runners
from ax.storage.sqa_store.decoder import Decoder
from ax.storage.sqa_store.encoder import Encoder
from ax.storage.sqa_store.sqa_config import SQAConfig


class RegistryBundleBase(ABC):
    """An abstraction to help with storing experiments with custom Metrics and Runners.

    Rather than managing registries individually, the RegistryBundle consumes custom
    metrics, runners, and configuration information and is lazily creates the storage
    registries needed for saving and loading.

    Args:
        metric_clss: A dictionary from Metric classes to the int their type should be
            encoded as in the associated SQAMetric. If no None is passed for the int
            a hash will be generated.
        runner_clss: A dictionary from Runner classes to the int their type should be
            encoded as in the associated SQARunner. If no None is passed for the int
            a hash will be generated.
        json_encoder_registry: A dictionary from Types to methods from an instance of
            the type to JSON.
        json_class_encoder_registry: A dictionary from Types to methods from the type's
            class to JSON.
        json_decoder_registry: A dictionary from str class labels to their associated
            Type.
        json_class_decoder_registry: A dictionary from str class labels to an
            associated method for reconstruction.
    """

    def __init__(
        self,
        metric_clss: Dict[Type[Metric], Optional[int]],
        runner_clss: Dict[Type[Runner], Optional[int]],
        json_encoder_registry: Dict[Type, Callable[[Any], Dict[str, Any]]],
        json_class_encoder_registry: Dict[Type, Callable[[Any], Dict[str, Any]]],
        json_decoder_registry: Dict[str, Type],
        json_class_decoder_registry: Dict[str, Callable[[Dict[str, Any]], Any]],
    ) -> None:
        self._metric_registry, encoder_registry, decoder_registry = register_metrics(
            metric_clss=metric_clss,
            encoder_registry=json_encoder_registry,
            decoder_registry=json_decoder_registry,
        )
        (
            self._runner_registry,
            self._encoder_registry,
            self._decoder_registry,
        ) = register_runners(
            runner_clss=runner_clss,
            encoder_registry=encoder_registry,
            decoder_registry=decoder_registry,
        )

        self._json_class_encoder_registry = json_class_encoder_registry
        self._json_class_decoder_registry = json_class_decoder_registry

    @property
    def metric_registry(self) -> Dict[Type[Metric], int]:
        return self._metric_registry

    @property
    def runner_registry(self) -> Dict[Type[Runner], int]:
        return self._runner_registry

    @property
    def encoder_registry(self) -> Dict[Type, Callable[[Any], Dict[str, Any]]]:
        return self._encoder_registry

    @property
    def decoder_registry(self) -> Dict[str, Type]:
        return self._decoder_registry

    @property
    def class_encoder_registry(self) -> Dict[Type, Callable[[Any], Dict[str, Any]]]:
        return self._json_class_encoder_registry

    @property
    def class_decoder_registry(self) -> Dict[str, Callable[[Dict[str, Any]], Any]]:
        return self._json_class_decoder_registry

    @abstractproperty
    def sqa_config(self) -> SQAConfig:
        pass

    @abstractproperty
    def encoder(self) -> Encoder:
        pass

    @abstractproperty
    def decoder(self) -> Decoder:
        pass


class RegistryBundle(RegistryBundleBase):
    """A concrete implementation of RegistryBundleBase with sensible defaults."""

    def __init__(
        self,
        metric_clss: Dict[Type[Metric], Optional[int]],
        runner_clss: Dict[Type[Runner], Optional[int]],
        json_encoder_registry: Dict[
            Type, Callable[[Any], Dict[str, Any]]
        ] = CORE_ENCODER_REGISTRY,
        json_class_encoder_registry: Dict[
            Type, Callable[[Any], Dict[str, Any]]
        ] = CORE_CLASS_ENCODER_REGISTRY,
        json_decoder_registry: Dict[str, Type] = CORE_DECODER_REGISTRY,
        json_class_decoder_registry: Dict[
            str, Callable[[Dict[str, Any]], Any]
        ] = CORE_CLASS_DECODER_REGISTRY,
    ) -> None:
        super().__init__(
            metric_clss=metric_clss,
            runner_clss=runner_clss,
            json_encoder_registry=json_encoder_registry,
            json_class_encoder_registry=json_class_encoder_registry,
            json_decoder_registry=json_decoder_registry,
            json_class_decoder_registry=json_class_decoder_registry,
        )
        self._sqa_config = SQAConfig(
            json_encoder_registry={**self.encoder_registry, **CORE_ENCODER_REGISTRY},
            json_decoder_registry={**self.decoder_registry, **CORE_DECODER_REGISTRY},
            metric_registry=self.metric_registry,
            runner_registry=self.runner_registry,
            json_class_encoder_registry=self.class_encoder_registry,
            json_class_decoder_registry=self.class_decoder_registry,
        )

        self._encoder = Encoder(self._sqa_config)
        self._decoder = Decoder(self._sqa_config)

    # TODO[mpolson64] change @property to @cached_property once we deprecate 3.7
    @property
    def sqa_config(self) -> SQAConfig:
        return self._sqa_config

    @property
    def encoder(self) -> Encoder:
        return self._encoder

    @property
    def decoder(self) -> Decoder:
        return self._decoder
