# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import logging
import os

import tensorflow as tf

from easy_rec.python.builders import loss_builder
from easy_rec.python.model.easy_rec_model import EasyRecModel
from easy_rec.python.protos.loss_pb2 import LossType
from easy_rec.python.protos.simi_pb2 import Similarity

if tf.__version__ >= '2.0':
  tf = tf.compat.v1
losses = tf.losses


class MatchModel(EasyRecModel):

  def __init__(self,
               model_config,
               feature_configs,
               features,
               labels=None,
               is_training=False):
    super(MatchModel, self).__init__(model_config, feature_configs, features,
                                     labels, is_training)
    self._loss_type = self._model_config.loss_type
    self._num_class = self._model_config.num_class

    if self._loss_type == LossType.CLASSIFICATION:
      assert self._num_class == 1

    if self._loss_type in [LossType.CLASSIFICATION, LossType.L2_LOSS]:
      self._is_point_wise = True
      logging.info('Use point wise dssm.')
    else:
      self._is_point_wise = False
      logging.info('Use list wise dssm.')

    cls_mem = self._model_config.WhichOneof('model')
    sub_model_config = getattr(self._model_config, cls_mem)

    self._item_ids = None
    assert sub_model_config is not None, 'sub_model_config undefined: model_cls = %s' % cls_mem
    if getattr(sub_model_config, 'item_id', '') != '':
      logging.info('item_id feature is: %s' % sub_model_config.item_id)
      self._item_ids = features[sub_model_config.item_id]

  def _mask_in_batch(self, logits):
    batch_size = tf.shape(logits)[0]
    if getattr(self._model_config, 'ignore_in_batch_neg_sam', False):
      in_batch = logits[:, :batch_size] - (
          1 - tf.diag(tf.ones([batch_size], dtype=tf.float32))) * 1e32
      return tf.concat([in_batch, logits[:, batch_size:]], axis=1)
    else:
      if self._item_ids is not None:
        mask_in_batch_neg = tf.to_float(
            tf.equal(self._item_ids[None, :batch_size],
                     self._item_ids[:batch_size, None])) - tf.diag(
                         tf.ones([batch_size], dtype=tf.float32))
        tf.summary.scalar('in_batch_neg_conflict',
                          tf.reduce_sum(mask_in_batch_neg))
        return tf.concat([
            logits[:, :batch_size] - mask_in_batch_neg * 1e32,
            logits[:, batch_size:]],
            axis=1)  # yapf: disable
      else:
        return logits

  def _list_wise_sim(self, user_emb, item_emb):
    batch_size = tf.shape(user_emb)[0]
    hard_neg_indices = self._feature_dict.get('hard_neg_indices', None)

    if hard_neg_indices is not None:
      logging.info('With hard negative examples')
      noclk_size = tf.shape(hard_neg_indices)[0]
      # pos_item_emb, neg_item_emb, hard_neg_item_emb = tf.split(
      #     item_emb, [batch_size, -1, noclk_size], axis=0)
      simple_item_emb, hard_neg_item_emb = tf.split(
          item_emb, [-1, noclk_size], axis=0)
    else:
      # pos_item_emb = item_emb[:batch_size]
      # neg_item_emb = item_emb[batch_size:]
      simple_item_emb = item_emb

    # pos_user_item_sim = tf.reduce_sum(
    #     tf.multiply(user_emb, pos_item_emb), axis=1, keep_dims=True)
    # neg_user_item_sim = tf.matmul(user_emb, tf.transpose(neg_item_emb))
    # simple_user_item_sim = tf.matmul(user_emb, tf.transpose(simple_item_emb))

    _mode = os.environ['tf.estimator.mode']
    if _mode == tf.estimator.ModeKeys.PREDICT:
      simple_user_item_sim = tf.reduce_sum(
          tf.multiply(user_emb, simple_item_emb), axis=1, keep_dims=True)
    else:
      simple_user_item_sim = tf.matmul(user_emb, tf.transpose(simple_item_emb))

    if hard_neg_indices is None:
      return simple_user_item_sim
    else:
      user_emb_expand = tf.gather(user_emb, hard_neg_indices[:, 0])
      hard_neg_user_item_sim = tf.reduce_sum(
          tf.multiply(user_emb_expand, hard_neg_item_emb), axis=1)
      max_num_neg = tf.reduce_max(hard_neg_indices[:, 1]) + 1
      hard_neg_shape = tf.stack([tf.to_int64(batch_size), max_num_neg])
      hard_neg_sim = tf.scatter_nd(hard_neg_indices, hard_neg_user_item_sim,
                                   hard_neg_shape)
      hard_neg_mask = tf.scatter_nd(
          hard_neg_indices,
          tf.ones_like(hard_neg_user_item_sim, dtype=tf.float32),
          shape=hard_neg_shape)
      # set tail positions to -1e32, so that after exp(x), will be zero
      hard_neg_user_item_sim = hard_neg_sim - (1 - hard_neg_mask) * 1e32

      # user_item_sim = [pos_user_item_sim, neg_user_item_sim]
      # if hard_neg_indices is not None:
      #   user_item_sim.append(hard_neg_user_item_sim)
      # return tf.concat(user_item_sim, axis=1)

      return tf.concat([simple_user_item_sim, hard_neg_user_item_sim], axis=1)

  def _point_wise_sim(self, user_emb, item_emb):
    user_item_sim = tf.reduce_sum(
        tf.multiply(user_emb, item_emb), axis=1, keep_dims=True)
    return user_item_sim

  def sim(self, user_emb, item_emb):
    # Name the outputs of the user tower and the item tower, i.e. the inputs of the
    # simularity operation.
    # Explicit names of these nodes are necessary for some online recall systems like
    # BasicEngine to split up the predicting graph into different clusters.
    user_emb = tf.identity(user_emb, 'user_tower_emb')
    item_emb = tf.identity(item_emb, 'item_tower_emb')

    if self._is_point_wise:
      return self._point_wise_sim(user_emb, item_emb)
    else:
      return self._list_wise_sim(user_emb, item_emb)

  def norm(self, fea):
    fea_norm = tf.nn.l2_normalize(fea, axis=-1)
    return fea_norm

  def build_predict_graph(self):
    if not self.has_backbone:
      raise NotImplementedError(
          'method `build_predict_graph` must be implemented when you donot use backbone network'
      )
    model = self._model_config.WhichOneof('model')
    assert model == 'model_params', '`model_params` must be configured'
    model_params = self._model_config.model_params
    for out in model_params.outputs:
      self._outputs.append(out)

    output = self.backbone

    user_tower_emb = output[model_params.user_tower_idx_in_output]
    item_tower_emb = output[model_params.item_tower_idx_in_output]

    if model_params.simi_func == Similarity.COSINE:
      user_tower_emb = self.norm(user_tower_emb)
      item_tower_emb = self.norm(item_tower_emb)
      temperature = model_params.temperature
    else:
      temperature = 1.0

    user_item_sim = self.sim(user_tower_emb, item_tower_emb) / temperature

    if model_params.scale_simi:
      sim_w = tf.get_variable(
          'sim_w',
          dtype=tf.float32,
          shape=(1),
          initializer=tf.ones_initializer())
      sim_b = tf.get_variable(
          'sim_b',
          dtype=tf.float32,
          shape=(1),
          initializer=tf.zeros_initializer())
      y_pred = user_item_sim * tf.abs(sim_w) + sim_b
    else:
      y_pred = user_item_sim

    if self._is_point_wise:
      y_pred = tf.reshape(y_pred, [-1])

    if self._loss_type == LossType.CLASSIFICATION:
      self._prediction_dict['logits'] = y_pred
      self._prediction_dict['probs'] = tf.nn.sigmoid(y_pred)
    elif self._loss_type == LossType.SOFTMAX_CROSS_ENTROPY:
      y_pred = self._mask_in_batch(y_pred)
      self._prediction_dict['logits'] = y_pred
      self._prediction_dict['probs'] = tf.nn.softmax(y_pred)
    else:
      self._prediction_dict['y'] = y_pred

    self._prediction_dict['user_tower_emb'] = user_tower_emb
    self._prediction_dict['item_tower_emb'] = item_tower_emb
    self._prediction_dict['user_emb'] = tf.reduce_join(
        tf.as_string(user_tower_emb), axis=-1, separator=',')
    self._prediction_dict['item_emb'] = tf.reduce_join(
        tf.as_string(item_tower_emb), axis=-1, separator=',')

    return self._prediction_dict

  def build_loss_graph(self):
    if self._is_point_wise:
      return self._build_point_wise_loss_graph()
    else:
      return self._build_list_wise_loss_graph()

  def _build_list_wise_loss_graph(self):
    if self._loss_type == LossType.SOFTMAX_CROSS_ENTROPY:
      batch_size = tf.shape(self._prediction_dict['probs'])[0]
      indices = tf.range(batch_size)
      indices = tf.concat([indices[:, None], indices[:, None]], axis=1)
      hit_prob = tf.gather_nd(
          self._prediction_dict['probs'][:batch_size, :batch_size], indices)

      sample_weights = tf.cast(tf.squeeze(self._sample_weight), tf.float32)
      self._loss_dict['cross_entropy_loss'] = -tf.reduce_mean(
          tf.log(hit_prob + 1e-12) *
          sample_weights) / tf.reduce_mean(sample_weights)

      logging.info('softmax cross entropy loss is used')

      user_features = self._prediction_dict['user_tower_emb']
      pos_item_emb = self._prediction_dict['item_tower_emb'][:batch_size]
      pos_simi = tf.reduce_sum(user_features * pos_item_emb, axis=1)
      # if pos_simi < 0, produce loss
      reg_pos_loss = tf.nn.relu(-pos_simi)
      self._loss_dict['reg_pos_loss'] = tf.reduce_mean(
          reg_pos_loss * sample_weights) / tf.reduce_mean(sample_weights)

      # the AMM loss for DAT model
      if all([
          k in self._prediction_dict.keys() for k in
          ['augmented_p_u', 'augmented_p_i', 'augmented_a_u', 'augmented_a_i']
      ]):
        self._loss_dict[
            'amm_loss_u'] = self._model_config.amm_u_weight * tf.reduce_mean(
                tf.square(self._prediction_dict['augmented_a_u'] -
                          self._prediction_dict['augmented_p_i'][:batch_size]) *
                sample_weights) / tf.reduce_mean(sample_weights)
        self._loss_dict[
            'amm_loss_i'] = self._model_config.amm_i_weight * tf.reduce_mean(
                tf.square(self._prediction_dict['augmented_a_i'][:batch_size] -
                          self._prediction_dict['augmented_p_u']) *
                sample_weights) / tf.reduce_mean(sample_weights)

    else:
      raise ValueError('invalid loss type: %s' % str(self._loss_type))
    return self._loss_dict

  def _build_point_wise_loss_graph(self):
    label = list(self._labels.values())[0]
    if self._loss_type == LossType.CLASSIFICATION:
      pred = self._prediction_dict['logits']
      loss_name = 'cross_entropy_loss'
    elif self._loss_type == LossType.L2_LOSS:
      pred = self._prediction_dict['y']
      loss_name = 'l2_loss'
    else:
      raise ValueError('invalid loss type: %s' % str(self._loss_type))

    kwargs = {'loss_name': loss_name}
    self._loss_dict[loss_name] = loss_builder.build(
        self._loss_type,
        label=label,
        pred=pred,
        loss_weight=self._sample_weight,
        **kwargs)

    # build kd loss
    kd_loss_dict = loss_builder.build_kd_loss(self.kd, self._prediction_dict,
                                              self._labels, self._feature_dict)
    self._loss_dict.update(kd_loss_dict)
    return self._loss_dict

  def build_metric_graph(self, eval_config):
    if self._is_point_wise:
      return self._build_point_wise_metric_graph(eval_config)
    else:
      return self._build_list_wise_metric_graph(eval_config)

  def _build_list_wise_metric_graph(self, eval_config):
    from easy_rec.python.core.easyrec_metrics import metrics_tf
    logits = self._prediction_dict['logits']
    # label = tf.zeros_like(logits[:, :1], dtype=tf.int64)
    batch_size = tf.shape(logits)[0]
    label = tf.cast(tf.range(batch_size), tf.int64)

    indices = tf.range(batch_size)
    indices = tf.concat([indices[:, None], indices[:, None]], axis=1)
    pos_item_sim = tf.gather_nd(logits[:batch_size, :batch_size], indices)
    metric_dict = {}
    for metric in eval_config.metrics_set:
      if metric.WhichOneof('metric') == 'recall_at_topk':
        metric_dict['recall@%d' %
                    metric.recall_at_topk.topk] = metrics_tf.recall_at_k(
                        label, logits, metric.recall_at_topk.topk)

        logits_v2 = tf.concat([pos_item_sim[:, None], logits[:, batch_size:]],
                              axis=1)
        labels_v2 = tf.zeros_like(logits_v2[:, :1], dtype=tf.int64)
        metric_dict['recall_neg_sam@%d' %
                    metric.recall_at_topk.topk] = metrics_tf.recall_at_k(
                        labels_v2, logits_v2, metric.recall_at_topk.topk)

        metric_dict['recall_in_batch@%d' %
                    metric.recall_at_topk.topk] = metrics_tf.recall_at_k(
                        label, logits[:, :batch_size],
                        metric.recall_at_topk.topk)
      else:
        raise ValueError('invalid metric type: %s' % str(metric))
    return metric_dict

  def _build_point_wise_metric_graph(self, eval_config):
    from easy_rec.python.core.easyrec_metrics import metrics_tf
    metric_dict = {}
    label = list(self._labels.values())[0]
    for metric in eval_config.metrics_set:
      if metric.WhichOneof('metric') == 'auc':
        assert self._loss_type == LossType.CLASSIFICATION
        metric_dict['auc'] = metrics_tf.auc(label,
                                            self._prediction_dict['probs'])
      elif metric.WhichOneof('metric') == 'mean_absolute_error':
        assert self._loss_type == LossType.L2_LOSS
        metric_dict['mean_absolute_error'] = metrics_tf.mean_absolute_error(
            tf.to_float(label), self._prediction_dict['y'])
      else:
        raise ValueError('invalid metric type: %s' % str(metric))
    return metric_dict

  def get_outputs(self):
    if not self.has_backbone:
      raise NotImplementedError(
          'could not call get_outputs on abstract class MatchModel')
    if self._loss_type == LossType.CLASSIFICATION:
      return [
          'logits', 'probs', 'user_emb', 'item_emb', 'user_tower_emb',
          'item_tower_emb'
      ]
    elif self._loss_type == LossType.SOFTMAX_CROSS_ENTROPY:
      self._prediction_dict['logits'] = tf.squeeze(
          self._prediction_dict['logits'], axis=-1)
      self._prediction_dict['probs'] = tf.nn.sigmoid(
          self._prediction_dict['logits'])
      return [
          'logits', 'probs', 'user_emb', 'item_emb', 'user_tower_emb',
          'item_tower_emb'
      ]
    elif self._loss_type == LossType.L2_LOSS:
      return ['y', 'user_emb', 'item_emb', 'user_tower_emb', 'item_tower_emb']
    else:
      raise ValueError('invalid loss type: %s' % str(self._loss_type))
