# Copyright 2022 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import os
from typing import List, Optional, Union, Literal

import torch

from ..utils import assign_device

from .dist_context import DistContext, assign_server_by_order


class _BasicDistSamplingWorkerOptions(object):
  r""" Basic options to launch distributed sampling workers.

  Args:
    num_workers (int): How many workers to use for distributed neighbor
      sampling of the current process, must be same for each process of
      the current context group. (default: ``1``).
    worker_devices (torch.device or List[torch.device], optional): List of
      devices assgined to workers of this group. If set to ``None``, the
      devices to use will be automatically assigned (the cuda device will be
      preferred if available). (default: ``None``).
    worker_concurrency (int): The max sampling concurrency with different
      seeds batches for each sampling worker, which should not exceed 32.
      (default: ``1``).
    master_addr (str, optional): Master address for rpc initialization across
      all sampling workers. the environment varaible ``MASTER_ADDR`` will be
      used if set to ``None``. (default: ``None``).
    master_port (str or int, optional): Master port for rpc initialization
      across all sampling workers. If set to ``None``, in order to avoid
      conflicts with master port already used by other modules (e.g., the
      method ``init_process_group`` of ``torch.distributed``), the value of
      environment varaible ``MASTER_PORT`` will be increased by one as the
      real rpc port for sampling workers. Otherwise, the provided port should
      be guaranteed to avoid such conflicts. (default: ``None``).
    num_rpc_threads (int, optional): Number of threads used for rpc agent on
      each sampling worker. If set to ``None``, the number of rpc threads to
      use will be specified according to the actual workload, but will not
      exceed 16. (default: ``None``).
    rpc_timeout (float): The timeout in seconds for all rpc requests during
      distributed sampling and feature collection. (default: ``180``).
  """
  def __init__(self,
               num_workers: int = 1,
               worker_devices: Optional[List[torch.device]] = None,
               worker_concurrency: int = 1,
               master_addr: Optional[str] = None,
               master_port: Optional[Union[str, int]] = None,
               num_rpc_threads: Optional[int] = None,
               rpc_timeout: float = 180):
    self.num_workers = num_workers

    # Not sure yet, will be calculated later.
    self.worker_world_size = None
    self.worker_ranks = None

    if worker_devices is None:
      self.worker_devices = None
    elif isinstance(worker_devices, list) or isinstance(worker_devices, tuple):
      assert len(worker_devices) == self.num_workers
      self.worker_devices = list(worker_devices)
    else:
      self.worker_devices = [worker_devices] * self.num_workers

    # Worker concurrency should not exceed 32.
    self.worker_concurrency = max(worker_concurrency, 1)
    self.worker_concurrency = min(self.worker_concurrency, 32)

    if master_addr is not None:
      self.master_addr = str(master_addr)
    elif os.environ.get('MASTER_ADDR') is not None:
      self.master_addr = os.environ['MASTER_ADDR']
    else:
      raise ValueError(f"'{self.__class__.__name__}': missing master address "
                       "for rpc communication, try to provide it or set it "
                       "with environment variable 'MASTER_ADDR'")

    if master_port is not None:
      self.master_port = int(master_port)
    elif os.environ.get('MASTER_PORT') is not None:
      self.master_port = int(os.environ['MASTER_PORT']) + 1
    else:
      raise ValueError(f"'{self.__class__.__name__}': missing master port "
                       "for rpc communication, try to provide it or set it "
                       "with environment variable 'MASTER_ADDR'")

    self.num_rpc_threads = num_rpc_threads
    if self.num_rpc_threads is not None:
      assert self.num_rpc_threads > 0
    self.rpc_timeout = rpc_timeout

  def _set_worker_ranks(self, current_ctx: DistContext):
    self.worker_world_size = current_ctx.world_size * self.num_workers
    self.worker_ranks = [
      current_ctx.rank * self.num_workers + i
      for i in range(self.num_workers)
    ]

  def _assign_worker_devices(self):
    if self.worker_devices is not None:
      return
    self.worker_devices = [assign_device() for _ in range(self.num_workers)]


