# Copyright 2022 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================


from typing import Tuple, Union, Optional, Literal

import torch

from .transform import to_data, to_hetero_data

from ..utils import convert_to_tensor
from ..data import Dataset
from ..sampler import (
  BaseSampler,
  EdgeSamplerInput,
  NegativeSampling,
  SamplerOutput,
  HeteroSamplerOutput
)
from ..typing import InputEdges, reverse_edge_type


class LinkLoader(object):
  r"""A data loader that performs mini-batch sampling from link information,
  using a generic :class:`~graphlearn_torch.sampler.BaseSampler`
  implementation that defines a
  :meth:`~graphlearn_torch.sampler.BaseSampler.sample_from_edges` function and
  is supported on the provided input :obj:`data` object.

    .. note::
      Negative sampling for triplet case is currently implemented in an
      approximate way, *i.e.* negative edges may contain false negatives.

    Args:
      data (Dataset): The `graphlearn_torch.data.Dataset` object.
      link_sampler (graphlearn_torch.sampler.BaseSampler): The sampler
        implementation to be used with this loader.
        Needs to implement
        :meth:`~graphlearn_torch.sampler.BaseSampler.sample_from_edges`.
        The sampler implementation must be compatible with the input
        :obj:`data` object.
      edge_label_index (Tensor or EdgeType or Tuple[EdgeType, Tensor]):
        The edge indices, holding source and destination nodes to start
        sampling from.
        If set to :obj:`None`, all edges will be considered.
        In heterogeneous graphs, needs to be passed as a tuple that holds
        the edge type and corresponding edge indices.
        (default: :obj:`None`)
      edge_label (Tensor, optional): The labels of edge indices from which to
        start sampling from. Must be the same length as
        the :obj:`edge_label_index`. (default: :obj:`None`)
      neg_sampling (NegativeSampling, optional): The negative sampling
        configuration.
        For negative sampling mode :obj:`"binary"`, samples can be accessed
        via the attributes :obj:`edge_label_index` and :obj:`edge_label` in
        the respective edge type of the returned mini-batch.
        In case :obj:`edge_label` does not exist, it will be automatically
        created and represents a binary classification task (:obj:`0` =
        negative edge, :obj:`1` = positive edge).
        In case :obj:`edge_label` does exist, it has to be a categorical
        label from :obj:`0` to :obj:`num_classes - 1`.
        After negative sampling, label :obj:`0` represents negative edges,
        and labels :obj:`1` to :obj:`num_classes` represent the labels of
        positive edges.
        Note that returned labels are of type :obj:`torch.float` for binary
        classification (to facilitate the ease-of-use of
        :meth:`F.binary_cross_entropy`) and of type
        :obj:`torch.long` for multi-class classification (to facilitate the
        ease-of-use of :meth:`F.cross_entropy`).
        For negative sampling mode :obj:`"triplet"`, samples can be
        accessed via the attributes :obj:`src_index`, :obj:`dst_pos_index`
        and :obj:`dst_neg_index` in the respective node types of the
        returned mini-batch.
        :obj:`edge_label` needs to be :obj:`None` for :obj:`"triplet"`
        negative sampling mode.
        If set to :obj:`None`, no negative sampling strategy is applied.
        (default: :obj:`None`)
      device (torch.device, optional): The device to put the data on.
        If set to :obj:`None`, the CPU is used.
      edge_dir (str:["in", "out"]): The edge direction for sampling.
        Can be either :str:`"out"` or :str:`"in"`.
        (default: :str:`"out"`)
      **kwargs (optional): Additional arguments of
        :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`,
        :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`.

  """
  def __init__(
    self,
    data: Dataset,
    link_sampler: BaseSampler,
    edge_label_index: InputEdges = None,
    edge_label: Optional[torch.Tensor] = None,
    neg_sampling: Optional[NegativeSampling] = None,
    device: torch.device = torch.device('cuda:0'),
    edge_dir: Literal['out', 'in'] = 'out',
    **kwargs,
  ):
    # Get edge type (or `None` for homogeneous graphs):
    input_type, edge_label_index = get_edge_label_index(
        data, edge_label_index)
    self.data = data
    self.link_sampler = link_sampler
    self.neg_sampling = NegativeSampling.cast(neg_sampling)
    self.device = device
    self.edge_dir = edge_dir

    if (self.neg_sampling is not None and self.neg_sampling.is_binary()
        and edge_label is not None and edge_label.min() == 0):
      # Increment labels such that `zero` now denotes "negative".
      edge_label = edge_label + 1

    if (self.neg_sampling is not None and self.neg_sampling.is_triplet()
            and edge_label is not None):
      raise ValueError("'edge_label' needs to be undefined for "
                        "'triplet'-based negative sampling. Please use "
                        "`src_index`, `dst_pos_index` and "
                        "`neg_pos_index` of the returned mini-batch "
                        "instead to differentiate between positive and "
                        "negative samples.")

    self.input_data = EdgeSamplerInput(
      row=edge_label_index[0].clone(),
      col=edge_label_index[1].clone(),
      label=edge_label,
      input_type=input_type,
      neg_sampling=self.neg_sampling,
    )

    input_index = range(len(edge_label_index[0]))
    self._seed_loader = torch.utils.data.DataLoader(input_index, **kwargs)

  def __iter__(self):
    self._seeds_iter = iter(self._seed_loader)
    return self

  def __next__(self):
    seeds = self._seeds_iter._next_data().to(self.device)
    # Currently, we support the out-edge sampling manner, so we reverse the
    # direction of src and dst for the output so that features of the sampled
    # nodes during training can be aggregated from k-hop to (k-1)-hop nodes.
    sampler_out = self.link_sampler.sample_from_edges(self.input_data[seeds])
    result = self._collate_fn(sampler_out)

    return result

  def _collate_fn(self, sampler_out: Union[SamplerOutput, HeteroSamplerOutput]):
    r"""format sampler output to Data/HeteroData
      For the out-edge sampling scheme (i.e. the direction of edges in
      the output is inverse to the original graph), we put the reversed
      edge_label_index into the (dst, rev_to, src) subgraph for
      HeteroSamplerOutput and (dst, to, src) for SamplerOutput.
      However, for the in-edge sampling scheme (i.e. the direction of edges 
      in the output is the same as the original graph), we do not need to
      reverse the edge type of the sampler_out.
    """
    if isinstance(sampler_out, SamplerOutput):
      x = self.data.node_features[sampler_out.node]
      if self.data.edge_features is not None and sampler_out.edge is not None:
        edge_attr = self.data.edge_features[sampler_out.edge]
      else:
        edge_attr = None
      res_data = to_data(sampler_out,
                         node_feats=x,
                         edge_feats=edge_attr,
                        )
    else: # hetero
      x_dict = {}
      x_dict = {ntype : self.data.get_node_feature(ntype)[ids.to(torch.int64)] for ntype, ids in sampler_out.node.items()}
      edge_attr_dict = {}
      if sampler_out.edge is not None:
        for etype, eids in sampler_out.edge.items():
          if self.edge_dir == 'out':
            efeat = self.data.get_edge_feature(reverse_edge_type(etype))
          elif self.edge_dir == 'in':
            efeat = self.data.get_edge_feature(etype)
          if efeat is not None:
            edge_attr_dict[etype] = efeat[eids.to(torch.int64)]

      res_data = to_hetero_data(sampler_out,
                                node_feat_dict=x_dict,
                                edge_feat_dict=edge_attr_dict,
                                edge_dir=self.edge_dir,
                               )
    return res_data

  def __repr__(self) -> str:
    return f'{self.__class__.__name__}()'


def get_edge_label_index(
  data: Dataset,
  edge_label_index: InputEdges
) -> Tuple[Optional[str], torch.Tensor]:
  edge_type = None
  # # Need the edge index in COO for LinkNeighborLoader:
  def _get_edge_index(edge_type):
    row, col, _, _ = data.get_graph(edge_type).topo.to_coo()
    return (row, col)

  if not isinstance(edge_label_index, Tuple):
    if edge_label_index is None:
      return None, _get_edge_index(edge_type)
    return None, convert_to_tensor(edge_label_index)

  if isinstance(edge_label_index[0], str):
    edge_type = edge_label_index

    return edge_type, _get_edge_index(edge_type)

  assert len(edge_label_index) == 2
  edge_type, edge_label_index = convert_to_tensor(edge_label_index)

  if edge_label_index is None:
    row, col, _, _ = data.get_graph(edge_type).topo.to_coo()
    return edge_type, (row, col)

  return edge_type, edge_label_index