# Copyright 2022 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, eithPer express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from multiprocessing.reduction import ForkingPickler
from typing import Dict, List, Optional, Union, Literal, Tuple, Callable

import torch

from ..data import Dataset, Graph, Feature, DeviceGroup, vineyard_utils
from ..partition import (
  load_partition, cat_feature_cache,
  PartitionBook, HeteroNodePartitionDict, HeteroEdgePartitionDict
)
from ..typing import (NodeType, EdgeType, NodeLabel, NodeIndex)
from ..utils import share_memory, default_id_filter, default_id_select


class DistDataset(Dataset):
  r""" Graph and feature dataset with distributed partition info.

  """
  def __init__(
    self,
    num_partitions: int = 1,
    partition_idx: int = 0,
    graph_partition: Union[Graph, Dict[EdgeType, Graph]] = None,
    node_feature_partition: Union[Feature, Dict[NodeType, Feature]] = None,
    edge_feature_partition: Union[Feature, Dict[EdgeType, Feature]] = None,
    whole_node_labels: NodeLabel = None,
    node_pb: Union[PartitionBook, HeteroNodePartitionDict] = None,
    edge_pb: Union[PartitionBook, HeteroEdgePartitionDict] = None,
    node_feat_pb: Union[PartitionBook, HeteroNodePartitionDict] = None,
    edge_feat_pb: Union[PartitionBook, HeteroEdgePartitionDict] = None,
    edge_dir: Literal['in', 'out'] = 'out',
    graph_caching: bool = False,
    node_split: Tuple[NodeIndex, NodeIndex, NodeIndex] = None,
    id_filter: Callable = default_id_filter,
    id_select: Callable = default_id_select
  ):
    super().__init__(
      graph_partition,
      node_feature_partition,
      edge_feature_partition,
      whole_node_labels,
      edge_dir,
      node_split,
    )
    self.id_filter = id_filter
    self.id_select = id_select
    self.num_partitions = num_partitions
    self.partition_idx = partition_idx
    self.graph_caching = graph_caching

    self.node_pb = node_pb
    self.edge_pb = edge_pb

    # As the loaded feature partition may be concatenated with its cached
    # features and the partition book for features will be modified, thus we
    # need to distinguish them with the original graph partition books.
    #
    # If the `node_feat_pb` or `edge_feat_pb` is not provided, the `node_pb`
    # or `edge_pb` will be used instead for feature lookups.
    self._node_feat_pb = node_feat_pb
    self._edge_feat_pb = edge_feat_pb

    if self.graph is not None:
      assert self.node_pb is not None
    if self.node_features is not None:
      assert self.node_pb is not None or self._node_feat_pb is not None
    if self.edge_features is not None:
      assert self.edge_pb is not None or self._edge_feat_pb is not None

  def load(
    self,
    root_dir: str,
    partition_idx: int,
    graph_mode: str = 'ZERO_COPY',
    input_layout: Literal['COO', 'CSR', 'CSC'] = 'COO',
    feature_with_gpu: bool = True,
    graph_caching: bool = False,
    device_group_list: Optional[List[DeviceGroup]] = None,
    whole_node_label_file: Union[str, Dict[NodeType, str]] = None,
    device: Optional[int] = None
  ):
    r""" Load a certain dataset partition from partitioned files and create
    in-memory objects (``Graph``, ``Feature`` or ``torch.Tensor``).

    Args:
      root_dir (str): The directory path to load the graph and feature
        partition data.
      partition_idx (int): Partition idx to load.
      graph_mode (str): Mode for creating graphlearn_torch's ``Graph``, including
        ``CPU``, ``ZERO_COPY`` or ``CUDA``. (default: ``ZERO_COPY``)
      input_layout (str): layout of the input graph, including ``CSR``, ``CSC`` 
        or ``COO``. (default: ``COO``)
      feature_with_gpu (bool): A Boolean value indicating whether the created
        ``Feature`` objects of node/edge features use ``UnifiedTensor``.
        If True, it means ``Feature`` consists of ``UnifiedTensor``, otherwise
        ``Feature`` is a PyTorch CPU Tensor, the ``device_group_list`` and
        ``device`` will be invliad. (default: ``True``)
      graph_caching (bool): A Boolean value indicating whether to load the full
        graph totoploy instead of partitioned one.
      device_group_list (List[DeviceGroup], optional): A list of device groups
        used for feature lookups, the GPU part of feature data will be
        replicated on each device group in this list during the initialization.
        GPUs with peer-to-peer access to each other should be set in the same
        device group properly.  (default: ``None``)
      whole_node_label_file (str): The path to the whole node labels which are
        not partitioned. (default: ``None``)
      device: The target cuda device rank used for graph operations when graph
        mode is not "CPU" and feature lookups when the GPU part is not None.
        (default: ``None``)
    """
    (
      self.num_partitions,
      self.partition_idx,
      graph_data,
      node_feat_data,
      edge_feat_data,
      self.node_pb,
      self.edge_pb
    ) = load_partition(root_dir, partition_idx, graph_caching)

    # init graph partition
    if isinstance(graph_data, dict):
      # heterogeneous.
      edge_index, edge_ids, edge_weights = {}, {}, {}
      for k, v in graph_data.items():
        edge_index[k] = v.edge_index
        edge_ids[k] = v.eids
        edge_weights[k] = v.weights
    else:
      # homogeneous.
      edge_index = graph_data.edge_index
      edge_ids = graph_data.eids
      edge_weights = graph_data.weights
    self.init_graph(edge_index, edge_ids, edge_weights, layout=input_layout,
                    graph_mode=graph_mode, device=device)
    self.graph_caching = graph_caching

    # load node feature partition
    if node_feat_data is not None:
      node_cache_ratio, node_feat, node_feat_id2idx, node_feat_pb = \
        _cat_feature_cache(partition_idx, node_feat_data, self.node_pb)
      self.init_node_features(
        node_feat, node_feat_id2idx, None, node_cache_ratio,
        device_group_list, device, feature_with_gpu, dtype=None
      )
      self._node_feat_pb = node_feat_pb

    # load edge feature partition
    if edge_feat_data is not None:
      edge_cache_ratio, edge_feat, edge_feat_id2idx, edge_feat_pb = \
        _cat_feature_cache(partition_idx, edge_feat_data, self.edge_pb)
      self.init_edge_features(
        edge_feat, edge_feat_id2idx, edge_cache_ratio,
        device_group_list, device, feature_with_gpu, dtype=None
      )
      self._edge_feat_pb = edge_feat_pb

    # load whole node labels
    if whole_node_label_file is not None:
      if isinstance(whole_node_label_file, dict):
        whole_node_labels = {}
        for ntype, file in whole_node_label_file.items():
          whole_node_labels[ntype] = torch.load(file)
      else:
        whole_node_labels = torch.load(whole_node_label_file)
      self.init_node_labels(whole_node_labels)

  def random_node_split(
    self,
    num_val: Union[float, int],
    num_test: Union[float, int],
  ):
    r"""Performs a node-level random split by adding :obj:`train_idx`,
    :obj:`val_idx` and :obj:`test_idx` attributes to the
    :class:`~graphlearn_torch.distributed.DistDataset` object. All nodes except 
    those in the validation and test sets will be used for training.
    
    Args:
      num_val (int or float): The number of validation samples.
        If float, it represents the ratio of samples to include in the
        validation set.
      num_test (int or float): The number of test samples in case
        of :obj:`"train_rest"` and :obj:`"random"` split. If float, it
        represents the ratio of samples to include in the test set.
    """

    if isinstance(self.node_labels, dict):
      train_idx = {}
      val_idx = {}
      test_idx = {}
  
      for node_type, _ in self.node_labels.items():
        indices = self.id_filter(self.node_pb[node_type], self.partition_idx)
        train_idx[node_type], val_idx[node_type], test_idx[node_type] = random_split(indices, num_val, num_test)
    else:
      indices = self.id_filter(self.node_pb, self.partition_idx)
      train_idx, val_idx, test_idx = random_split(indices, num_val, num_test)
    self.init_node_split((train_idx, val_idx, test_idx))

  def load_vineyard(
    self,
    vineyard_id: str,
    vineyard_socket: str,
    edges: List[EdgeType],
    edge_weights: Dict[EdgeType, str] = None,
    node_features: Dict[NodeType, List[str]] = None,
    edge_features: Dict[EdgeType, List[str]] = None,
    node_labels: Dict[NodeType, str] = None,
  ):
    super().load_vineyard(vineyard_id=vineyard_id, vineyard_socket=vineyard_socket, 
                          edges=edges, edge_weights=edge_weights, node_features=node_features, 
                          edge_features=edge_features, node_labels=node_labels)
    if isinstance(self.graph, dict):
      # hetero
      self._node_feat_pb = {}
      if node_features:
        for ntype, _ in self.node_features.items():
          if self.node_pb is not None:
            self._node_feat_pb[ntype] = self.node_pb[ntype]
          else:
            self._node_feat_pb[ntype] = None
    else:
      # homo
      if node_features:
        self._node_feat_pb = self.node_pb
    
    self.id_select = vineyard_utils.v6d_id_select
    self.id_filter = vineyard_utils.v6d_id_filter

  def share_ipc(self):
    super().share_ipc()
    self.node_pb = share_memory(self.node_pb)
    self.edge_pb = share_memory(self.edge_pb)
    self._node_feat_pb = share_memory(self._node_feat_pb)
    self._edge_feat_pb = share_memory(self._edge_feat_pb)
    ipc_hanlde = (
      self.num_partitions, self.partition_idx,
      self.graph, self.node_features, self.edge_features, self.node_labels,
      self.node_pb, self.edge_pb, self._node_feat_pb, self._edge_feat_pb, 
      self.edge_dir, self.graph_caching,
      (self.train_idx, self.val_idx, self.test_idx)
    )
    return ipc_hanlde

  @classmethod
  def from_ipc_handle(cls, ipc_handle):
    return cls(*ipc_handle)

  @property
  def node_feat_pb(self):
    if self._node_feat_pb is None:
      return self.node_pb
    return self._node_feat_pb

  @property
  def edge_feat_pb(self):
    if self._edge_feat_pb is None:
      return self.edge_pb
    return self._edge_feat_pb


