# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""utils"""

import ast
import inspect
import socket
import subprocess
import textwrap
import time
import math
import copy
import concurrent.futures
from contextlib import closing
from types import SimpleNamespace

import pynvml
import numpy as np
import torch
import torch.nn.functional as F
from chatlearn.utils.logger import logger


def get_attributes(cls):
    """Get attributes from class."""
    return [(name, attr) for name, attr in inspect.getmembers(cls)
            if not (name.startswith('_')) and (not callable(attr))]


def parse_function_args(func):
    args = []

    def parse_func_args(node):
        for argument in node.args.args:
            args.append(argument.arg)

    node_iter = ast.NodeVisitor()
    node_iter.visit_FunctionDef = parse_func_args
    code = textwrap.dedent(inspect.getsource(func))
    node_iter.visit(ast.parse(code))
    return args


def get_return_lines(node):
    for line in node.body:
        if isinstance(line, ast.Return):
            return line
    for line in node.body:
        if isinstance(line, ast.If):
            return get_return_lines(line)


def get_return_value_num(ret):
    if isinstance(ret.value, ast.Name):
        return 1
    elif isinstance(ret.value, ast.Tuple):
        return len(ret.value.elts)
    elif isinstance(ret.value, ast.Call):
        raise RuntimeError("current do not support nested call in return")


def parse_function_return_num(func):
    results = []

    def parse_func_return(node):
        ret = get_return_lines(node)
        return_num = 0
        if ret is not None:
            return_num = get_return_value_num(ret)
        results.append(return_num)

    node_iter = ast.NodeVisitor()
    node_iter.visit_FunctionDef = parse_func_return
    code = textwrap.dedent(inspect.getsource(func))
    node_iter.visit(ast.parse(code))
    return results[0]


def get_host_addr():
    """
    get ip address in current node
    """
    hostname = socket.gethostname()
    ip_addr = socket.gethostbyname(hostname)
    return ip_addr


def get_free_port():
    """
    find a free port
    """
    with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as handle:
        handle.bind(('', 0))
        handle.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        return handle.getsockname()[1]


def split_index(length, num_splits):
    # Calculate the size of each split
    size = length // num_splits
    remainder = length % num_splits

    # Initialize an empty list for indices
    indices = []

    # Loop over the number of splits and append indices
    start = 0
    end = 0
    for _ in range(num_splits):
        end += size
        if remainder > 0:
            end += 1
            remainder -= 1
        indices.append((start, end))
        start = end

    # Return the list of indices
    return indices


def to_device(device, args):
    """
    Convert args to device recursively

    Args:
        device: gpu/cpu
        args: args to be converted
    """
    if isinstance(args, (list, tuple)):
        args = type(args)(to_device(device, arg) for arg in args)
    elif isinstance(args, dict):
        for key, value in args.items():
            args[key] = to_device(device, value)
    elif isinstance(args, torch.Tensor):
        args = args.to(device)
    return args


def get_or_cache(cache, key, func, *args, **kwargs):
    """
    get results if cached
    otherwise call the func to get the results, and cache the results
    """
    if key in cache:
        return cache[key]
    res = func(*args, **kwargs)
    cache[key] = res
    return res


def flatten(nested_list):
    flat = []
    for elem in nested_list:
        if isinstance(elem, list):
            flat.extend(flatten(elem))
        else:
            flat.append(elem)
    return flat


def get_indent_count(string):
    count = 0
    for s in string:
        if s == ' ':
            count += 1
        else:
            return count


def detect_and_insert_code(lines, pattern, new_code, additional_indent=0, line_offset=0, replace=False):
    """
    Insert new_code above the pattern detected
    """
    detected_lines = [(line_number, line) for line_number, line in enumerate(lines) if pattern in line]
    if not detected_lines:
        return
    type_line_number, type_line = detected_lines[0]
    indent = get_indent_count(type_line) + additional_indent
    new_lines = [line for line in new_code.split('\n') if line.strip()]
    added_lines = []
    for line in new_lines:
        added_lines.append(" "*indent + line)
    lines = lines[:type_line_number+line_offset - replace] + added_lines + lines[type_line_number+line_offset:]
    return lines

