# Copyright 2022 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import logging
import time
import threading
from typing import Dict, Optional, Union
import warnings

import torch
from ..partition import PartitionBook

from ..channel import ShmChannel, QueueTimeoutError
from ..sampler import NodeSamplerInput, EdgeSamplerInput, SamplingConfig, RemoteSamplerInput

from .dist_context import get_context, _set_server_context
from .dist_dataset import DistDataset
from .dist_options import RemoteDistSamplingWorkerOptions
from .dist_sampling_producer import DistMpSamplingProducer
from .rpc import barrier, init_rpc, shutdown_rpc

SERVER_EXIT_STATUS_CHECK_INTERVAL = 5.0
r""" Interval (in seconds) to check exit status of server.
"""

class DistServer(object):
  r""" A server that supports launching remote sampling workers for
  training clients.

  Note that this server is enabled only when the distribution mode is a
  server-client framework, and the graph and feature store will be partitioned
  and managed by all server nodes.

  Args:
    dataset (DistDataset): The ``DistDataset`` object of a partition of graph
      data and feature data, along with distributed patition books.
  """
  def __init__(self, dataset: DistDataset):
    self.dataset = dataset
    self._lock = threading.RLock()
    self._exit = False
    self._cur_producer_idx = 0 # auto incremental index (same as producer count)
    # The mapping from the key in worker options (such as 'train', 'test')
    # to producer id
    self._worker_key2producer_id: Dict[str, int] = {}  
    self._producer_pool: Dict[int, DistMpSamplingProducer] = {}
    self._msg_buffer_pool: Dict[int, ShmChannel] = {}
    self._epoch: Dict[int, int] = {} # last epoch for the producer

  def shutdown(self):
    for producer_id in list(self._producer_pool.keys()):
      self.destroy_sampling_producer(producer_id)
    assert len(self._producer_pool) == 0
    assert len(self._msg_buffer_pool) == 0

  def wait_for_exit(self):
    r""" Block until the exit flag been set to ``True``.
    """
    while not self._exit:
      time.sleep(SERVER_EXIT_STATUS_CHECK_INTERVAL)

  def exit(self):
    r""" Set the exit flag to ``True``.
    """
    self._exit = True
    return self._exit

  def get_dataset_meta(self):
    r""" Get the meta info of the distributed dataset managed by the current
    server, including partition info and graph types.
    """
    return self.dataset.num_partitions, self.dataset.partition_idx, \
      self.dataset.get_node_types(), self.dataset.get_edge_types()
  
  def get_node_partition_id(self, node_type, index):
    if isinstance(self.dataset.node_pb, PartitionBook):
      partition_id = self.dataset.node_pb[index]
      return partition_id
    elif isinstance(self.dataset.node_pb, Dict):
      partition_id = self.dataset.node_pb[node_type][index]
      return partition_id
    return None

  def get_node_feature(self, node_type, index):
    feature = self.dataset.get_node_feature(node_type)
    return feature[index].cpu()

  def get_tensor_size(self, node_type):
    feature = self.dataset.get_node_feature(node_type)
    return feature.shape

  def get_node_label(self, node_type, index):
    label = self.dataset.get_node_label(node_type)
    return label[index]
  
  def get_edge_index(self, edge_type, layout):
    graph = self.dataset.get_graph(edge_type)
    row = None
    col = None
    result = None
    if layout == 'coo':
      row, col, _, _ = graph.topo.to_coo()
      result = (row, col)
    else:
      raise ValueError(f"Invalid layout {layout}")
    return result

  def get_edge_size(self, edge_type, layout):
    graph = self.dataset.get_graph(edge_type)
    if layout == 'coo':
      row_count = graph.row_count
      col_count = graph.col_count
    else:
      raise ValueError(f"Invalid layout {layout}")
    return (row_count, col_count)

  def create_sampling_producer(
    self,
    sampler_input: Union[NodeSamplerInput, EdgeSamplerInput, RemoteSamplerInput],
    sampling_config: SamplingConfig,
    worker_options: RemoteDistSamplingWorkerOptions,
  ) -> int:
    r""" Create and initialize an instance of ``DistSamplingProducer`` with
    a group of subprocesses for distributed sampling.

    Args:
      sampler_input (NodeSamplerInput or EdgeSamplerInput): The input data
        for sampling.
      sampling_config (SamplingConfig): Configuration of sampling meta info.
      worker_options (RemoteDistSamplingWorkerOptions): Options for launching
        remote sampling workers by this server.

    Returns:
      A unique id of created sampling producer on this server.
    """
    if isinstance(sampler_input, RemoteSamplerInput):
      sampler_input = sampler_input.to_local_sampler_input(dataset=self.dataset)
    
    with self._lock: 
      producer_id = self._worker_key2producer_id.get(worker_options.worker_key)
      if producer_id is None:
        producer_id = self._cur_producer_idx
        self._worker_key2producer_id[worker_options.worker_key] = producer_id
        self._cur_producer_idx += 1
        buffer = ShmChannel(
          worker_options.buffer_capacity, worker_options.buffer_size
        )
        producer = DistMpSamplingProducer(
          self.dataset, sampler_input, sampling_config, worker_options, buffer
        )
        producer.init()
        self._producer_pool[producer_id] = producer
        self._msg_buffer_pool[producer_id] = buffer
        self._epoch[producer_id] = -1
    return producer_id

  def destroy_sampling_producer(self, producer_id: int):
    r""" Shutdown and destroy a sampling producer managed by this server with
    its producer id.
    """
    with self._lock:
      producer = self._producer_pool.get(producer_id, None)
      if producer is not None:
        producer.shutdown()
        self._producer_pool.pop(producer_id)
        self._msg_buffer_pool.pop(producer_id)
        self._epoch.pop(producer_id)

  def start_new_epoch_sampling(self, producer_id: int, epoch: int):
    r""" Start a new epoch sampling tasks for a specific sampling producer
    with its producer id.
    """
    with self._lock:
      cur_epoch = self._epoch[producer_id]
      if cur_epoch < epoch:
        self._epoch[producer_id] = epoch
        producer = self._producer_pool.get(producer_id, None)
        if producer is not None:
          producer.produce_all()

  def fetch_one_sampled_message(self, producer_id: int):
    r""" Fetch a sampled message from the buffer of a specific sampling
    producer with its producer id.
    """
    producer = self._producer_pool.get(producer_id, None)
    if producer is None:
      warnings.warn('invalid producer_id {producer_id}')
      return None, False
    if producer.is_all_sampling_completed_and_consumed():
      return None, True
    buffer = self._msg_buffer_pool.get(producer_id, None)
    while True:
        try:
            msg = buffer.recv(timeout_ms=500)
            return msg, False
        except QueueTimeoutError as e:
            if producer.is_all_sampling_completed():
                return None, True


