src/neper_healthcheck/neper_runner.py (266 lines of code) (raw):
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Runs the neper test.
This module control execution of neper test.
"""
from collections.abc import Callable
import os
import re
import time
import checker_common
JOB_NAME = os.getenv("JOB_NAME")
SERVICE_NAME = os.getenv("SERVICE_NAME")
POD_NAME = os.getenv("POD_NAME")
_RESULT_LABEL_KEY = "aiinfra/neper-healthcheck-result"
TAINT_KEY = "aiinfra/neper-healthcheck"
TAINT_EFFECT = "NoSchedule"
HEALTHCHECK_TIME_LABEL_KEY = "aiinfra/neper-healthcheck-runtime-sec"
K_ADD_LABEL_FORMAT = "/scripts/kubectl label node %s %s=%s --overwrite"
K_TAINT_NODE_FORMAT = "/scripts/kubectl taint node %s %s=%s:%s"
K_REMOVE_LABEL_FORMAT = "/scripts/kubectl label node %s %s-"
K_REMOVE_TAINT_NODE_FORMAT = "/scripts/kubectl taint node %s %s-"
def ensure_env_variables() -> None:
"""Ensure necessary environment variables are set."""
required_envs = [
"NODE_NAME",
"NODE_IP",
"GOOD_THROUGHPUT",
"HEALTH_VALIDITY_HOURS",
"POD_NAME",
"JOB_NAME",
"SERVICE_NAME",
"DRY_RUN",
]
for env in required_envs:
if env not in os.environ:
raise ValueError(f"Must set {env}")
print("ENV %s=%s" % (env, os.environ[env]))
def configure_ssh() -> None:
"""Configures SSH settings."""
checker_common.run_command(
"sed -i 's/#Port 22/Port 222/g' /etc/ssh/sshd_config", check=False
)
checker_common.run_command("service ssh restart")
with open("/root/.ssh/config", "a") as f:
f.write("""
Host *
StrictHostKeyChecking no
User root
IdentityFile /root/.ssh/google_compute_engine
Port 222""")
def get_host_to_ips() -> dict[str, list[str]]:
"""Generates a hostfile based on host names from pods."""
hosts = {}
pod_name = f"{JOB_NAME}-1.{SERVICE_NAME}"
host_name = get_host_name(pod_name)
raw_ip_addresses = get_ip_addresses(pod_name)
ip_addresses = raw_ip_addresses.strip().split("\n")
if host_name and ip_addresses:
hosts[host_name] = ip_addresses
print(f"Got host information from pod: {pod_name} on host {host_name}")
print(f"Got ip information from pod: {pod_name} on ip {ip_addresses}")
return hosts
def run_neper_test(
hosts_to_ips: dict[str, list[str]],
) -> list[Callable[[], None]]:
"""Runs the Neper test."""
def cleanup_delete_temp_files(regex: str) -> Callable[[], None]:
def delete_file() -> None:
print(f"Deleting temporary files with path: {regex}")
checker_common.run_command(f"rm {regex}")
return delete_file
cleanup_functions = []
if f"{JOB_NAME}-0" in POD_NAME: # master node
print("I am a master that will run neper test on 2 nodes")
self_host = checker_common.run_command("cat /host.name").stdout
log_files = []
for host, ips in hosts_to_ips.items():
print(f"Host: {host}")
count = 0
for dst_ip in ips:
count += 1
log_file = f"/tmp/{self_host}_{host}_eth{count}.log"
log_files.append(log_file)
checker_common.run_command(
"taskset -c 17-24,73-80 /scripts/tcp_stream -rw --client -H"
f" '{dst_ip}' --skip-rx-copy --num-threads=16 --num-flows=200"
f" --suicide-length=600 --test-length=30 > '{log_file}' 2>&1"
)
cleanup_functions.append(cleanup_delete_temp_files(log_file))
checker_common.run_command(
f"ssh {JOB_NAME}-1.{SERVICE_NAME} -p 222 -- touch"
f" /master{count}.done",
)
process_test_result(log_files, self_host, host)
else: # secondary nodes
for _, ips in hosts_to_ips.items():
count = 0
for dst_ip in ips:
count += 1
print(f"spinning up neper server for {dst_ip}...")
checker_common.run_command(
"taskset -c 17-24,73-80 /scripts/tcp_stream -rw --skip-rx-copy "
"--num-threads=16 --num-flows=200 --suicide-length=600 "
"--test-length=30 &"
)
while not os.path.exists(f"/master{count}.done"):
print(f"test for {dst_ip} not done")
time.sleep(10)
print(f"test for {dst_ip} is done")
cleanup_functions.append(
cleanup_delete_temp_files(f"/master{count}.done")
)
return cleanup_functions
def process_test_result(
log_files: list[str], local_host: str, remote_host: str
) -> None:
"""Analyze the log files and add taints to the nodes that yield bad throughput."""
threshold = int(os.environ["GOOD_THROUGHPUT"])
count = 0
local_test_failed = False
remote_test_failed = False
local_throughput_by_eth = {}
remote_throughput_by_eth = {}
for log_file in log_files:
count += 1
local_throughput = get_throughput(log_file, local=True)
remote_throughput = get_throughput(log_file, local=False)
local_throughput_by_eth[f"eth{count}"] = local_throughput
remote_throughput_by_eth[f"eth{count}"] = remote_throughput
if local_throughput < threshold:
local_test_failed = True
print(
f"local host {local_host} failed the neper test at eth{count} with"
f" throughput {local_throughput}. Adding node taints..."
)
checker_common.add_label(
local_host,
f"{TAINT_KEY}_eth{count}",
f"{local_throughput}",
K_ADD_LABEL_FORMAT,
)
else:
remove_label(local_host, f"{TAINT_KEY}_eth{count}")
if remote_throughput < threshold:
remote_test_failed = True
print(
f"remote host {remote_host} failed the neper test at eth{count} with"
f" throughput {remote_throughput}. Adding node taints..."
)
checker_common.add_label(
remote_host,
f"{TAINT_KEY}_eth{count}",
f"{remote_throughput}",
K_ADD_LABEL_FORMAT,
)
else:
# Removing taints and labels.
remove_label(remote_host, f"{TAINT_KEY}_eth{count}")
apply_fail_label(local_test_failed, local_host, remote_host)
apply_fail_label(remote_test_failed, remote_host, local_host)
add_healthcheck_time_label(local_host)
add_healthcheck_time_label(remote_host)
server_client = {"server": local_host, "client": remote_host}
checker_common.log_results(
test_name="neper",
passed=not local_test_failed,
node_name=local_host,
workflow_id=os.environ.get("WORKFLOW_ID"),
result_data={
"throughput_by_eth": local_throughput_by_eth,
"server_client": server_client,
},
)
checker_common.log_results(
test_name="neper",
passed=not remote_test_failed,
node_name=remote_host,
workflow_id=os.environ.get("WORKFLOW_ID"),
result_data={
"throughput_by_eth": remote_throughput_by_eth,
"server_client": server_client,
},
)
# Add a label to the local and remote host to indicate if the test passed or
# failed.
result = (
"pass" if not local_test_failed and not remote_test_failed else "fail"
)
checker_common.add_label(
local_host,
_RESULT_LABEL_KEY,
result,
K_ADD_LABEL_FORMAT,
)
checker_common.add_label(
remote_host,
_RESULT_LABEL_KEY,
result,
K_ADD_LABEL_FORMAT,
)
def get_throughput(log_file: str, local: bool) -> int:
"""Get local/remote throughput number from a log file."""
with open(log_file, "r") as f:
log_output = f.read()
remote_throughput_match = re.search(r"remote_throughput=(\d+)", log_output)
local_throughput_match = re.search(r"local_throughput=(\d+)", log_output)
if local and local_throughput_match:
local_throughput = int(local_throughput_match.group(1))
return local_throughput
if not local and remote_throughput_match:
remote_throughput = int(remote_throughput_match.group(1))
return remote_throughput
return -1
def get_ip_addresses(pod_name: str) -> str:
"""Retrieve the host name where the specified pod is running.
Args:
pod_name (str): The name of the pod for which the host name is to be
retrieved.
Returns:
str: The host name where the pod is running.
"""
start_time = time.time()
while timeout_check(start_time, pod_name):
result = checker_common.run_command(
f"ssh {pod_name} -p 222 -- cat /tmp/ip_addrs",
check=False,
)
if result.returncode == 0:
return result.stdout
time.sleep(1)
return ""
def get_host_name(pod_name: str) -> str:
"""Retrieve the host name where the specified pod is running.
Args:
pod_name (str): The name of the pod for which the host name is to be
retrieved.
Returns:
str: The host name where the pod is running.
"""
start_time = time.time()
while timeout_check(start_time, pod_name):
result = checker_common.run_command(
f"ssh {pod_name} -p 222 -- cat /host.name",
check=False,
)
if result.returncode == 0:
return result.stdout
time.sleep(1)
return ""
def timeout_check(start_time: float, pod_name: str) -> bool:
"""Check if we exceed allocated timeout to get host name from that pod_name.
Args:
start_time (float): Time when we start checking the pod.
pod_name (str): The name of the pod to be checked.
Returns:
bool
Raises:
TimeoutError: If the pod has been running longer than the allocated timeout.
"""
elapsed_time = time.time() - start_time
if elapsed_time >= 10 * 60: # 10 minutes
raise TimeoutError(
f"10min Timeout reached while trying to ssh to pod {pod_name}"
)
return True
def taint_node(node_name: str, key: str, value: str, effect: str) -> None:
"""Apply a taint to a specified node with given key, value, and effect.
Args:
node_name (str): The name of the node to be tainted.
key (str): The taint key to be set.
value (str): The taint value to be set.
effect (str): The effect of the taint (e.g., "NoExecute", "NoSchedule").
"""
if os.environ.get("DRY_RUN") != "true":
print("adding taint %s=%s to node %s" % (key, value, node_name))
checker_common.run_command(
K_TAINT_NODE_FORMAT % (node_name, key, value, effect)
)
def remove_node_taint(node_name: str, taint_key: str) -> None:
print("removing taint %s from node %s" % (taint_key, node_name))
checker_common.run_command(
K_REMOVE_TAINT_NODE_FORMAT % (node_name, taint_key)
)
def add_healthcheck_time_label(node_name: str) -> None:
"""Add healthcheck time label to node."""
checker_common.add_label(
node_name,
HEALTHCHECK_TIME_LABEL_KEY,
f"{int(time.time())}",
K_ADD_LABEL_FORMAT,
)
def apply_fail_label(check_failed: bool, node_name: str, value: str) -> None:
if check_failed:
taint_node(node_name, TAINT_KEY, "failed", TAINT_EFFECT)
checker_common.add_label(node_name, TAINT_KEY, value, K_ADD_LABEL_FORMAT)
else:
remove_node_taint(node_name, TAINT_KEY)
remove_label(node_name, TAINT_KEY)
def remove_label(node_name: str, label: str) -> None:
print("removing label %s from node %s" % (label, node_name))
checker_common.run_command(K_REMOVE_LABEL_FORMAT % (node_name, label))
def main() -> None:
"""Main function."""
ensure_env_variables()
configure_ssh()
node_name = os.environ["NODE_NAME"]
with open("/host.name", "w") as f:
f.write(node_name)
host_to_ips = get_host_to_ips()
cleanup_funcs = run_neper_test(host_to_ips)
print("my job is done, running cleanups...")
for cleanup in cleanup_funcs:
cleanup()
print("cleanups are done... exiting...")
if __name__ == "__main__":
main()