class CollocatedDistSamplingWorkerOptions(_BasicDistSamplingWorkerOptions):
  r""" Options for launching a single distributed sampling worker collocated
  with the current process.

  Args:
    master_addr (str, optional): Master address for rpc initialization across
      all sampling workers. (default: ``None``).
    master_port (str or int, optional): Master port for rpc initialization
      across all sampling workers. (default: ``None``).
    num_rpc_threads (int, optional): Number of threads used for rpc agent on
      each sampling worker. (default: ``None``).
    rpc_timeout (float): The timeout in seconds for rpc requests.
      (default: ``180``).
    use_all2all (bool): Whether use all2all to collect distributed node/edge 
      feature instead of through p2p rpc. (deafult: ``False``).

  Please ref to ``_BasicDistSamplingWorkerOptions`` for more detailed comments
  of related input arguments.
  """
  def __init__(self,
               master_addr: Optional[str] = None,
               master_port: Optional[Union[str, int]] = None,
               num_rpc_threads: Optional[int] = None,
               rpc_timeout: float = 180,
               use_all2all: bool = False):
    super().__init__(1, None, 1, master_addr, master_port,
                     num_rpc_threads, rpc_timeout)
    self.use_all2all = use_all2all


class MpDistSamplingWorkerOptions(_BasicDistSamplingWorkerOptions):
  r""" Options for launching distributed sampling workers with multiprocessing.

  Note that if ``MpDistWorkerOptions`` is used, all sampling workers will be
  launched on spawned subprocesses by ``torch.multiprocessing``. Thus, a
  share-memory based channel should be created for message passing of sampled
  results, which are produced by those multiprocessing sampling workers and
  consumed by the current process.

  Args:
    num_workers (int): How many workers to use (subprocesses to spwan) for
      distributed neighbor sampling of the current process. (default: ``1``).
    worker_devices (torch.device or List[torch.device], optional): List of
      devices assgined to workers of this group. (default: ``None``).
    worker_concurrency (int): The max sampling concurrency for each sampling
      worker. (default: ``4``).
    master_addr (str, optional): Master address for rpc initialization across
      all sampling workers. (default: ``None``).
    master_port (str or int, optional): Master port for rpc initialization
      across all sampling workers. (default: ``None``).
    num_rpc_threads (int, optional): Number of threads used for rpc agent on
      each sampling worker. (default: ``None``).
    rpc_timeout (float): The timeout in seconds for rpc requests.
      (default: ``180``).
    channel_size (int or str): The shared-memory buffer size (bytes) allocated
      for the channel. The number of ``num_workers * 64MB`` will be used if set
      to ``None``. (default: ``None``).
    pin_memory (bool): Set to ``True`` to register the underlying shared memory
      for cuda, which will achieve better performance if you want to copy
      loaded data from channel to cuda device. (default: ``False``).
    use_all2all (bool): Whether use all2all to collect distributed node/edge 
      feature instead of through p2p rpc. (deafult: ``False``).

  Please ref to ``_BasicDistSamplingWorkerOptions`` for more detailed comments
  of related input arguments.
  """
  def __init__(self,
               num_workers: int = 1,
               worker_devices: Optional[List[torch.device]] = None,
               worker_concurrency: int = 4,
               master_addr: Optional[str] = None,
               master_port: Optional[Union[str, int]] = None,
               num_rpc_threads: Optional[int] = None,
               rpc_timeout: float = 180,
               channel_size: Optional[Union[int, str]] = None,
               pin_memory: bool = False,
               use_all2all: bool = False):
    super().__init__(num_workers, worker_devices, worker_concurrency,
                     master_addr, master_port, num_rpc_threads, rpc_timeout)

    self.channel_capacity = self.num_workers * self.worker_concurrency

    if channel_size is None:
      self.channel_size = f'{self.num_workers * 64}MB'
    else:
      self.channel_size = channel_size

    self.pin_memory = pin_memory
    self.use_all2all = use_all2all