_dist_server: DistServer = None
r""" ``DistServer`` instance of the current process.
"""


def get_server() -> DistServer:
  r""" Get the ``DistServer`` instance on the current process.
  """
  return _dist_server


def init_server(num_servers: int, server_rank: int, dataset: DistDataset,
                master_addr: str, master_port: int, num_clients: int = 0,
                num_rpc_threads: int = 16, request_timeout: int = 180,
                server_group_name: Optional[str] = None, is_dynamic: bool = False):
  r""" Initialize the current process as a server and establish connections
  with all other servers and clients. Note that this method should be called
  only in the server-client distribution mode.

  Args:
    num_servers (int): Number of processes participating in the server group.
    server_rank (int): Rank of the current process withing the server group (it
      should be a number between 0 and ``num_servers``-1).
    dataset (DistDataset): The ``DistDataset`` object of a partition of graph
      data and feature data, along with distributed patition book info.
    master_addr (str): The master TCP address for RPC connection between all
      servers and clients, the value of this parameter should be same for all
      servers and clients.
    master_port (int): The master TCP port for RPC connection between all
      servers and clients, the value of this parameter should be same for all
      servers and clients.
    num_clients (int): Number of processes participating in the client group.
      if ``is_dynamic`` is ``True``, this parameter will be ignored.
    num_rpc_threads (int): The number of RPC worker threads used for the
      current server to respond remote requests. (Default: ``16``).
    request_timeout (int): The max timeout seconds for remote requests,
      otherwise an exception will be raised. (Default: ``16``).
    server_group_name (str): A unique name of the server group that current
      process belongs to. If set to ``None``, a default name will be used.
      (Default: ``None``).
    is_dynamic (bool): Whether the world size is dynamic. (Default: ``False``).
  """
  if server_group_name:
    server_group_name = server_group_name.replace('-', '_')
  _set_server_context(num_servers, server_rank, server_group_name, num_clients)
  global _dist_server
  _dist_server = DistServer(dataset=dataset)
  init_rpc(master_addr, master_port, num_rpc_threads, request_timeout, is_dynamic=is_dynamic)


def wait_and_shutdown_server():
  r""" Block until all client have been shutdowned, and further shutdown the
  server on the current process and destroy all RPC connections.
  """
  current_context = get_context()
  if current_context is None:
    logging.warning("'wait_and_shutdown_server': try to shutdown server when "
                    "the current process has not been initialized as a server.")
    return
  if not current_context.is_server():
    raise RuntimeError(f"'wait_and_shutdown_server': role type of "
                       f"the current process context is not a server, "
                       f"got {current_context.role}.")
  global _dist_server
  _dist_server.wait_for_exit()
  _dist_server.shutdown()
  _dist_server = None
  barrier()
  shutdown_rpc()


def _call_func_on_server(func, *args, **kwargs):
  r""" A callee entry for remote requests on the server side.
  """
  if not callable(func):
    logging.warning(f"'_call_func_on_server': receive a non-callable "
                    f"function target {func}")
    return None

  server = get_server()
  if hasattr(server, func.__name__):
    return func(server, *args, **kwargs)

  return func(*args, **kwargs)
