# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#    http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
import multiprocessing as mp
import os
import time
from multiprocessing.connection import Connection
from typing import Any, List, Optional

import numpy as np
import numpy.typing as npt
import pyarrow as pa
from sklearn.cluster import KMeans

from tzrec.datasets.dataset import create_reader
from tzrec.tools.tdm.gen_tree import tree_builder
from tzrec.tools.tdm.gen_tree.tree_builder import TDMTreeNode
from tzrec.utils.env_util import use_hash_node_id
from tzrec.utils.logging_util import logger


class TreeCluster:
    """Cluster based on emb vec.

    Args:
        item_input_path (str): The file path where the item information is stored.
        item_id_field (str): The column name representing item_id in the file.
        attr_fields (List[str]): The column names representing the features in the file.
        output_dir (str): The output file.
        parallel (int): The number of CPU cores for parallel processing.
        n_cluster (int): The branching factor of the nodes in the tree.
    """

    def __init__(
        self,
        item_input_path: str,
        item_id_field: str,
        attr_fields: Optional[str] = None,
        raw_attr_fields: Optional[str] = None,
        output_dir: Optional[str] = None,
        embedding_field: str = "item_emb",
        parallel: int = 16,
        n_cluster: int = 2,
        **kwargs: Any,
    ) -> None:
        self.item_input_path = item_input_path
        self.mini_batch = 1024

        self.data = None
        self.leaf_nodes = None

        self.parallel = parallel
        self.queue = None
        self.timeout = 5
        self.codes = None
        self.output_dir = output_dir
        self.n_clusters = n_cluster

        self.item_id_field = item_id_field
        self.attr_fields = []
        self.raw_attr_fields = []
        if attr_fields:
            self.attr_fields = [x.strip() for x in attr_fields.split(",")]
        if raw_attr_fields:
            self.raw_attr_fields = [x.strip() for x in raw_attr_fields.split(",")]

        self.embedding_field = embedding_field

        self.dataset_kwargs = {}
        if "odps_data_quota_name" in kwargs:
            self.dataset_kwargs["quota_name"] = kwargs["odps_data_quota_name"]

    def _read(self) -> None:
        t1 = time.time()
        data = list()

        self.leaf_nodes = []

        selected_cols = (
            {self.item_id_field, self.embedding_field}
            | set(self.attr_fields)
            | set(self.raw_attr_fields)
        )
        reader = create_reader(
            self.item_input_path,
            4096,
            selected_cols=list(selected_cols),
            **self.dataset_kwargs,
        )

        for data_dict in reader.to_batches():
            if use_hash_node_id():
                ids = data_dict[self.item_id_field].cast(pa.string()).to_pylist()
            else:
                ids = data_dict[self.item_id_field].cast(pa.int64()).to_pylist()
            data += data_dict[self.embedding_field].to_pylist()

            batch_tree_nodes = []
            for one_id in ids:
                batch_tree_nodes.append(TDMTreeNode(item_id=one_id))

            for attr in self.attr_fields:
                attr_data = data_dict[attr]
                for i in range(len(batch_tree_nodes)):
                    batch_tree_nodes[i].attrs.append(attr_data[i])

            for attr in self.raw_attr_fields:
                attr_data = data_dict[attr]
                for i in range(len(batch_tree_nodes)):
                    batch_tree_nodes[i].raw_attrs.append(attr_data[i])

            self.leaf_nodes.extend(batch_tree_nodes)

        if isinstance(data[0], str):
            data = [eval(i) for i in data]
        self.data = np.array(data)
        t2 = time.time()

        logger.info(
            "Read data done, {} records read, elapsed: {}".format(
                len(self.leaf_nodes), t2 - t1
            )
        )

    def train(self, save_tree: bool = False) -> TDMTreeNode:
        """Cluster data."""
        self._read()
        # The (code, index) stored in the queue represent the node number
        # in the current class’s tree and the index of the item belonging
        # to this class, respectively.
        queue = mp.Queue()
        queue.put((0, np.arange(len(self.leaf_nodes))))
        processes = []
        pipes = []
        for _ in range(self.parallel):
            parent_conn, child_conn = mp.Pipe()
            p = mp.Process(target=self._train, args=(child_conn, queue))
            processes.append(p)
            pipes.append(parent_conn)
            p.start()

        self.codes = np.zeros((len(self.leaf_nodes),), dtype=np.int64)
        for pipe in pipes:
            codes = pipe.recv()
            for i in range(len(codes)):
                if codes[i] > 0:
                    self.leaf_nodes[i].tree_code = codes[i]

        for p in processes:
            p.join()

        assert queue.empty()
        builder = tree_builder.TreeBuilder(self.output_dir, self.n_clusters)
        root = builder.build(self.leaf_nodes, save_tree)
        return root

    # pyre-ignore [24]
    def _train(self, pipe: Connection, queue: mp.Queue) -> None:
        last_size = -1
        catch_time = 0
        processed = False
        code = np.zeros((len(self.leaf_nodes),), dtype=np.int64)
        parent_code = None
        index = None
        while True:
            for _ in range(3):
                try:
                    parent_code, index = queue.get(timeout=self.timeout)
                except Exception as _:
                    index = None
                if index is not None:
                    break

            if index is None:
                if processed and (last_size <= self.mini_batch or catch_time >= 3):
                    logger.info("Process {} exits".format(os.getpid()))
                    break
                else:
                    logger.info(
                        "Got empty job, pid: {}, time: {}".format(
                            os.getpid(), catch_time
                        )
                    )
                    catch_time += 1
                    continue

            processed = True
            catch_time = 0
            last_size = len(index)
            if last_size <= self.mini_batch:
                self._mini_batch(parent_code, index, code)
            else:
                start = time.time()
                sub_index = self._cluster(index)
                logger.info(
                    "Train iteration done, parent_code:{}, "
                    "data size: {}, elapsed time: {}".format(
                        parent_code, len(index), time.time() - start
                    )
                )
                self.timeout = int(0.4 * self.timeout + 0.6 * (time.time() - start))
                if self.timeout < 5:
                    self.timeout = 5

                for i in range(self.n_clusters):
                    if len(sub_index[i]) > 1:
                        queue.put((self.n_clusters * parent_code + i + 1, sub_index[i]))

        process_count = 0
        for c in code:
            if c > 0:
                process_count += 1
        logger.info("Process {} process {} items".format(os.getpid(), process_count))
        pipe.send(code)

    def _mini_batch(
        self, parent_code: int, index: npt.NDArray, code: npt.NDArray
    ) -> None:
        dq = collections.deque()
        dq.append((parent_code, index))
        batch_size = len(index)
        tstart = time.time()
        while dq:
            parent_code, index = dq.popleft()

            if len(index) <= self.n_clusters:
                for i in range(len(index)):
                    code[index[i]] = self.n_clusters * parent_code + i + 1
                continue

            sub_index = self._cluster(index)
            for i in range(self.n_clusters):
                if len(sub_index[i]) > 1:
                    dq.append((self.n_clusters * parent_code + i + 1, sub_index[i]))
                elif len(sub_index[i]) > 0:
                    for j in range(len(sub_index[i])):
                        code[sub_index[i][j]] = (
                            self.n_clusters * parent_code + i + j + 1
                        )
        logger.info(
            "Minibatch, batch size: {}, elapsed: {}".format(
                batch_size, time.time() - tstart
            )
        )

    def _cluster(self, index: npt.NDArray) -> List[npt.NDArray]:
        data = self.data[index]
        kmeans = KMeans(n_clusters=self.n_clusters, random_state=0).fit(data)
        labels = kmeans.labels_
        sub_indices = []
        remain_index = []
        ave_num = int(len(index) / self.n_clusters)

        for i in range(self.n_clusters):
            sub_i = np.where(labels == i)[0]
            sub_index = index[sub_i]
            if len(sub_index) <= ave_num:
                sub_indices.append(sub_index)
            else:
                distances = kmeans.transform(data[sub_i])[:, i]
                sorted_index = sub_index[np.argsort(distances)]
                sub_indices.append(sorted_index[:ave_num])
                remain_index.extend(list(sorted_index[ave_num:]))
        idx = 0
        remain_index = np.array(remain_index)

        # reblance index
        while idx < self.n_clusters and len(remain_index) > 0:
            if len(sub_indices[idx]) >= ave_num:
                idx += 1
            else:
                diff = min(len(remain_index), ave_num - len(sub_indices[idx]))
                remain_data = self.data[remain_index]
                distances = kmeans.transform(remain_data)[:, idx]
                sorted_index = remain_index[np.argsort(distances)]

                # Supplement the data by sorting the distances
                # to the current cluster centers in ascending order.
                sub_indices[idx] = np.append(
                    sub_indices[idx], np.array(sorted_index[0:diff])
                )
                remain_index = sorted_index[diff:]
                idx += 1
        if len(remain_index) > 0:
            sub_indices[0] = np.append(sub_indices[0], remain_index)

        return sub_indices
