ptf/sai_utils.py (227 lines of code) (raw):

# Copyright 2021-present Intel Corporation. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Thrift SAI interface basic utils. """ import time import struct import socket import json from functools import wraps from ptf.packet import * from ptf.testutils import * from sai_thrift.sai_adapter import * import sai_thrift.sai_adapter as adapter from typing import List, Dict from typing import TYPE_CHECKING def sai_thrift_query_attribute_enum_values_capability(client, obj_type, attr_id=None): """ Call the sai_thrift_query_attribute_enum_values_capability() function and return the list of supported aattr_is enum capabilities Args: client (Client): SAI RPC client obj_type (enum): SAI object type attr_id (attr): SAI attribute name Returns: list: list of switch object type enum capabilities """ max_cap_no = 20 enum_cap_list = client.sai_thrift_query_attribute_enum_values_capability( obj_type, attr_id, max_cap_no) return enum_cap_list def sai_thrift_object_type_get_availability(client, obj_type, attr_id=None, attr_type=None): """ sai_thrift_object_type_get_availability() RPC client function implementation Args: client (Client): SAI RPC client obj_type (enum): SAI object type attr_id (attr): SAI attribute name attr_type (type): SAI attribute type Returns: uint: number of available resources with given parameters """ availability_cnt = client.sai_thrift_object_type_get_availability( obj_type, attr_id, attr_type) return availability_cnt def sai_thrift_object_type_query(client, obj_id=None): """ sai_thrift_object_type_query() RPC client function implementation Args: client (Client): SAI RPC client obj_id (obj): SAI object id Returns: uint: object type """ obj_type = client.sai_object_type_query( obj_id) return obj_type def sai_thrift_switch_id_query(client, obj_id=None): """ sai_thrift_switch_id_query() RPC client function implementation Args: client (Client): SAI RPC client obj_id (obj): SAI object id Returns: uint: object type """ switch_obj_id = client.sai_switch_id_query( obj_id) return switch_obj_id def sai_thrift_api_uninitialize(client): """ sai_thrift_api_uninitialize() RPC client function implementation Args: client (Client): SAI RPC client Returns: uint: object type """ obj_type = client.sai_thrift_api_uninitialize() return obj_type def sai_thrift_get_debug_counter_port_stats(client, port_oid, counter_ids): """ Get port statistics for given debug counters Args: client (Client): SAI RPC client port_oid (sai_thrift_object_id_t): object_id IN argument counter_ids (sai_stat_id_t): list of requested counters Returns: Dict[str, sai_thrift_uint64_t]: stats """ stats = {} counters = client.sai_thrift_get_port_stats(port_oid, counter_ids) for i, counter_id in enumerate(counter_ids): stats[counter_id] = counters[i] return stats def sai_thrift_get_debug_counter_switch_stats(client, counter_ids): """ Get switch statistics for given debug counters Args: client (Client): SAI RPC client counter_ids (sai_stat_id_t): list of requested counters Returns: Dict[str, sai_thrift_uint64_t]: stats """ stats = {} counters = client.sai_thrift_get_switch_stats(counter_ids) for i, counter_id in enumerate(counter_ids): stats[counter_id] = counters[i] return stats def sai_ipaddress(addr_str): """ Set SAI IP address, assign appropriate type and return sai_thrift_ip_address_t object Args: addr_str (str): IP address string Returns: sai_thrift_ip_address_t: object containing IP address family and number """ if '.' in addr_str: family = SAI_IP_ADDR_FAMILY_IPV4 addr = sai_thrift_ip_addr_t(ip4=addr_str) if ':' in addr_str: family = SAI_IP_ADDR_FAMILY_IPV6 addr = sai_thrift_ip_addr_t(ip6=addr_str) ip_addr = sai_thrift_ip_address_t(addr_family=family, addr=addr) return ip_addr def sai_ipprefix(prefix_str): """ Set IP address prefix and mask and return ip_prefix object Args: prefix_str (str): IP address and mask string (with slash notation) Return: sai_thrift_ip_prefix_t: IP prefix object """ addr_mask = prefix_str.split('/') if len(addr_mask) != 2: print("Invalid IP prefix format") return None if '.' in prefix_str: family = SAI_IP_ADDR_FAMILY_IPV4 addr = sai_thrift_ip_addr_t(ip4=addr_mask[0]) mask = num_to_dotted_quad(addr_mask[1]) mask = sai_thrift_ip_addr_t(ip4=mask) if ':' in prefix_str: family = SAI_IP_ADDR_FAMILY_IPV6 addr = sai_thrift_ip_addr_t(ip6=addr_mask[0]) mask = num_to_dotted_quad(int(addr_mask[1]), ipv4=False) mask = sai_thrift_ip_addr_t(ip6=mask) ip_prefix = sai_thrift_ip_prefix_t( addr_family=family, addr=addr, mask=mask) return ip_prefix def num_to_dotted_quad(address, ipv4=True): """ Helper function to convert the ip address Args: address (str): IP address ipv4 (bool): determines what IP version is handled Returns: str: formatted IP address """ if ipv4 is True: mask = (1 << 32) - (1 << 32 >> int(address)) return socket.inet_ntop(socket.AF_INET, struct.pack('>L', mask)) mask = (1 << 128) - (1 << 128 >> int(address)) i = 0 result = '' for sign in str(hex(mask)[2:]): if (i + 1) % 4 == 0: result = result + sign + ':' else: result = result + sign i += 1 return result[:-1] def open_packet_socket(hostif_name): """ Open a linux socket Args: hostif_name (str): socket interface name Return: sock: socket ID """ eth_p_all = 3 sock = socket.socket(socket.AF_PACKET, socket.SOCK_RAW, socket.htons(eth_p_all)) sock.bind((hostif_name, eth_p_all)) sock.setblocking(0) return sock def socket_verify_packet(pkt, sock, timeout=2): """ Verify packet was received on a socket Args: pkt (packet): packet to match with sock (int): socket ID timeout (int): timeout Return: bool: True if packet matched """ max_pkt_size = 9100 timeout = time.time() + timeout match = False if isinstance(pkt, ptf.mask.Mask): if not pkt.is_valid(): return False while time.time() < timeout: try: packet_from_tap_device = Ether(sock.recv(max_pkt_size)) if isinstance(pkt, ptf.mask.Mask): match = pkt.pkt_match(packet_from_tap_device) else: match = (str(packet_from_tap_device) == str(pkt)) if match: break except BaseException: pass return match def delay_wrapper(func, delay=2): """ A wrapper extending given function by a delay Args: func (function): function to be wrapped delay (int): delay period in sec Return: wrapped_function: wrapped function """ @wraps(func) def wrapped_function(*args, **kwargs): """ A wrapper function adding a delay Args: args (tuple): function arguments kwargs (dict): keyword function arguments Return: status: original function return value """ test_params = test_params_get() if 'target' in test_params.keys() and test_params['target'] != "hw": time.sleep(delay) status = func(*args, **kwargs) return status return wrapped_function sai_thrift_flush_fdb_entries = delay_wrapper(sai_thrift_flush_fdb_entries) def warm_test(is_test_rebooting:bool=False, time_out=60, interval=1): """ Method decorator for the method on warm testing. Depends on parameters [test_reboot_mode] and [test_reboot_stage]. Runs different method, test_starting, setUp_post_start and runTest args: is_test_rebooting: whether running the test case when saiserver container shut down time_out: check saiserver contianer restart is complete within a certain time limit.if time limit if exceeded, raise error interval: frequency of check """ def _check_run_case(f): def test_director(inst, *args): if inst.test_reboot_mode == 'warm': print("shutdown the swich in warm mode") sai_thrift_set_switch_attribute(inst.client, restart_warm=True) sai_thrift_set_switch_attribute(inst.client, pre_shutdown=True) sai_thrift_remove_switch(inst.client) sai_thrift_api_uninitialize(inst.client) # write content to reboot-requested print("write rebooting to file") warm_file = open('/tmp/warm_reboot','w+') warm_file.write('rebooting') warm_file.close() times = 0 try: while 1: print("reading content in the warm_reboot") warm_file = open('/tmp/warm_reboot','r') txt = warm_file.readline() warm_file.close() if 'post_reboot_done' in txt: print("warm reboot is done, next, we will run the case") break if is_test_rebooting: print("running in the rebooting stage, text is ", txt) f(inst) times = times + 1 time.sleep(interval) print("alreay wait for ",times) if times > time_out: raise Exception("time out") except Exception as e: print(e) inst.createRpcClient() inst.warm_start_switch() return f(inst) return test_director return _check_run_case def query_counter(test, cnt_func, *args, **kwargs): """ Get counter by each counter id for the counter function. This method depends on sai_adapater generation pattern. The cnt_func name must be with pattern sai_thrift_get_<counter_query_func> Then, expect there will be a counter dict with pattern sai_<counter_query_func>_ids_dict and a counter list with name pattern sai_<counter_query_func>_ids Args: test: object extends from base test cnt_func: counter function args: counter function parameters kwargs: counter function parameters with name return: result: dict, counter name and value supported_counters: supported counter name list unsupported_counters: unsupported counter name list """ fun_name = cnt_func.__name__ result = {} supported_counters = [] unsupported_counters = [] if not fun_name.startswith("sai_thrift_get"): # cannot get the func name directly # it should be a wrapper fun_name = inspect.getclosurevars(cnt_func).nonlocals['func'].__name__ if not fun_name.startswith("sai_thrift_get"): raise ArgumentError("Cannot get the expected counter query method name") cnt_query_fun_name = fun_name.lstrip("sai_thrift_") id_dict_name = "sai_{}_counter_ids_dict".format(cnt_query_fun_name) id_list_name = "sai_{}_counter_ids".format(cnt_query_fun_name) id_dict = getattr(adapter, id_dict_name) id_list = getattr(adapter, id_list_name) ignore_api_errors() for id in id_list: kwargs["counter_ids"] = [id] counter = id_dict[id] stats = cnt_func(test.client, *args, **kwargs) if test.status() == SAI_STATUS_SUCCESS: supported_counters.append(counter) else: unsupported_counters.append(counter) result[counter] = stats[counter] restore_api_error_code() return result def clear_counter(test, cnt_func, *args, **kwargs): """ Clear counter by each counter id for the counter function. This method depends on sai_adapater generation pattern. The cnt_func name must be with pattern sai_thrift_clear_<counter_query_func> Then, expect there will be a counter dict with pattern sai_<counter_query_func>_ids_dict and a counter list with name pattern sai_<counter_query_func>_ids Args: test: object extends from base test cnt_func: counter function args: counter function parameters kwargs: counter function parameters with name return: supported_counters: supported counter name list unsupported_counters: unsupported counter name list """ fun_name = cnt_func.__name__ supported_counters = [] unsupported_counters = [] if not fun_name.startswith("sai_thrift_clear"): # cannot get the func name directly # it should be a wrapper fun_name = inspect.getclosurevars(cnt_func).nonlocals['func'].__name__ if not fun_name.startswith("sai_thrift_clear"): raise ArgumentError("Cannot get the expected counter clear method name") cnt_clear_fun_name = fun_name.lstrip("sai_thrift_") id_dict_name = "sai_{}_counter_ids_dict".format(cnt_clear_fun_name) id_list_name = "sai_{}_counter_ids".format(cnt_clear_fun_name) id_dict = getattr(adapter, id_dict_name) id_list = getattr(adapter, id_list_name) ignore_api_errors() for id in id_list: kwargs["counter_ids"] = [id] counter = id_dict[id] cnt_func(test.client, *args, **kwargs) if test.status() == SAI_STATUS_SUCCESS: supported_counters.append(counter) else: unsupported_counters.append(counter) restore_api_error_code() capture_status = True expected_code = [] def ignore_api_errors(): """ Ignore API errors. After run this function, all the API error will be caught and will not be raised. """ #print("Ignore all the expect error code and exception captures.") global capture_status, expected_code capture_status = adapter.CATCH_EXCEPTIONS expected_code = adapter.EXPECTED_ERROR_CODE adapter.CATCH_EXCEPTIONS = True adapter.EXPECTED_ERROR_CODE = [] return capture_status, expected_code def restore_api_error_code(): """ Restore API error code and catch status. """ #print("Restore all the expect error code and exception captures.") global capture_status, expected_code adapter.CATCH_EXCEPTIONS = capture_status adapter.EXPECTED_ERROR_CODE = expected_code