def detect_and_insert_code_to_func(source_code, pattern, new_code, additional_indent=0, line_offset=0, replace=False):
    lines = source_code.split('\n')
    lines = detect_and_insert_code(lines, pattern, new_code, additional_indent, line_offset, replace)
    if lines is None:
        return
    indent = get_indent_count(lines[0])
    lines = [line[indent:] for line in lines]
    return '\n'.join(lines)

def execute(cmd, check=False, retry=1):
    """
    Execute cmd in shell
    
    Args:
        check: if returncode is non-zero, raise error
    """
    ret = subprocess.run(cmd, shell=True, capture_output=True, text=True, check=check)
    state = ret.returncode == 0
    msg = ret.stdout if state else ret.stderr
    if not state:
        logger.warning(f"execute {cmd} got error {msg}")
        if retry > 1:
            logger.warning(f"retry {cmd} ...")
            time.sleep(1)
            return execute(cmd, check, retry-1)
    return state, msg


def is_connection_refused(msg):
    keywords = ["StatusCode.UNAVAILABLE", "Connection refused", "failed to connect to all addresses"]
    return any(keyword in msg for keyword in keywords)


def get_ray_status():
    cluster_state, msg = execute("ray status", retry=3)
    if cluster_state:
        return True, None
    elif is_connection_refused(msg):
        return False, msg
    # unknown msg
    return True, msg


def get_full_proc_memory_info(prefix):
    torch.cuda.synchronize()
    s = prefix + ': '
    s += f'memory allocated: {torch.cuda.memory_allocated() / (1 << 30):.2f} GiB, ' \
         f'memory reserved: {torch.cuda.memory_reserved() / (1 << 30):.2f} GiB, ' \
         f'proc memory usage: {nvml_proc_memory_info()}'
    return s


def nvml_proc_memory_info():
    pynvml.nvmlInit()
    s = ''
    for dev_id in range(pynvml.nvmlDeviceGetCount()):
        handle = pynvml.nvmlDeviceGetHandleByIndex(dev_id)
        mem_str = ' | '.join([f'(pid {proc.pid}: {proc.usedGpuMemory / (1 << 30):.2f} GiB)' \
                  for proc in pynvml.nvmlDeviceGetComputeRunningProcesses(handle)])
        s += mem_str
        break
    return s


def dict_to_simplenamespace(d):
    for key, value in d.items():
        if isinstance(value, dict):
            d[key] = dict_to_simplenamespace(value)
    return SimpleNamespace(**d)


def get_use_legacy_models(model_args):
    if isinstance(model_args, dict):
        use_legacy_models = model_args.get("use_legacy_models", None)
    else:
        use_legacy_models = getattr(model_args, "use_legacy_models", None)

    if use_legacy_models is None:
        raise RuntimeError("Please specify use_legacy_models (True or False), but not None.")
    return use_legacy_models


def execute_in_parallel(function, arguments):
    if len(arguments) == 1:
        return function(*arguments[0])
    results = []
    with concurrent.futures.ThreadPoolExecutor() as executor:
        # Using list comprehension to handle the results
        futures = [executor.submit(function, *args) for args in arguments]
        for _future in concurrent.futures.as_completed(futures):
            results.append(_future.result())
    return results


def multi_thread_data_processing(num_threads: int, all_data: list, process_one_data, fn_args: list):
    num_data = len(all_data)

    if num_data == 0:
        return []
    # reduce num_threads if data amount is little
    if num_data < num_threads:
        num_threads = num_data
    assert num_threads > 0, "Get num_threads <= 0. Expect to be a positive number."

    result = list(range(num_data))
    data_size_per_thread = math.ceil(num_data / num_threads)
    thread_args = [(i, all_data[i * data_size_per_thread : min((i+1) * data_size_per_thread, num_data)]) for i in range(num_threads)]

    def thread_fn(thread_id: int, pending_data):
        offset = thread_id * data_size_per_thread
        for i, data in enumerate(pending_data):
            result[offset + i] = process_one_data(data, *fn_args)

    execute_in_parallel(thread_fn, thread_args)

    return result


