# Copyright 2021-present Intel Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Thrift SAI interface basic utils.
"""

import time
import struct
import socket
import json

from functools import wraps

from ptf.packet import *
from ptf.testutils import *

from sai_thrift.sai_adapter import *
import sai_thrift.sai_adapter as adapter

from typing import List, Dict
from typing import TYPE_CHECKING


def sai_thrift_query_attribute_enum_values_capability(client,
                                                      obj_type,
                                                      attr_id=None):
    """
    Call the sai_thrift_query_attribute_enum_values_capability() function
    and return the list of supported aattr_is enum capabilities

    Args:
        client (Client): SAI RPC client
        obj_type (enum): SAI object type
        attr_id (attr): SAI attribute name

    Returns:
        list: list of switch object type enum capabilities
    """
    max_cap_no = 20

    enum_cap_list = client.sai_thrift_query_attribute_enum_values_capability(
        obj_type, attr_id, max_cap_no)

    return enum_cap_list


def sai_thrift_object_type_get_availability(client,
                                            obj_type,
                                            attr_id=None,
                                            attr_type=None):
    """
    sai_thrift_object_type_get_availability() RPC client function
    implementation

    Args:
        client (Client): SAI RPC client
        obj_type (enum): SAI object type
        attr_id (attr): SAI attribute name
        attr_type (type): SAI attribute type

    Returns:
        uint: number of available resources with given parameters
    """
    availability_cnt = client.sai_thrift_object_type_get_availability(
        obj_type, attr_id, attr_type)

    return availability_cnt


def sai_thrift_object_type_query(client,
                                 obj_id=None):
    """
    sai_thrift_object_type_query() RPC client function
    implementation

    Args:
        client (Client): SAI RPC client
        obj_id (obj): SAI object id

    Returns:
        uint: object type
    """
    obj_type = client.sai_object_type_query(
        obj_id)

    return obj_type


def sai_thrift_switch_id_query(client,
                               obj_id=None):
    """
    sai_thrift_switch_id_query() RPC client function
    implementation

    Args:
        client (Client): SAI RPC client
        obj_id (obj): SAI object id

    Returns:
        uint: object type
    """
    switch_obj_id = client.sai_switch_id_query(
        obj_id)

    return switch_obj_id


def sai_thrift_api_uninitialize(client):
    """
    sai_thrift_api_uninitialize() RPC client function
    implementation
    Args:
        client (Client): SAI RPC client
    Returns:
        uint: object type
    """
    obj_type = client.sai_thrift_api_uninitialize()

    return obj_type


def sai_thrift_get_debug_counter_port_stats(client, port_oid, counter_ids):
    """
    Get port statistics for given debug counters

    Args:
        client (Client): SAI RPC client
        port_oid (sai_thrift_object_id_t): object_id IN argument
        counter_ids (sai_stat_id_t): list of requested counters

    Returns:
        Dict[str, sai_thrift_uint64_t]: stats
    """

    stats = {}
    counters = client.sai_thrift_get_port_stats(port_oid, counter_ids)

    for i, counter_id in enumerate(counter_ids):
        stats[counter_id] = counters[i]

    return stats


def sai_thrift_get_debug_counter_switch_stats(client, counter_ids):
    """
    Get switch statistics for given debug counters

    Args:
        client (Client): SAI RPC client
        counter_ids (sai_stat_id_t): list of requested counters

    Returns:
        Dict[str, sai_thrift_uint64_t]: stats
    """

    stats = {}
    counters = client.sai_thrift_get_switch_stats(counter_ids)

    for i, counter_id in enumerate(counter_ids):
        stats[counter_id] = counters[i]

    return stats


def sai_ipaddress(addr_str):
    """
    Set SAI IP address, assign appropriate type and return
    sai_thrift_ip_address_t object

    Args:
        addr_str (str): IP address string

    Returns:
        sai_thrift_ip_address_t: object containing IP address family and number
    """

    if '.' in addr_str:
        family = SAI_IP_ADDR_FAMILY_IPV4
        addr = sai_thrift_ip_addr_t(ip4=addr_str)
    if ':' in addr_str:
        family = SAI_IP_ADDR_FAMILY_IPV6
        addr = sai_thrift_ip_addr_t(ip6=addr_str)
    ip_addr = sai_thrift_ip_address_t(addr_family=family, addr=addr)

    return ip_addr


def sai_ipprefix(prefix_str):
    """
    Set IP address prefix and mask and return ip_prefix object

    Args:
        prefix_str (str): IP address and mask string (with slash notation)

    Return:
        sai_thrift_ip_prefix_t: IP prefix object
    """
    addr_mask = prefix_str.split('/')
    if len(addr_mask) != 2:
        print("Invalid IP prefix format")
        return None

    if '.' in prefix_str:
        family = SAI_IP_ADDR_FAMILY_IPV4
        addr = sai_thrift_ip_addr_t(ip4=addr_mask[0])
        mask = num_to_dotted_quad(addr_mask[1])
        mask = sai_thrift_ip_addr_t(ip4=mask)
    if ':' in prefix_str:
        family = SAI_IP_ADDR_FAMILY_IPV6
        addr = sai_thrift_ip_addr_t(ip6=addr_mask[0])
        mask = num_to_dotted_quad(int(addr_mask[1]), ipv4=False)
        mask = sai_thrift_ip_addr_t(ip6=mask)

    ip_prefix = sai_thrift_ip_prefix_t(
        addr_family=family, addr=addr, mask=mask)
    return ip_prefix


def num_to_dotted_quad(address, ipv4=True):
    """
    Helper function to convert the ip address

    Args:
        address (str): IP address
        ipv4 (bool): determines what IP version is handled

    Returns:
        str: formatted IP address
    """
    if ipv4 is True:
        mask = (1 << 32) - (1 << 32 >> int(address))
        return socket.inet_ntop(socket.AF_INET, struct.pack('>L', mask))

    mask = (1 << 128) - (1 << 128 >> int(address))
    i = 0
    result = ''
    for sign in str(hex(mask)[2:]):
        if (i + 1) % 4 == 0:
            result = result + sign + ':'
        else:
            result = result + sign
        i += 1
    return result[:-1]


def open_packet_socket(hostif_name):
    """
    Open a linux socket

    Args:
        hostif_name (str): socket interface name

    Return:
        sock: socket ID
    """
    eth_p_all = 3
    sock = socket.socket(socket.AF_PACKET, socket.SOCK_RAW,
                         socket.htons(eth_p_all))
    sock.bind((hostif_name, eth_p_all))
    sock.setblocking(0)

    return sock


def socket_verify_packet(pkt, sock, timeout=2):
    """
    Verify packet was received on a socket

    Args:
        pkt (packet): packet to match with
        sock (int): socket ID
        timeout (int): timeout

    Return:
        bool: True if packet matched
    """
    max_pkt_size = 9100
    timeout = time.time() + timeout
    match = False

    if isinstance(pkt, ptf.mask.Mask):
        if not pkt.is_valid():
            return False

    while time.time() < timeout:
        try:
            packet_from_tap_device = Ether(sock.recv(max_pkt_size))

            if isinstance(pkt, ptf.mask.Mask):
                match = pkt.pkt_match(packet_from_tap_device)
            else:
                match = (str(packet_from_tap_device) == str(pkt))

            if match:
                break

        except BaseException:
            pass

    return match


def delay_wrapper(func, delay=2):
    """
    A wrapper extending given function by a delay

    Args:
        func (function): function to be wrapped
        delay (int): delay period in sec

    Return:
        wrapped_function: wrapped function
    """
    @wraps(func)
    def wrapped_function(*args, **kwargs):
        """
        A wrapper function adding a delay

        Args:
            args (tuple): function arguments
            kwargs (dict): keyword function arguments

        Return:
            status: original function return value
        """
        test_params = test_params_get()
        if 'target' in test_params.keys() and test_params['target'] != "hw":
            time.sleep(delay)

        status = func(*args, **kwargs)
        return status

    return wrapped_function


sai_thrift_flush_fdb_entries = delay_wrapper(sai_thrift_flush_fdb_entries)


def warm_test(is_test_rebooting:bool=False, time_out=60, interval=1):
    """
    Method decorator for the method on warm testing.
    
    Depends on parameters [test_reboot_mode] and [test_reboot_stage].
    Runs different method, test_starting, setUp_post_start and runTest
    
    args:
        is_test_rebooting: whether running the test case when saiserver container shut down
        time_out: check saiserver contianer restart is complete within a 
                  certain time limit.if time limit if exceeded, raise error
        interval: frequency of check
    """
    def _check_run_case(f):
        def test_director(inst, *args):
            if inst.test_reboot_mode == 'warm':
                print("shutdown the swich in warm mode")
                sai_thrift_set_switch_attribute(inst.client, restart_warm=True)
                sai_thrift_set_switch_attribute(inst.client, pre_shutdown=True)
                sai_thrift_remove_switch(inst.client)
                sai_thrift_api_uninitialize(inst.client)
                # write content to reboot-requested
                print("write rebooting to file")
                warm_file = open('/tmp/warm_reboot','w+')
                warm_file.write('rebooting')
                warm_file.close()
                times = 0
                try:
                    while 1:
                        print("reading content in the warm_reboot")
                        warm_file = open('/tmp/warm_reboot','r')
                        txt = warm_file.readline()
                        warm_file.close()                
                        if 'post_reboot_done' in txt:
                            print("warm reboot is done, next, we will run the case")
                            break
                        if is_test_rebooting:
                            print("running in the rebooting stage, text is ", txt)
                            f(inst)
                        times = times + 1
                        time.sleep(interval)
                        print("alreay wait for ",times)
                        if times > time_out:
                            raise Exception("time out")
                except Exception as e:
                    print(e)
                
                inst.createRpcClient()
                inst.warm_start_switch()
            return f(inst)
        return test_director
    return _check_run_case

def query_counter(test, cnt_func, *args, **kwargs):
    """
    Get counter by each counter id for the counter function.
    This method depends on sai_adapater generation pattern.
    The cnt_func name must be with pattern sai_thrift_get_<counter_query_func>
    Then, expect there will be a counter dict with pattern 
        sai_<counter_query_func>_ids_dict
    and a counter list with name pattern
        sai_<counter_query_func>_ids
    Args:
        test: object extends from base test
        cnt_func: counter function
        args: counter function parameters
        kwargs: counter function parameters with name
    return:
        result: dict, counter name and  value
        supported_counters: supported counter name list
        unsupported_counters: unsupported counter name list
    """
    
    fun_name = cnt_func.__name__
    result = {}
    supported_counters = []
    unsupported_counters = []
    if not fun_name.startswith("sai_thrift_get"):
        # cannot get the func name directly
        # it should be a wrapper
        fun_name = inspect.getclosurevars(cnt_func).nonlocals['func'].__name__ 
    if not fun_name.startswith("sai_thrift_get"):
        raise ArgumentError("Cannot get the expected counter query method name")  

    cnt_query_fun_name = fun_name.lstrip("sai_thrift_")
    id_dict_name = "sai_{}_counter_ids_dict".format(cnt_query_fun_name)
    id_list_name = "sai_{}_counter_ids".format(cnt_query_fun_name)
    id_dict = getattr(adapter, id_dict_name)
    id_list = getattr(adapter, id_list_name)

    ignore_api_errors()
    for id in id_list:
        kwargs["counter_ids"] = [id]
        counter = id_dict[id]
        stats = cnt_func(test.client, *args, **kwargs)
        if test.status() == SAI_STATUS_SUCCESS:
            supported_counters.append(counter)
        else:
            unsupported_counters.append(counter)
        result[counter] = stats[counter]
    restore_api_error_code()
    return result


def clear_counter(test, cnt_func, *args, **kwargs):
    """
    Clear counter by each counter id for the counter function.
    This method depends on sai_adapater generation pattern.
    The cnt_func name must be with pattern sai_thrift_clear_<counter_query_func>
    Then, expect there will be a counter dict with pattern 
        sai_<counter_query_func>_ids_dict
    and a counter list with name pattern
        sai_<counter_query_func>_ids
    Args:
        test: object extends from base test
        cnt_func: counter function
        args: counter function parameters
        kwargs: counter function parameters with name
    return:
        supported_counters: supported counter name list
        unsupported_counters: unsupported counter name list
    """
    
    fun_name = cnt_func.__name__
    supported_counters = []
    unsupported_counters = []
    if not fun_name.startswith("sai_thrift_clear"):
        # cannot get the func name directly
        # it should be a wrapper
        fun_name = inspect.getclosurevars(cnt_func).nonlocals['func'].__name__ 
    if not fun_name.startswith("sai_thrift_clear"):
        raise ArgumentError("Cannot get the expected counter clear method name")  

    cnt_clear_fun_name = fun_name.lstrip("sai_thrift_")
    id_dict_name = "sai_{}_counter_ids_dict".format(cnt_clear_fun_name)
    id_list_name = "sai_{}_counter_ids".format(cnt_clear_fun_name)
    id_dict = getattr(adapter, id_dict_name)
    id_list = getattr(adapter, id_list_name)

    ignore_api_errors()
    for id in id_list:
        kwargs["counter_ids"] = [id]
        counter = id_dict[id]
        cnt_func(test.client, *args, **kwargs)
        if test.status() == SAI_STATUS_SUCCESS:
            supported_counters.append(counter)
        else:
            unsupported_counters.append(counter)
    restore_api_error_code()


capture_status = True
expected_code = []
def ignore_api_errors():
    """
    Ignore API errors.
    After run this function, all the API error will be caught
    and will not be raised.

    """
    #print("Ignore all the expect error code and exception captures.")
    global capture_status, expected_code
    capture_status = adapter.CATCH_EXCEPTIONS
    expected_code = adapter.EXPECTED_ERROR_CODE
    adapter.CATCH_EXCEPTIONS = True
    adapter.EXPECTED_ERROR_CODE = []
    return capture_status, expected_code


def restore_api_error_code():
    """
    Restore API error code and catch status.
    """
    #print("Restore all the expect error code and exception captures.")
    global capture_status, expected_code
    adapter.CATCH_EXCEPTIONS = capture_status
    adapter.EXPECTED_ERROR_CODE = expected_code