def _cat_feature_cache(partition_idx, raw_feat_data, raw_feat_pb):
  r""" Cat a feature partition with its cached features.
  """
  if isinstance(raw_feat_data, dict):
    # heterogeneous.
    cache_ratio, feat_data, feat_id2idx, feat_pb = {}, {}, {}, {}
    for graph_type, raw_feat in raw_feat_data.items():
      cache_ratio[graph_type], feat_data[graph_type], \
      feat_id2idx[graph_type], feat_pb[graph_type] = \
        cat_feature_cache(partition_idx, raw_feat, raw_feat_pb[graph_type])
  else:
    # homogeneous.
    cache_ratio, feat_data, feat_id2idx, feat_pb = \
      cat_feature_cache(partition_idx, raw_feat_data, raw_feat_pb)
  return cache_ratio, feat_data, feat_id2idx, feat_pb


## Pickling Registration

def rebuild_dist_dataset(ipc_handle):
  ds = DistDataset.from_ipc_handle(ipc_handle)
  return ds

def reduce_dist_dataset(dataset: DistDataset):
  ipc_handle = dataset.share_ipc()
  return (rebuild_dist_dataset, (ipc_handle, ))

ForkingPickler.register(DistDataset, reduce_dist_dataset)

def random_split(
  indices: torch.Tensor,
  num_val: Union[float, int],
  num_test: Union[float, int],
):  
  num_total = indices.shape[0]
  num_val = round(num_total * num_val) if isinstance(num_val, float) else num_val
  num_test = round(num_total * num_test) if isinstance(num_test, float) else num_test
  perm = torch.randperm(num_total)
  val_idx = indices[perm[:num_val]].clone()
  test_idx = indices[perm[num_val:num_val + num_test]].clone()
  train_idx = indices[perm[num_val + num_test:]].clone()
  return train_idx, val_idx, test_idx