def regroup_by_concat_along_batch(tensors):
    batched = {}
    if tensors[0] is None:
        return batched
    for key in tensors[0].keys():
        to_batch = [results[key] for results in tensors]
        if isinstance(to_batch[0], torch.Tensor):
            if len(to_batch[0].shape) == 2:
                max_dim_1 = max([ele.shape[1] for ele in to_batch]) # pylint: disable=consider-using-generator
                pad_value = 0.0 if to_batch[0].dtype in [torch.float32, torch.float16, torch.bfloat16] else 0
                value = [
                    F.pad(
                        ele,
                        (0, max_dim_1 - ele.shape[1]),
                        value=pad_value,
                    )
                    for ele in to_batch
                ]
                batched[key] = torch.vstack(value)
            elif len(to_batch[0].shape) == 1:
                batched[key] = torch.concat(to_batch)
            else:
                raise RuntimeError(f"unsupported shape for in_queue rebatching. expect 1 or 2. while {to_batch[0].shape}")
        elif isinstance(to_batch[0], list):
            batched[key] = []
            for seq in to_batch:
                batched[key].extend(seq)
        else:
            raise Exception(f"unknown types key: {key} and {type(to_batch[0])} to concat : {to_batch[0]}")

    return batched

def slice_by_index_along_batch(batched_input, index):
    start = index[0]
    offset = index[1]
    batched = {}
    for key in batched_input.keys():
        if isinstance(batched_input[key], torch.Tensor):
            batched[key] = batched_input[key][start::offset,...]
        elif isinstance(batched_input[key], list):
            batched[key] = batched_input[key][start::offset]
    return batched
def listdict_to_dictlist(ld, list_extend=True):
    '''
    [{k1: v11, k2: v2}, {k1: v12, k2: v2},....] => {k1: [v11, v12..], k2: [v21, v22...]}
    if v11 is list then k1: v11 + v12
    :param ld:
    :return:
    '''
    res = copy.deepcopy(ld[0])
    for res_key, v in res.items():
        if list_extend and isinstance(res[res_key], list):
            continue

        res[res_key] = [v]

    for d in ld[1:]:
        for key, v in d.items():
            if list_extend and isinstance(d[key], list):
                res[key].extend(v)
            else:
                res[key].append(v)

    return res


def map_metrics(metric_list):
    mapped_metrics = {}
    for metrics in metric_list:
        for key, value in metrics.items():
            if key in mapped_metrics:
                mapped_metrics[key].append(value)
            else:
                mapped_metrics[key] = [value]
    return mapped_metrics


def reduce_metrics(merged_metrics):
    # [TODO:baodong.lh] support custom_op like min, max to reduce metrics
    reduced_metrics = {}
    for key, value_list in merged_metrics.items():
        if isinstance(value_list[0], torch.Tensor):
            value = torch.mean(torch.Tensor(value_list))
        else:
            value = np.mean(value_list)
        reduced_metrics[key] = value
    return reduced_metrics


def map_reduce_metrics(metric_list):
    # [TODO:baodong.lh] imporve performance by distributing the task to per-replica
    # sanity check
    assert isinstance(metric_list, list)

    if len(metric_list) == 0:
        return {}

    first_metric_len = len(metric_list[0])
    for i, metric in enumerate(metric_list):
        if len(metric) != first_metric_len:
            logger.info(
                f"WARNING! length of metrics are not the same for {i}-th metric ({len(metric)}) "
                f"and the first one ({first_metric_len})! This is weird and please check!"
            )

    mapped_metrics = map_metrics(metric_list)
    reduced_metrics = reduce_metrics(mapped_metrics)
    return reduced_metrics
