ptf/sai_utils.py (227 lines of code) (raw):
# Copyright 2021-present Intel Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Thrift SAI interface basic utils.
"""
import time
import struct
import socket
import json
from functools import wraps
from ptf.packet import *
from ptf.testutils import *
from sai_thrift.sai_adapter import *
import sai_thrift.sai_adapter as adapter
from typing import List, Dict
from typing import TYPE_CHECKING
def sai_thrift_query_attribute_enum_values_capability(client,
obj_type,
attr_id=None):
"""
Call the sai_thrift_query_attribute_enum_values_capability() function
and return the list of supported aattr_is enum capabilities
Args:
client (Client): SAI RPC client
obj_type (enum): SAI object type
attr_id (attr): SAI attribute name
Returns:
list: list of switch object type enum capabilities
"""
max_cap_no = 20
enum_cap_list = client.sai_thrift_query_attribute_enum_values_capability(
obj_type, attr_id, max_cap_no)
return enum_cap_list
def sai_thrift_object_type_get_availability(client,
obj_type,
attr_id=None,
attr_type=None):
"""
sai_thrift_object_type_get_availability() RPC client function
implementation
Args:
client (Client): SAI RPC client
obj_type (enum): SAI object type
attr_id (attr): SAI attribute name
attr_type (type): SAI attribute type
Returns:
uint: number of available resources with given parameters
"""
availability_cnt = client.sai_thrift_object_type_get_availability(
obj_type, attr_id, attr_type)
return availability_cnt
def sai_thrift_object_type_query(client,
obj_id=None):
"""
sai_thrift_object_type_query() RPC client function
implementation
Args:
client (Client): SAI RPC client
obj_id (obj): SAI object id
Returns:
uint: object type
"""
obj_type = client.sai_object_type_query(
obj_id)
return obj_type
def sai_thrift_switch_id_query(client,
obj_id=None):
"""
sai_thrift_switch_id_query() RPC client function
implementation
Args:
client (Client): SAI RPC client
obj_id (obj): SAI object id
Returns:
uint: object type
"""
switch_obj_id = client.sai_switch_id_query(
obj_id)
return switch_obj_id
def sai_thrift_api_uninitialize(client):
"""
sai_thrift_api_uninitialize() RPC client function
implementation
Args:
client (Client): SAI RPC client
Returns:
uint: object type
"""
obj_type = client.sai_thrift_api_uninitialize()
return obj_type
def sai_thrift_get_debug_counter_port_stats(client, port_oid, counter_ids):
"""
Get port statistics for given debug counters
Args:
client (Client): SAI RPC client
port_oid (sai_thrift_object_id_t): object_id IN argument
counter_ids (sai_stat_id_t): list of requested counters
Returns:
Dict[str, sai_thrift_uint64_t]: stats
"""
stats = {}
counters = client.sai_thrift_get_port_stats(port_oid, counter_ids)
for i, counter_id in enumerate(counter_ids):
stats[counter_id] = counters[i]
return stats
def sai_thrift_get_debug_counter_switch_stats(client, counter_ids):
"""
Get switch statistics for given debug counters
Args:
client (Client): SAI RPC client
counter_ids (sai_stat_id_t): list of requested counters
Returns:
Dict[str, sai_thrift_uint64_t]: stats
"""
stats = {}
counters = client.sai_thrift_get_switch_stats(counter_ids)
for i, counter_id in enumerate(counter_ids):
stats[counter_id] = counters[i]
return stats
def sai_ipaddress(addr_str):
"""
Set SAI IP address, assign appropriate type and return
sai_thrift_ip_address_t object
Args:
addr_str (str): IP address string
Returns:
sai_thrift_ip_address_t: object containing IP address family and number
"""
if '.' in addr_str:
family = SAI_IP_ADDR_FAMILY_IPV4
addr = sai_thrift_ip_addr_t(ip4=addr_str)
if ':' in addr_str:
family = SAI_IP_ADDR_FAMILY_IPV6
addr = sai_thrift_ip_addr_t(ip6=addr_str)
ip_addr = sai_thrift_ip_address_t(addr_family=family, addr=addr)
return ip_addr
def sai_ipprefix(prefix_str):
"""
Set IP address prefix and mask and return ip_prefix object
Args:
prefix_str (str): IP address and mask string (with slash notation)
Return:
sai_thrift_ip_prefix_t: IP prefix object
"""
addr_mask = prefix_str.split('/')
if len(addr_mask) != 2:
print("Invalid IP prefix format")
return None
if '.' in prefix_str:
family = SAI_IP_ADDR_FAMILY_IPV4
addr = sai_thrift_ip_addr_t(ip4=addr_mask[0])
mask = num_to_dotted_quad(addr_mask[1])
mask = sai_thrift_ip_addr_t(ip4=mask)
if ':' in prefix_str:
family = SAI_IP_ADDR_FAMILY_IPV6
addr = sai_thrift_ip_addr_t(ip6=addr_mask[0])
mask = num_to_dotted_quad(int(addr_mask[1]), ipv4=False)
mask = sai_thrift_ip_addr_t(ip6=mask)
ip_prefix = sai_thrift_ip_prefix_t(
addr_family=family, addr=addr, mask=mask)
return ip_prefix
def num_to_dotted_quad(address, ipv4=True):
"""
Helper function to convert the ip address
Args:
address (str): IP address
ipv4 (bool): determines what IP version is handled
Returns:
str: formatted IP address
"""
if ipv4 is True:
mask = (1 << 32) - (1 << 32 >> int(address))
return socket.inet_ntop(socket.AF_INET, struct.pack('>L', mask))
mask = (1 << 128) - (1 << 128 >> int(address))
i = 0
result = ''
for sign in str(hex(mask)[2:]):
if (i + 1) % 4 == 0:
result = result + sign + ':'
else:
result = result + sign
i += 1
return result[:-1]
def open_packet_socket(hostif_name):
"""
Open a linux socket
Args:
hostif_name (str): socket interface name
Return:
sock: socket ID
"""
eth_p_all = 3
sock = socket.socket(socket.AF_PACKET, socket.SOCK_RAW,
socket.htons(eth_p_all))
sock.bind((hostif_name, eth_p_all))
sock.setblocking(0)
return sock
def socket_verify_packet(pkt, sock, timeout=2):
"""
Verify packet was received on a socket
Args:
pkt (packet): packet to match with
sock (int): socket ID
timeout (int): timeout
Return:
bool: True if packet matched
"""
max_pkt_size = 9100
timeout = time.time() + timeout
match = False
if isinstance(pkt, ptf.mask.Mask):
if not pkt.is_valid():
return False
while time.time() < timeout:
try:
packet_from_tap_device = Ether(sock.recv(max_pkt_size))
if isinstance(pkt, ptf.mask.Mask):
match = pkt.pkt_match(packet_from_tap_device)
else:
match = (str(packet_from_tap_device) == str(pkt))
if match:
break
except BaseException:
pass
return match
def delay_wrapper(func, delay=2):
"""
A wrapper extending given function by a delay
Args:
func (function): function to be wrapped
delay (int): delay period in sec
Return:
wrapped_function: wrapped function
"""
@wraps(func)
def wrapped_function(*args, **kwargs):
"""
A wrapper function adding a delay
Args:
args (tuple): function arguments
kwargs (dict): keyword function arguments
Return:
status: original function return value
"""
test_params = test_params_get()
if 'target' in test_params.keys() and test_params['target'] != "hw":
time.sleep(delay)
status = func(*args, **kwargs)
return status
return wrapped_function
sai_thrift_flush_fdb_entries = delay_wrapper(sai_thrift_flush_fdb_entries)
def warm_test(is_test_rebooting:bool=False, time_out=60, interval=1):
"""
Method decorator for the method on warm testing.
Depends on parameters [test_reboot_mode] and [test_reboot_stage].
Runs different method, test_starting, setUp_post_start and runTest
args:
is_test_rebooting: whether running the test case when saiserver container shut down
time_out: check saiserver contianer restart is complete within a
certain time limit.if time limit if exceeded, raise error
interval: frequency of check
"""
def _check_run_case(f):
def test_director(inst, *args):
if inst.test_reboot_mode == 'warm':
print("shutdown the swich in warm mode")
sai_thrift_set_switch_attribute(inst.client, restart_warm=True)
sai_thrift_set_switch_attribute(inst.client, pre_shutdown=True)
sai_thrift_remove_switch(inst.client)
sai_thrift_api_uninitialize(inst.client)
# write content to reboot-requested
print("write rebooting to file")
warm_file = open('/tmp/warm_reboot','w+')
warm_file.write('rebooting')
warm_file.close()
times = 0
try:
while 1:
print("reading content in the warm_reboot")
warm_file = open('/tmp/warm_reboot','r')
txt = warm_file.readline()
warm_file.close()
if 'post_reboot_done' in txt:
print("warm reboot is done, next, we will run the case")
break
if is_test_rebooting:
print("running in the rebooting stage, text is ", txt)
f(inst)
times = times + 1
time.sleep(interval)
print("alreay wait for ",times)
if times > time_out:
raise Exception("time out")
except Exception as e:
print(e)
inst.createRpcClient()
inst.warm_start_switch()
return f(inst)
return test_director
return _check_run_case
def query_counter(test, cnt_func, *args, **kwargs):
"""
Get counter by each counter id for the counter function.
This method depends on sai_adapater generation pattern.
The cnt_func name must be with pattern sai_thrift_get_<counter_query_func>
Then, expect there will be a counter dict with pattern
sai_<counter_query_func>_ids_dict
and a counter list with name pattern
sai_<counter_query_func>_ids
Args:
test: object extends from base test
cnt_func: counter function
args: counter function parameters
kwargs: counter function parameters with name
return:
result: dict, counter name and value
supported_counters: supported counter name list
unsupported_counters: unsupported counter name list
"""
fun_name = cnt_func.__name__
result = {}
supported_counters = []
unsupported_counters = []
if not fun_name.startswith("sai_thrift_get"):
# cannot get the func name directly
# it should be a wrapper
fun_name = inspect.getclosurevars(cnt_func).nonlocals['func'].__name__
if not fun_name.startswith("sai_thrift_get"):
raise ArgumentError("Cannot get the expected counter query method name")
cnt_query_fun_name = fun_name.lstrip("sai_thrift_")
id_dict_name = "sai_{}_counter_ids_dict".format(cnt_query_fun_name)
id_list_name = "sai_{}_counter_ids".format(cnt_query_fun_name)
id_dict = getattr(adapter, id_dict_name)
id_list = getattr(adapter, id_list_name)
ignore_api_errors()
for id in id_list:
kwargs["counter_ids"] = [id]
counter = id_dict[id]
stats = cnt_func(test.client, *args, **kwargs)
if test.status() == SAI_STATUS_SUCCESS:
supported_counters.append(counter)
else:
unsupported_counters.append(counter)
result[counter] = stats[counter]
restore_api_error_code()
return result
def clear_counter(test, cnt_func, *args, **kwargs):
"""
Clear counter by each counter id for the counter function.
This method depends on sai_adapater generation pattern.
The cnt_func name must be with pattern sai_thrift_clear_<counter_query_func>
Then, expect there will be a counter dict with pattern
sai_<counter_query_func>_ids_dict
and a counter list with name pattern
sai_<counter_query_func>_ids
Args:
test: object extends from base test
cnt_func: counter function
args: counter function parameters
kwargs: counter function parameters with name
return:
supported_counters: supported counter name list
unsupported_counters: unsupported counter name list
"""
fun_name = cnt_func.__name__
supported_counters = []
unsupported_counters = []
if not fun_name.startswith("sai_thrift_clear"):
# cannot get the func name directly
# it should be a wrapper
fun_name = inspect.getclosurevars(cnt_func).nonlocals['func'].__name__
if not fun_name.startswith("sai_thrift_clear"):
raise ArgumentError("Cannot get the expected counter clear method name")
cnt_clear_fun_name = fun_name.lstrip("sai_thrift_")
id_dict_name = "sai_{}_counter_ids_dict".format(cnt_clear_fun_name)
id_list_name = "sai_{}_counter_ids".format(cnt_clear_fun_name)
id_dict = getattr(adapter, id_dict_name)
id_list = getattr(adapter, id_list_name)
ignore_api_errors()
for id in id_list:
kwargs["counter_ids"] = [id]
counter = id_dict[id]
cnt_func(test.client, *args, **kwargs)
if test.status() == SAI_STATUS_SUCCESS:
supported_counters.append(counter)
else:
unsupported_counters.append(counter)
restore_api_error_code()
capture_status = True
expected_code = []
def ignore_api_errors():
"""
Ignore API errors.
After run this function, all the API error will be caught
and will not be raised.
"""
#print("Ignore all the expect error code and exception captures.")
global capture_status, expected_code
capture_status = adapter.CATCH_EXCEPTIONS
expected_code = adapter.EXPECTED_ERROR_CODE
adapter.CATCH_EXCEPTIONS = True
adapter.EXPECTED_ERROR_CODE = []
return capture_status, expected_code
def restore_api_error_code():
"""
Restore API error code and catch status.
"""
#print("Restore all the expect error code and exception captures.")
global capture_status, expected_code
adapter.CATCH_EXCEPTIONS = capture_status
adapter.EXPECTED_ERROR_CODE = expected_code