class RemoteDistSamplingWorkerOptions(_BasicDistSamplingWorkerOptions):
  r""" Options for launching distributed sampling workers on remote servers.

  Note that if ``RemoteDistSamplingWorkerOptions`` is used, all sampling
  workers will be launched on remote servers. Thus, a cross-machine based
  channel will be created for message passing of sampled results, which are
  produced by those remote sampling workers and consumed by the current process.

  Args:
    server_rank (int or List[int], optional): The rank of server to launch
      sampling workers, can be multiple. If set to ``None``, it will be 
      automatically assigned. (default: ``None``).
    num_workers (int): How many workers to launch on the remote server for
      distributed neighbor sampling of the current process. (default: ``1``).
    worker_devices (torch.device or List[torch.device], optional): List of
      devices assgined to workers of this group. (default: ``None``).
    worker_concurrency (int): The max sampling concurrency for each sampling
      worker. (default: ``4``).
    master_addr (str, optional): Master address for rpc initialization across
      all sampling workers. (default: ``None``).
    master_port (str or int, optional): Master port for rpc initialization
      across all sampling workers. (default: ``None``).
    num_rpc_threads (int, optional): Number of threads used for rpc agent on
      each sampling worker. (default: ``None``).
    rpc_timeout (float): The timeout in seconds for rpc requests.
      (default: ``180``).
    buffer_size (int or str): The size (bytes) allocated for the server-side
      buffer. The number of ``num_workers * 64MB`` will be used if set to
      ``None``. (default: ``None``).
    prefetch_size (int): The max prefetched sampled messages for consuming on
      the client side. (default: ``4``).
    glt_graph: Used in GraphScope side to get parameters. (default: ``None``).
    workload_type: Used in GraphScope side, indicates the type of option. This 
      field must be set when ``workload_type`` is not None. (default: ``None``).
  """
  def __init__(self,
               server_rank: Optional[Union[int, List[int]]] = None,
               num_workers: int = 1,
               worker_devices: Optional[List[torch.device]] = None,
               worker_concurrency: int = 4,
               master_addr: Optional[str] = None,
               master_port: Optional[Union[str, int]] = None,
               num_rpc_threads: Optional[int] = None,
               rpc_timeout: float = 180,
               buffer_size: Optional[Union[int, str]] = None,
               prefetch_size: int = 4,
               worker_key: str = None,
               glt_graph = None,
               workload_type: Optional[Literal['train', 'validate', 'test']] = None,
               use_all2all: bool = False):
    # glt_graph is used in GraphScope side to get parameters
    if glt_graph:
      if not workload_type:
        raise ValueError(f"'{self.__class__.__name__}': missing workload_type ")
      master_addr = glt_graph.master_addr
      if workload_type == 'train':
        master_port = glt_graph.train_loader_master_port
      elif workload_type == 'validate':
        master_port = glt_graph.val_loader_master_port
      elif workload_type == 'test':
        master_port = glt_graph.test_loader_master_port
      worker_key = str(master_port)
    
    super().__init__(num_workers, worker_devices, worker_concurrency,
                     master_addr, master_port, num_rpc_threads, rpc_timeout)
    if server_rank is not None:
      self.server_rank = server_rank
    else:
      self.server_rank = assign_server_by_order()
    self.buffer_capacity = self.num_workers * self.worker_concurrency
    if buffer_size is None:
      self.buffer_size = f'{self.num_workers * 64}MB'
    else:
      self.buffer_size = buffer_size

    self.prefetch_size = prefetch_size
    if self.prefetch_size > self.buffer_capacity:
      raise ValueError(f"'{self.__class__.__name__}': the prefetch count "
                       f"{self.prefetch_size} exceeds the buffer capacity "
                       f"{self.buffer_capacity}")
    self.worker_key = worker_key
    self.use_all2all = use_all2all


AllDistSamplingWorkerOptions = Union[
  CollocatedDistSamplingWorkerOptions,
  MpDistSamplingWorkerOptions,
  RemoteDistSamplingWorkerOptions
]
