#!/usr/bin/env python3

# --------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in the project root for license information.
# --------------------------------------------------------------------------

"""Azure VM utilities self-tests script."""
import argparse
import glob
import json
import logging
import os
import re
import subprocess
import time
import urllib.error
import urllib.request
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Literal, Optional

logger = logging.getLogger("selftest")

# pylint: disable=broad-exception-caught
# pylint: disable=line-too-long
# pylint: disable=too-many-lines
# pylint: disable=too-many-instance-attributes
# pylint: disable=too-many-lines
# pylint: disable=too-many-locals

SYS_CLASS_NET = "/sys/class/net"


@dataclass(eq=True, repr=True)
class SkuConfig:
    """VM sku-specific configuration related to disks."""

    vm_size: str
    vm_size_type: Literal["arm64", "x64"] = "x64"
    nvme_controller_toggle_supported: bool = (
        False  # whether the sku supports NVMe controller toggle (Eb[d]s_v5)
    )
    nvme_only: bool = False  # NVMe-only skus (v6+)
    nvme_id_enabled_local: bool = False  # whether the sku supports NVMe ID locally
    nvme_id_enabled_remote: bool = False  # whether the sku supports NVMe ID remotely
    nvme_local_disk_count: int = 0
    nvme_local_disk_size_gib: int = 0
    temp_disk_size_gib: int = 0  # SCSI temp/resource disk size in GiB


@dataclass(eq=True, repr=True)
class V6SkuConfig(SkuConfig):
    """V6 VM sku-specific configuration related to disks."""

    nvme_only: bool = True
    nvme_id_enabled_local: bool = True
    nvme_id_enabled_remote: bool = False


def gb_to_gib(size_gb: int) -> int:
    """Roughly convert GB to GiB as sizes are documented in both ways."""
    return int(size_gb * (1000**3) / (1024**3))


SKU_CONFIGS = {
    "Standard_B2ts_v2": SkuConfig(vm_size="Standard_B2ts_v2"),
    # "Standard_D2s_v3": SkuConfig(vm_size="Standard_D2s_v3", temp_disk_size_gib=16),
    "Standard_D2s_v4": SkuConfig(vm_size="Standard_D2s_v4"),
    "Standard_D2ds_v4": SkuConfig(vm_size="Standard_D2ds_v4", temp_disk_size_gib=75),
    "Standard_D2s_v5": SkuConfig(vm_size="Standard_D2s_v5"),
    "Standard_D2ds_v5": SkuConfig(vm_size="Standard_D2ds_v5", temp_disk_size_gib=75),
    "Standard_D2ads_v5": SkuConfig(vm_size="Standard_D2ads_v5", temp_disk_size_gib=75),
    "Standard_D16ads_v5": SkuConfig(
        vm_size="Standard_D16ads_v5", temp_disk_size_gib=600
    ),
    "Standard_L8s_v2": SkuConfig(
        vm_size="Standard_L8s_v2",
        temp_disk_size_gib=80,
        nvme_local_disk_count=1,
        nvme_local_disk_size_gib=gb_to_gib(1920),
    ),
    "Standard_L8s_v3": SkuConfig(
        vm_size="Standard_L8s_v3",
        temp_disk_size_gib=80,
        nvme_local_disk_count=1,
        nvme_local_disk_size_gib=gb_to_gib(1920),
    ),
    "Standard_L80s_v3": SkuConfig(
        vm_size="Standard_L80s_v3",
        nvme_controller_toggle_supported=True,
        temp_disk_size_gib=800,
        nvme_local_disk_count=10,
        nvme_local_disk_size_gib=gb_to_gib(1920),
    ),
    "Standard_E2bs_v5": SkuConfig(
        vm_size="Standard_E2bs_v5", nvme_controller_toggle_supported=True
    ),
    "Standard_E2bds_v5": SkuConfig(
        vm_size="Standard_E2bds_v5",
        nvme_controller_toggle_supported=True,
        temp_disk_size_gib=75,
    ),
    "Standard_D2s_v6": V6SkuConfig(vm_size="Standard_D2s_v6"),
    "Standard_D2ds_v6": V6SkuConfig(
        vm_size="Standard_D2ds_v6",
        nvme_local_disk_count=1,
        nvme_local_disk_size_gib=110,
    ),
    "Standard_D16ds_v6": V6SkuConfig(
        vm_size="Standard_D16ds_v6",
        nvme_local_disk_count=2,
        nvme_local_disk_size_gib=440,
    ),
    "Standard_D32ds_v6": V6SkuConfig(
        vm_size="Standard_D32ds_v6",
        nvme_local_disk_count=4,
        nvme_local_disk_size_gib=440,
    ),
    "Standard_D2as_v6": V6SkuConfig(vm_size="Standard_D2as_v6"),
    "Standard_D2ads_v6": V6SkuConfig(
        vm_size="Standard_D2ads_v6",
        nvme_local_disk_count=1,
        nvme_local_disk_size_gib=110,
    ),
    "Standard_D16ads_v6": V6SkuConfig(
        vm_size="Standard_D16ads_v6",
        nvme_local_disk_count=2,
        nvme_local_disk_size_gib=440,
    ),
    "Standard_D32ads_v6": V6SkuConfig(
        vm_size="Standard_D32ads_v6",
        nvme_local_disk_count=4,
        nvme_local_disk_size_gib=440,
    ),
    "Standard_D2pls_v5": SkuConfig(
        vm_size="Standard_D2pls_v5",
        vm_size_type="arm64",
    ),
    "Standard_D2plds_v5": SkuConfig(
        vm_size="Standard_D2plds_v5",
        vm_size_type="arm64",
        temp_disk_size_gib=75,
    ),
    "Standard_D8pls_v5": SkuConfig(
        vm_size="Standard_D8pls_v5",
        vm_size_type="arm64",
    ),
    "Standard_D8plds_v5": SkuConfig(
        vm_size="Standard_D8plds_v5",
        vm_size_type="arm64",
        temp_disk_size_gib=300,
    ),
    "Standard_D2pls_v6": SkuConfig(
        vm_size="Standard_D2pls_v6",
        vm_size_type="arm64",
    ),
    "Standard_D2plds_v6": SkuConfig(
        vm_size="Standard_D2plds_v6",
        vm_size_type="arm64",
        nvme_local_disk_count=1,
        nvme_local_disk_size_gib=110,
    ),
    "Standard_D16pls_v6": SkuConfig(
        vm_size="Standard_D16pls_v6",
        vm_size_type="arm64",
    ),
    "Standard_D16plds_v6": SkuConfig(
        vm_size="Standard_D16plds_v6",
        vm_size_type="arm64",
        nvme_local_disk_count=2,
        nvme_local_disk_size_gib=440,
    ),
}


def device_sort(devices: List[str]) -> List[str]:
    """Natural sort for devices."""

    def natural_sort_key(s: str):
        # Natural sort by turning a string into a list of string and number chunks.
        # e.g. "nvme0n10" -> ["nvme", 0, "n", 10]
        return [
            int(text) if text.isdigit() else text for text in re.split("([0-9]+)", s)
        ]

    return sorted(devices, key=natural_sort_key)


def get_disk_size_gb(disk_path: str) -> int:
    """Get the size of the disk in GB."""
    try:
        proc = subprocess.run(
            ["lsblk", "-b", "-n", "-o", "SIZE", "-d", disk_path],
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=True,
            check=True,
        )
        logger.debug("lsblk output: %r", proc)
        size_bytes = int(proc.stdout.strip())
        size_gib = size_bytes // (1000**3)
        return size_gib
    except subprocess.CalledProcessError as error:
        logger.error("error while fetching disk size: %r", error)
        raise
    except FileNotFoundError:
        logger.error("lsblk command not found")
        raise


def get_disk_size_gib(disk_path: str) -> int:
    """Get the size of the disk in GiB."""
    try:
        proc = subprocess.run(
            ["lsblk", "-b", "-n", "-o", "SIZE", "-d", disk_path],
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=True,
            check=True,
        )
        logger.debug("lsblk output: %r", proc)
        size_bytes = int(proc.stdout.strip())
        size_gib = size_bytes // (1024**3)
        return size_gib
    except subprocess.CalledProcessError as error:
        logger.error("error while fetching disk size: %r", error)
        raise
    except FileNotFoundError:
        logger.error("lsblk command not found")
        raise


def get_imds_metadata() -> Dict:
    """Fetch IMDS metadata using urllib."""
    url = "http://169.254.169.254/metadata/instance?api-version=2021-02-01"
    headers = {"Metadata": "true"}

    req = urllib.request.Request(url, headers=headers)

    last_error = None
    deadline = time.time() + 300
    while time.time() < deadline:
        try:
            with urllib.request.urlopen(req, timeout=60) as response:
                if response.status != 200:
                    raise urllib.error.HTTPError(
                        url,
                        response.status,
                        "Failed to fetch metadata",
                        response.headers,
                        None,
                    )
                metadata = json.load(response)
                logger.debug("fetched IMDS metadata: %r", metadata)
                return metadata
        except urllib.error.URLError as error:
            last_error = error
            logger.error("error fetching IMDS metadata: %r", error)
            time.sleep(1)

    raise RuntimeError(f"failed to fetch IMDS metadata: {last_error}")


def get_local_nvme_disks() -> List[str]:
    """Get all local NVMe disks."""
    local_disk_controllers = get_nvme_controllers_with_model(
        "Microsoft NVMe Direct Disk"
    )
    local_disk_controllers_v2 = get_nvme_controllers_with_model(
        "Microsoft NVMe Direct Disk v2"
    )

    return device_sort(
        [
            namespace
            for controller in local_disk_controllers + local_disk_controllers_v2
            for namespace in get_nvme_namespace_devices(controller)
        ]
    )


def get_remote_nvme_disks() -> List[str]:
    """Get all remote NVMe disks."""
    remote_disk_controllers = get_nvme_controllers_with_model(
        "MSFT NVMe Accelerator v1.0"
    )

    assert (
        len(remote_disk_controllers) <= 1
    ), f"unexpected number of remote controllers {remote_disk_controllers}"
    return device_sort(
        [
            namespace
            for controller in remote_disk_controllers
            for namespace in get_nvme_namespace_devices(controller)
        ]
    )


def get_nvme_controllers_with_model(model: str) -> List[str]:
    """Get a list of all NVMe controllers with the specified model."""
    nvme_controllers = []
    nvme_path = "/sys/class/nvme"

    for controller in glob.glob(os.path.join(nvme_path, "nvme*")):
        logger.debug("checking controller: %s", controller)
        model_path = os.path.join(controller, "model")
        try:
            with open(model_path, "r", encoding="utf-8") as file:
                controller_model = file.read().strip()
                logger.debug("controller: %s model: %s", controller, controller_model)
                if controller_model == model:
                    controller_name = controller.split("/")[-1]
                    nvme_controllers.append(controller_name)
        except FileNotFoundError:
            logger.debug("model file not found: %s", model_path)
            continue

    return device_sort(nvme_controllers)


def get_nvme_namespace_devices_with_model(model: str) -> List[str]:
    """Get all NVMe namespace devices for a given NVMe controller model."""
    controllers = get_nvme_controllers_with_model(model)
    logger.debug("controllers found for model=%s: %r", model, controllers)
    return device_sort(
        [
            namespace
            for controller in controllers
            for namespace in get_nvme_namespace_devices(controller)
        ]
    )


def get_nvme_namespace_devices(controller: str) -> List[str]:
    """Get all NVMe namespace devices for a given NVMe controller."""
    namespace_devices = []
    controller_name = controller.split("/")[-1]
    nvme_path = f"/sys/class/nvme/{controller_name}"

    logger.debug("checking namespaces under %s", nvme_path)
    for namespace in glob.glob(os.path.join(nvme_path, "nvme*")):
        logger.debug("checking namespace device: %s", namespace)
        if os.path.isdir(namespace):
            device_name = namespace.split("/")[-1]
            namespace_devices.append(device_name)

    return device_sort(namespace_devices)


def get_root_block_device() -> str:
    """Get the root block device using findmnt."""
    try:
        proc = subprocess.run(
            ["findmnt", "-n", "-o", "SOURCE", "/"],
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=True,
            check=True,
        )
        logger.debug("findmnt output: %r", proc)
        return proc.stdout.strip()
    except subprocess.CalledProcessError as error:
        logger.error("error while fetching root block device: %r", error)
        raise
    except FileNotFoundError:
        logger.error("findmnt command not found")
        raise


def get_scsi_resource_disk() -> Optional[str]:
    """Get the SCSI resource disk device."""
    paths = [
        # cloud-init udev rules
        "/dev/disk/cloud/azure_resource",
        # gen2
        "/dev/disk/by-path/acpi-VMBUS:00-vmbus-f8b3781a1e824818a1c363d806ec15bb-lun-1",
        # gen1
        "/dev/disk/by-path/acpi-VMBUS:01-vmbus-000000000001*-lun-0",
    ]

    for path in paths:
        if "*" in path:
            matched_paths = glob.glob(path)
            for matched_path in matched_paths:
                resolved_path = os.path.realpath(matched_path)
                if os.path.exists(resolved_path):
                    return resolved_path.split("/")[-1]
        else:
            if os.path.exists(path):
                resolved_path = os.path.realpath(path)
                if os.path.exists(resolved_path):
                    return resolved_path.split("/")[-1]

    logger.info("no SCSI resource disk found")
    return None


DEV_DISK_AZURE_RESOURCE = "/dev/disk/azure/resource"


@dataclass(eq=True, repr=True)
class DiskInfo:
    """Information about different types of disks present."""

    root_device: str
    dev_disk_azure_links: List[str] = field(default_factory=list)
    dev_disk_azure_resource_disk: Optional[str] = None  # resolved path
    dev_disk_azure_resource_disk_size_gib: int = 0
    nvme_local_disk_size_gib: int = 0
    nvme_local_disks_v1: List[str] = field(default_factory=list)
    nvme_local_disks_v2: List[str] = field(default_factory=list)
    nvme_local_disks: List[str] = field(default_factory=list)
    nvme_remote_data_disks: List[str] = field(default_factory=list)
    nvme_remote_disks: List[str] = field(default_factory=list)
    nvme_remote_os_disk: Optional[str] = None
    root_device_is_nvme: bool = False
    scsi_resource_disk: Optional[str] = None
    scsi_resource_disk_size_gib: int = 0

    @classmethod
    def gather(cls) -> "DiskInfo":
        """Gather disk information and return an instance of DiskInfo."""
        dev_disk_azure_links = device_sort(
            [
                link
                for link in glob.glob(
                    os.path.join("/dev/disk/azure", "**"), recursive=True
                )
                if os.path.islink(link)
            ]
        )

        dev_disk_azure_resource_disk = None
        dev_disk_azure_resource_disk_size_gib = 0
        if os.path.exists(DEV_DISK_AZURE_RESOURCE):
            dev_disk_azure_resource_disk = os.path.realpath(DEV_DISK_AZURE_RESOURCE)
            dev_disk_azure_resource_disk_size_gib = get_disk_size_gib(
                dev_disk_azure_resource_disk
            )

        nvme_local_disks_v1 = get_nvme_namespace_devices_with_model(
            "Microsoft NVMe Direct Disk"
        )
        nvme_local_disks_v2 = get_nvme_namespace_devices_with_model(
            "Microsoft NVMe Direct Disk v2"
        )
        nvme_local_disks = device_sort(nvme_local_disks_v1 + nvme_local_disks_v2)
        nvme_local_disk_size_gib = 0
        if nvme_local_disks:
            nvme_local_disk_size_gib = min(
                get_disk_size_gib(f"/dev/{disk}") for disk in nvme_local_disks
            )
            local_disk_max_size = max(
                get_disk_size_gib(f"/dev/{disk}") for disk in nvme_local_disks
            )
            assert (
                nvme_local_disk_size_gib == local_disk_max_size
            ), f"local disk size mismatch: {nvme_local_disk_size_gib} != {local_disk_max_size} for {nvme_local_disks}"

        nvme_remote_disks = get_remote_nvme_disks()
        if nvme_remote_disks:
            nvme_remote_os_disk = nvme_remote_disks.pop(0)
            nvme_remote_data_disks = nvme_remote_disks
        else:
            nvme_remote_os_disk = None
            nvme_remote_data_disks = []

        root_device = get_root_block_device()
        root_device_is_nvme = root_device.startswith("/dev/nvme")
        root_device = root_device.split("/")[-1]

        scsi_resource_disk = get_scsi_resource_disk()
        scsi_resource_disk_size_gib = (
            get_disk_size_gib(f"/dev/{scsi_resource_disk}") if scsi_resource_disk else 0
        )

        disk_info = cls(
            dev_disk_azure_links=dev_disk_azure_links,
            dev_disk_azure_resource_disk=dev_disk_azure_resource_disk,
            dev_disk_azure_resource_disk_size_gib=dev_disk_azure_resource_disk_size_gib,
            nvme_local_disk_size_gib=nvme_local_disk_size_gib,
            nvme_local_disks_v1=nvme_local_disks_v1,
            nvme_local_disks_v2=nvme_local_disks_v2,
            nvme_local_disks=nvme_local_disks,
            nvme_remote_os_disk=nvme_remote_os_disk,
            nvme_remote_data_disks=nvme_remote_data_disks,
            nvme_remote_disks=nvme_remote_disks,
            root_device=root_device,
            root_device_is_nvme=root_device_is_nvme,
            scsi_resource_disk=scsi_resource_disk,
            scsi_resource_disk_size_gib=scsi_resource_disk_size_gib,
        )

        logger.info("disks info: %r", disk_info)
        return disk_info


@dataclass
class AzureNvmeIdDevice:
    """Azure NVMe ID device."""

    device: str
    model: Optional[str]
    nvme_id: str
    type: Optional[str]
    index: Optional[int]
    lun: Optional[int]
    name: Optional[str]
    extra: Dict[str, str]


@dataclass(repr=True)
class AzureNvmeIdInfo:
    """Azure NVMe ID."""

    azure_nvme_id_stdout: str
    azure_nvme_id_stderr: str
    azure_nvme_id_returncode: int
    azure_nvme_id_disks: Dict[str, AzureNvmeIdDevice]

    azure_nvme_id_json_stdout: str
    azure_nvme_id_json_stderr: str
    azure_nvme_id_json_returncode: int
    azure_nvme_id_json_disks: Dict[str, AzureNvmeIdDevice]

    azure_nvme_id_help_stdout: str
    azure_nvme_id_help_stderr: str
    azure_nvme_id_help_returncode: int

    azure_nvme_id_version_stdout: str
    azure_nvme_id_version_stderr: str
    azure_nvme_id_version_returncode: int
    azure_nvme_id_version: str

    azure_nvme_id_zzz_stdout: str
    azure_nvme_id_zzz_stderr: str
    azure_nvme_id_zzz_returncode: int

    def _validate_azure_nvme_disks(
        self, azure_nvme_id_disks: Dict[str, AzureNvmeIdDevice], disk_info: DiskInfo
    ) -> None:
        disk_cfg: Optional[AzureNvmeIdDevice] = None
        for device_name, disk_cfg in azure_nvme_id_disks.items():
            assert f"/dev/{device_name}" == disk_cfg.device
            assert disk_cfg.device.startswith(
                "/dev/nvme"
            ), f"unexpected device: {disk_cfg}"

        for device_name in disk_info.nvme_local_disks_v2:
            assert (
                device_name in azure_nvme_id_disks
            ), f"missing azure-nvme-id for {device_name}"
            disk_cfg = azure_nvme_id_disks.get(device_name)
            assert disk_cfg, f"failed to find azure-nvme-id for {device_name}"
            assert disk_cfg.type == "local", f"unexpected local disk type {disk_cfg}"
            assert disk_cfg.name, f"unexpected local disk name {disk_cfg}"
            assert disk_cfg.index, f"unexpected local disk index {disk_cfg}"
            assert disk_cfg.lun is None, f"unexpected local disk lun {disk_cfg}"
            assert disk_cfg.nvme_id, f"unexpected local disk id {disk_cfg}"
            assert not disk_cfg.extra, f"unexpected local disk extra {disk_cfg}"

        for device_name in disk_info.nvme_local_disks_v1:
            assert (
                device_name in azure_nvme_id_disks
            ), f"missing azure-nvme-id for {device_name}"
            disk_cfg = azure_nvme_id_disks.get(device_name)
            assert disk_cfg, f"failed to find azure-nvme-id for {device_name}"
            assert disk_cfg.type == "local", f"unexpected disk type {disk_cfg}"
            assert not disk_cfg.name, f"unexpected disk name {disk_cfg}"
            assert not disk_cfg.index, f"unexpected disk index {disk_cfg}"
            assert disk_cfg.lun is None, f"unexpected local disk lun {disk_cfg}"
            assert disk_cfg.nvme_id, f"unexpected disk id {disk_cfg}"
            assert not disk_cfg.extra, f"unexpected disk extra {disk_cfg}"

        for device_name in disk_info.nvme_remote_disks:
            assert (
                device_name in azure_nvme_id_disks
            ), f"missing azure-nvme-id for {device_name}"
            disk_cfg = azure_nvme_id_disks.get(device_name)
            assert disk_cfg, f"failed to find azure-nvme-id for {device_name}"
            assert disk_cfg.type in (
                "os",
                "data",
            ), f"unexpected remote disk type {disk_cfg}"
            if disk_cfg.type == "data":
                assert (
                    disk_cfg.lun is not None and disk_cfg.lun >= 0
                ), f"unexpected remote disk index {disk_cfg}"
            else:
                assert disk_cfg.lun is None, f"unexpected remote disk index {disk_cfg}"
            assert not disk_cfg.name, f"unexpected remote disk name {disk_cfg}"
            assert disk_cfg.nvme_id, f"unexpected remote disk id {disk_cfg}"
            assert not disk_cfg.extra, f"unexpected remote disk extra {disk_cfg}"

        logger.info("validate_azure_nvme_disks OK: %r", self.azure_nvme_id_disks)

    def validate_azure_nvme_id(self, disk_info: DiskInfo) -> None:
        """Validate azure-nvme-id outputs."""
        assert self.azure_nvme_id_returncode == 0, "azure-nvme-id failed"
        if not os.path.exists("/sys/class/nvme"):
            assert (
                self.azure_nvme_id_stderr
                == "no NVMe devices in /sys/class/nvme: No such file or directory\n"
            ), f"unexpected azure-nvme-id stderr without /sys/class/nvme: {self.azure_nvme_id_stderr}"
        else:
            assert (
                self.azure_nvme_id_stderr == ""
            ), f"unexpected azure-nvme-id stderr: {self.azure_nvme_id_stderr}"

        self._validate_azure_nvme_disks(self.azure_nvme_id_disks, disk_info)
        logger.info("validate_azure_nvmve_id OK: %r", self.azure_nvme_id_stdout)

    def validate_azure_nvme_id_help(self) -> None:
        """Validate azure-nvme-id --help outputs."""
        assert self.azure_nvme_id_help_returncode == 0, "azure-nvme-id --help failed"
        assert (
            self.azure_nvme_id_help_stderr == ""
        ), f"unexpected azure-nvme-id --help stderr: {self.azure_nvme_id_help_stderr!r}"
        assert (
            self.azure_nvme_id_help_stdout
            and self.azure_nvme_id_help_stdout.startswith("Usage: azure-nvme-id ")
        ), "unexpected azure-nvme-id --help stdout: {self.azure_nvme_id_help_stdout!r}"

        logger.info(
            "validate_azure_nvme_id_help OK: %r", self.azure_nvme_id_help_stdout
        )

    def validate_azure_nvme_id_json(self, disk_info: DiskInfo) -> None:
        """Validate azure-nvme-id --format json outputs."""
        assert self.azure_nvme_id_json_returncode == 0, "azure-nvme-id failed"
        if not os.path.exists("/sys/class/nvme"):
            assert (
                self.azure_nvme_id_json_stderr
                == "no NVMe devices in /sys/class/nvme: No such file or directory\n"
            ), f"unexpected azure-nvme-id stderr without /sys/class/nvme: {self.azure_nvme_id_json_stderr}"
        else:
            assert (
                self.azure_nvme_id_json_stderr == ""
            ), f"unexpected azure-nvme-id stderr: {self.azure_nvme_id_json_stderr}"

        self._validate_azure_nvme_disks(self.azure_nvme_id_disks, disk_info)

        assert all(
            disk.model
            in (
                "MSFT NVMe Accelerator v1.0",
                "Microsoft NVMe Direct Disk",
                "Microsoft NVMe Direct Disk v2",
            )
            for disk in self.azure_nvme_id_json_disks.values()
        ), "missing model in azure-nvme-id --format json"
        logger.info(
            "validate_azure_nvmve_id_json OK: %r", self.azure_nvme_id_json_stdout
        )

    def validate_azure_nvme_id_version(self) -> None:
        """Validate azure-nvme-id --version outputs."""
        assert (
            self.azure_nvme_id_version_returncode == 0
        ), "azure-nvme-id --version failed"
        assert (
            self.azure_nvme_id_version_stderr == ""
        ), f"unexpected azure-nvme-id stderr: {self.azure_nvme_id_stderr}"
        assert self.azure_nvme_id_version_stdout, "missing azure-nvme-id version stdout"
        assert re.match(
            r"azure-nvme-id [0v]\.*", self.azure_nvme_id_version_stdout.strip()
        ), f"unexpected azure-nvme-id version stdout: {self.azure_nvme_id_version_stdout}"
        assert re.match(
            r"[0v]\.*", self.azure_nvme_id_version
        ), f"unexpected azure-nvme-id version: {self.azure_nvme_id_version}"

        logger.info("validate_azure_nvme_id_version OK: %s", self.azure_nvme_id_version)

    def validate_azure_nvme_id_zzz_invalid_arg(self) -> None:
        """Validate azure-nvme-id handles invalid arguments."""
        assert (
            self.azure_nvme_id_zzz_returncode == 1
        ), f"azure-nvme-id zzz rc={self.azure_nvme_id_zzz_returncode}"
        assert (
            self.azure_nvme_id_zzz_stderr == "invalid argument: zzz\n"
        ), f"unexpected azure-nvme-id zzz stderr: {self.azure_nvme_id_zzz_stderr!r}"
        assert (
            self.azure_nvme_id_zzz_stdout
            and self.azure_nvme_id_zzz_stdout.startswith("Usage: azure-nvme-id ")
        ), (f"unexpected azure-nvme-id zzz stdout: {self.azure_nvme_id_zzz_stdout!r}")

        logger.info(
            "validate_azure_nvme_id_invalid_arg OK: %r", self.azure_nvme_id_zzz_stdout
        )

    def validate(self, disk_info: DiskInfo) -> None:
        """Validate Azure NVMe ID output."""
        self.validate_azure_nvme_id_help()
        self.validate_azure_nvme_id_version()
        self.validate_azure_nvme_id_zzz_invalid_arg()
        self.validate_azure_nvme_id(disk_info)
        self.validate_azure_nvme_id_json(disk_info)

    @classmethod
    def gather(cls) -> "AzureNvmeIdInfo":
        """Gather Azure NVMe ID information."""
        proc = subprocess.run(["azure-nvme-id"], capture_output=True, check=False)
        azure_nvme_id_stdout = proc.stdout.decode("utf-8")
        azure_nvme_id_stderr = proc.stderr.decode("utf-8")
        azure_nvme_id_returncode = proc.returncode
        azure_nvme_id_disks = cls.parse_azure_nvme_id_output(azure_nvme_id_stdout)

        proc = subprocess.run(
            ["azure-nvme-id", "--format", "json"], capture_output=True, check=False
        )
        azure_nvme_id_json_stdout = proc.stdout.decode("utf-8")
        azure_nvme_id_json_stderr = proc.stderr.decode("utf-8")
        azure_nvme_id_json_returncode = proc.returncode
        azure_nvme_id_json_disks = cls.parse_azure_nvme_id_json_output(
            azure_nvme_id_json_stdout
        )

        proc = subprocess.run(
            ["azure-nvme-id", "--help"], capture_output=True, check=False
        )
        azure_nvme_id_help_stdout = proc.stdout.decode("utf-8")
        azure_nvme_id_help_stderr = proc.stderr.decode("utf-8")
        azure_nvme_id_help_returncode = proc.returncode

        proc = subprocess.run(
            ["azure-nvme-id", "--version"], capture_output=True, check=False
        )
        azure_nvme_id_version_stdout = proc.stdout.decode("utf-8")
        azure_nvme_id_version_stderr = proc.stderr.decode("utf-8")
        azure_nvme_id_version_returncode = proc.returncode
        azure_nvme_id_version = cls.parse_azure_nvme_id_version(
            azure_nvme_id_version_stdout
        )

        proc = subprocess.run(
            ["azure-nvme-id", "zzz"], capture_output=True, check=False
        )
        azure_nvme_id_zzz_stdout = proc.stdout.decode("utf-8")
        azure_nvme_id_zzz_stderr = proc.stderr.decode("utf-8")
        azure_nvme_id_zzz_returncode = proc.returncode

        azure_nvme_id_info = cls(
            azure_nvme_id_stdout=azure_nvme_id_stdout,
            azure_nvme_id_stderr=azure_nvme_id_stderr,
            azure_nvme_id_returncode=azure_nvme_id_returncode,
            azure_nvme_id_help_stdout=azure_nvme_id_help_stdout,
            azure_nvme_id_help_stderr=azure_nvme_id_help_stderr,
            azure_nvme_id_help_returncode=azure_nvme_id_help_returncode,
            azure_nvme_id_disks=azure_nvme_id_disks,
            azure_nvme_id_json_stdout=azure_nvme_id_json_stdout,
            azure_nvme_id_json_stderr=azure_nvme_id_json_stderr,
            azure_nvme_id_json_returncode=azure_nvme_id_json_returncode,
            azure_nvme_id_json_disks=azure_nvme_id_json_disks,
            azure_nvme_id_version_stdout=azure_nvme_id_version_stdout,
            azure_nvme_id_version_stderr=azure_nvme_id_version_stderr,
            azure_nvme_id_version_returncode=azure_nvme_id_version_returncode,
            azure_nvme_id_version=azure_nvme_id_version,
            azure_nvme_id_zzz_returncode=azure_nvme_id_zzz_returncode,
            azure_nvme_id_zzz_stdout=azure_nvme_id_zzz_stdout,
            azure_nvme_id_zzz_stderr=azure_nvme_id_zzz_stderr,
        )
        logger.info("azure-nvme-id info: %r", azure_nvme_id_info)
        return azure_nvme_id_info

    @staticmethod
    def parse_azure_nvme_id_json_output(output: str) -> Dict[str, AzureNvmeIdDevice]:
        """Parse azure-nvme-id --format json output.
        Example output:
        [
            {
                "path": "/dev/nvme0n33",
                "model": "MSFT NVMe Accelerator v1.0",
                "properties": {
                    "type": "data",
                    "lun": 31
                },
                "vs": ""
            },
            {
                "path": "/dev/nvme1n1",
                "model": "Microsoft NVMe Direct Disk v2",
                "properties": {
                    "type": "local",
                    "index": 1,
                    "name": "nvme-440G-1"
                },
                "vs": "type=local,index=1,name=nvme-440G-1"
            }
        ]
        """
        devices = {}

        for device in json.loads(output):
            device_path = device["path"]
            model = device["model"]
            properties = device["properties"]
            device_type = properties.pop("type", None)
            device_index = (
                int(properties.pop("index")) if "index" in properties else None
            )
            device_lun = int(properties.pop("lun")) if "lun" in properties else None
            device_name = properties.pop("name", None)
            azure_nvme_id_device = AzureNvmeIdDevice(
                device=device_path,
                model=model,
                nvme_id=",".join([f"{k}={v}" for k, v in properties.items()]),
                type=device_type,
                index=device_index,
                lun=device_lun,
                name=device_name,
                extra=properties,
            )

            key = device_path.split("/")[-1]
            devices[key] = azure_nvme_id_device

        return devices

    @staticmethod
    def parse_azure_nvme_id_output(output: str) -> Dict[str, AzureNvmeIdDevice]:
        """Parse azure-nvme-id output.

        Example output:
        /dev/nvme0n1: type=os
        /dev/nvme0n2: type=data,lun=0
        /dev/nvme0n3: type=data,lun=1
        /dev/nvme1n1: type=local,index=1,name=nvme-440G-1
        /dev/nvme2n1: type=local,index=2,name=nvme-440G-2
        /dev/nvme3n1:
        """
        devices = {}

        for line in output.splitlines():
            parts = line.strip().split(":", 1)
            if parts[-1] == "":
                parts.pop()

            device = parts[0].strip()
            if len(parts) == 2:
                nvme_id = parts[1].strip()
                properties = dict(kv.split("=", 1) for kv in nvme_id.split(","))
            elif len(parts) == 1:
                nvme_id = ""
                properties = {}
            else:
                raise ValueError(f"unexpected azure-nvme-id output: {line}")

            device_type = properties.pop("type", None)
            device_index = (
                int(properties.pop("index")) if "index" in properties else None
            )
            device_lun = int(properties.pop("lun")) if "lun" in properties else None
            device_name = properties.pop("name", None)
            azure_nvme_id_device = AzureNvmeIdDevice(
                device=device,
                model=None,
                nvme_id=nvme_id,
                type=device_type,
                index=device_index,
                lun=device_lun,
                name=device_name,
                extra=properties,
            )

            key = device.split("/")[-1]
            devices[key] = azure_nvme_id_device

        return devices

    @staticmethod
    def parse_azure_nvme_id_version(azure_nvme_id_version_output: str) -> str:
        """Parse azure-nvme-id version output and return version info."""
        parts = azure_nvme_id_version_output.strip().split(" ")
        assert (
            len(parts) == 2
        ), f"unexpected azure-nvme-id version output: {azure_nvme_id_version_output}"
        return parts[1]


@dataclass
class NetworkInterface:
    """Network interface."""

    name: str
    driver: str
    mac: str
    ipv4_addrs: List[str]
    udev_properties: Dict[str, str]


@dataclass(eq=True, repr=True)
class NetworkInfo:
    """Network information."""

    interfaces: Dict[str, NetworkInterface] = field(default_factory=dict)

    @classmethod
    def enumerate_interfaces(cls) -> Dict[str, NetworkInterface]:
        """Retrieve all Ethernet interfaces on the system."""
        interfaces: Dict[str, NetworkInterface] = {}
        interface_names = [
            interface
            for interface in os.listdir(SYS_CLASS_NET)
            if os.path.exists(os.path.join(SYS_CLASS_NET, interface, "device"))
        ]

        for interface_name in interface_names:
            sys_path = Path(SYS_CLASS_NET, interface_name)
            udev_properties = cls.query_udev_properties(interface_name)
            driver_path = Path(sys_path, "device", "driver")
            if not driver_path.is_symlink():
                logger.debug(
                    "ignoring interface %s without driver symlink", interface_name
                )
                continue

            link = os.readlink(driver_path)
            driver = os.path.basename(link)
            mac = (sys_path / "address").read_text().strip()
            ipv4_addrs = cls.get_ipv4_addresses(interface_name)
            interfaces[interface_name] = NetworkInterface(
                name=interface_name,
                driver=driver,
                mac=mac,
                ipv4_addrs=ipv4_addrs,
                udev_properties=udev_properties,
            )

        return interfaces

    @staticmethod
    def get_ipv4_addresses(interface_name: str) -> List[str]:
        """Get the IPv4 addresses of a given network interface using `ip addr`."""
        try:
            result = subprocess.run(
                ["ip", "-4", "addr", "show", interface_name],
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
                text=True,
                check=True,
            )

            ipv4_addresses = re.findall(r"inet (\d+\.\d+\.\d+\.\d+)", result.stdout)
            return ipv4_addresses
        except subprocess.CalledProcessError as error:
            logger.error("failed to get IPv4 address for %s: %r", interface_name, error)
            raise

    @staticmethod
    def query_udev_properties(interface_name: str) -> Dict[str, str]:
        """Query all udev properties for a given interface using udevadm."""
        try:
            result = subprocess.run(
                [
                    "udevadm",
                    "info",
                    "--query=property",
                    f"--path={SYS_CLASS_NET}/{interface_name}",
                ],
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
                text=True,
                check=True,
            )
            properties: Dict[str, str] = {}
            for line in result.stdout.splitlines():
                if "=" in line:
                    key, value = line.split("=", 1)
                    properties[key] = value
            return properties
        except subprocess.CalledProcessError as error:
            logger.error(
                "Failed to query udev properties for %s: %r", interface_name, error
            )
            return {}

    def _validate_interface(self, interface: NetworkInterface) -> None:
        """Ensure the required properties are set for hv_netvsc, mlx4, mlx5, and mana devices."""
        if interface.driver in ["mlx4_core", "mlx5_core", "mana"]:
            assert (
                interface.udev_properties.get("NM_UNMANAGED") == "1"
                and interface.udev_properties.get("AZURE_UNMANAGED_SRIOV") == "1"
                and interface.udev_properties.get("ID_NET_MANAGED_BY") == "unmanaged"
            ), f"missing required properties for network interface: {interface}"
        elif interface.driver == "hv_netvsc":
            assert (
                "AZURE_UNMANAGED_SRIOV" not in interface.udev_properties
            ), f"unexpected AZURE_UNMANAGED_SRIOV property: {interface}"
            assert (
                interface.udev_properties.get("ID_NET_MANAGED_BY") != "unmanaged"
            ), f"hv_netvsc interface should be managed: {interface}"

        mana_has_synthetic_netvsc = interface.driver == "mana" and any(
            i.driver == "hv_netvsc" and i.mac == interface.mac
            for i in self.interfaces.values()
        )
        if interface.driver == "hv_netvsc" or (
            interface.driver == "mana" and not mana_has_synthetic_netvsc
        ):
            assert interface.ipv4_addrs, f"missing IPv4 addresses for {interface}"
        else:
            assert (
                not interface.ipv4_addrs
            ), f"unexpected IPv4 addresses for {interface}"

        logger.info("validate_interface %s OK: %r", interface.name, interface)

    def validate(self) -> None:
        """Validate network configuration."""
        for _, interface in self.interfaces.items():
            self._validate_interface(interface)

    @classmethod
    def gather(cls) -> "NetworkInfo":
        """Gather networking information."""
        return NetworkInfo(interfaces=cls.enumerate_interfaces())


class AzureVmUtilsValidator:
    """Validate Azure VM utilities."""

    def __init__(
        self,
        *,
        skip_imds_validation: bool = False,
        skip_network_validation: bool = False,
        skip_symlink_validation: bool = False,
    ) -> None:
        self.azure_nvme_id_info = AzureNvmeIdInfo.gather()
        self.disk_info = DiskInfo.gather()
        self.net_info = NetworkInfo.gather()
        self.skip_imds_validation = skip_imds_validation
        self.skip_network_validation = skip_network_validation
        self.skip_symlink_validation = skip_symlink_validation

        try:
            self.imds_metadata = get_imds_metadata()
        except Exception as error:
            logger.error("failed to fetch IMDS metadata: %r", error)
            if not self.skip_imds_validation:
                raise
            self.imds_metadata = {}

        self.vm_size = self.imds_metadata.get("compute", {}).get("vmSize")
        self.sku_config = SKU_CONFIGS.get(self.vm_size)

        logger.info("sku config: %r", self.sku_config)

    def validate_dev_disk_azure_links_data(self) -> None:
        """Validate /dev/disk/azure/data links.

        All data disks should have by-lun if azure-vm-utils is installed.
        Future variants of remote disks will include by-name.
        """
        imds_data_disks = (
            self.imds_metadata.get("compute", {})
            .get("storageProfile", {})
            .get("dataDisks", [])
        )
        expected_data_disks = len(imds_data_disks)
        data_disks = [
            link
            for link in self.disk_info.dev_disk_azure_links
            if link.startswith("/dev/disk/azure/data/by-lun")
        ]
        if self.disk_info.nvme_remote_disks:
            assert len(data_disks) == len(
                self.disk_info.nvme_remote_data_disks
            ), f"unexpected number of data disks: {data_disks} configured={self.disk_info.nvme_remote_data_disks}"

        assert (
            len(data_disks) == expected_data_disks
        ), f"unexpected number of data disks: {data_disks} IMDS configured={imds_data_disks} (note that IMDS may not be accurate)"

        # Verify disk sizes match up with IMDS configuration.
        for imds_disk in imds_data_disks:
            lun = imds_disk.get("lun")
            # Disk size is actually reported in GiB not GB.
            expected_size_gib = int(imds_disk.get("diskSizeGB"))
            disk_path = f"/dev/disk/azure/data/by-lun/{lun}"
            actual_size_gib = get_disk_size_gib(disk_path)
            assert (
                actual_size_gib == expected_size_gib
            ), f"disk size mismatch for {disk_path}: expected {expected_size_gib} GiB, found {actual_size_gib} GiB"

        logger.info("validate_dev_disk_azure_links_data OK: %r", data_disks)

    def validate_dev_disk_azure_links_local(self) -> None:
        """Validate /dev/disk/azure/local links.

        All local disks should have by-serial if azure-vm-utils is installed.
        If NVMe id is supported, by-index and by-name will be available as well.
        """
        local_disks = sorted(
            [
                link
                for link in self.disk_info.dev_disk_azure_links
                if link.startswith("/dev/disk/azure/local")
            ]
        )

        for key in ["index", "name", "serial"]:
            local_disks_by_key = sorted(
                [
                    link
                    for link in self.disk_info.dev_disk_azure_links
                    if link.startswith(f"/dev/disk/azure/local/by-{key}")
                ]
            )
            if key == "serial":
                expected_count = len(self.disk_info.nvme_local_disks)
            else:
                expected_count = len(self.disk_info.nvme_local_disks_v2)

            assert (
                len(local_disks_by_key) == expected_count
            ), f"unexpected number of local disks by-{key}: {local_disks_by_key} (expected {expected_count})"
            assert (
                not self.sku_config
                or not self.sku_config.nvme_id_enabled_local
                or len(local_disks_by_key) == self.sku_config.nvme_local_disk_count
            ), f"unexpected number of local disks by sku for by-{key}: {local_disks_by_key} (expected {expected_count})"

            if key == "name":
                for disk in local_disks_by_key:
                    name = disk.split("/")[-1]
                    assert name.startswith(
                        "nvme-"
                    ), f"unexpected local disk name: {name}"
                    match = re.match(r"nvme-(\d+)G-(\d+)", name)
                    assert (
                        match
                    ), f"local disk name does not conform to expected pattern: {name}"
                    size, index = match.groups()
                    assert (
                        size.isdigit() and index.isdigit()
                    ), f"invalid size or index in local disk name: {name}"

                    # Cross-check by-index links with by-name links.
                    by_index_path = f"/dev/disk/azure/local/by-index/{index}"
                    assert os.path.realpath(by_index_path) == os.path.realpath(
                        disk
                    ), f"mismatch between by-index and by-name links: {by_index_path} != {disk}"

        logger.info("validate_dev_disk_azure_links_local OK: %r", local_disks)

    def validate_dev_disk_azure_links_os(self) -> None:
        """Validate /dev/disk/azure/os link."""
        os_disk = "/dev/disk/azure/os"
        assert os_disk in self.disk_info.dev_disk_azure_links, f"missing {os_disk}"

        logger.info("validate_dev_disk_azure_links_os OK: %r", os_disk)

    def validate_dev_disk_azure_links_resource(self) -> None:
        """Validate /dev/disk/azure/resource link."""
        resource_disk = "/dev/disk/azure/resource"
        expected = (self.sku_config and self.sku_config.temp_disk_size_gib) or bool(
            self.disk_info.scsi_resource_disk
        )
        if expected:
            assert (
                resource_disk in self.disk_info.dev_disk_azure_links
            ), f"missing {resource_disk}"
        else:
            assert (
                resource_disk not in self.disk_info.dev_disk_azure_links
            ), f"unexpected {resource_disk}"

        logger.info("validate_dev_disk_azure_links_resource OK: %r", resource_disk)

    def validate_networking(self) -> None:
        """Validate networking configuration."""
        self.net_info.validate()
        logger.info("validate_networking OK: %r", self.net_info)

    def validate_nvme_local_disks(self) -> None:
        """Validate NVMe local disks."""
        logger.info("validate_nvme_local_disks OK: %r", self.disk_info.nvme_local_disks)

    def validate_scsi_resource_disk(self) -> None:
        """Validate SCSI resource disk symlink and size."""
        assert (
            self.disk_info.scsi_resource_disk_size_gib
            == self.disk_info.dev_disk_azure_resource_disk_size_gib
        ), f"resource disk size mismatch: {self.disk_info}"
        if self.disk_info.scsi_resource_disk:
            assert (
                f"/dev/{self.disk_info.scsi_resource_disk}"
                == self.disk_info.dev_disk_azure_resource_disk
            ), f"unexpected resource disk path: {self.disk_info}"
        else:
            assert (
                self.disk_info.scsi_resource_disk is None
                and self.disk_info.dev_disk_azure_resource_disk is None
            ), f"unexpected resource disk path: {self.disk_info}"

        logger.info(
            "validate_scsi_resource_disk OK: /dev/disk/azure/resource => %s",
            self.disk_info.dev_disk_azure_resource_disk,
        )

    def validate_sku_config(self) -> None:
        """Validate SKU config."""
        if not self.sku_config:
            logger.warning(
                "validate_sku_config SKIPPED: no sku configuration for VM size %r",
                self.vm_size,
            )
            return

        assert (
            self.sku_config.vm_size == self.vm_size
        ), f"vm size mismatch: {self.sku_config.vm_size} != {self.vm_size}"
        assert (
            len(self.disk_info.nvme_local_disks)
            == self.sku_config.nvme_local_disk_count
        ), f"local disk count mismatch: {len(self.disk_info.nvme_local_disks)} != {self.sku_config.nvme_local_disk_count}"
        assert (
            self.disk_info.nvme_local_disk_size_gib
            == self.sku_config.nvme_local_disk_size_gib
        ), f"local disk size mismatch: {self.disk_info.nvme_local_disk_size_gib} != {self.sku_config.nvme_local_disk_size_gib}"
        assert (
            self.disk_info.scsi_resource_disk_size_gib
            == self.sku_config.temp_disk_size_gib
        ), f"temp disk size mismatch: {self.disk_info.scsi_resource_disk_size_gib} != {self.sku_config.temp_disk_size_gib}"
        assert (
            self.disk_info.dev_disk_azure_resource_disk_size_gib
            == self.sku_config.temp_disk_size_gib
        ), f"temp disk size mismatch: {self.disk_info.dev_disk_azure_resource_disk_size_gib} != {self.sku_config.temp_disk_size_gib}"

        logger.info("validate_sku_config OK: %r", self.sku_config)

    def validate(self) -> None:
        """Run validations."""
        self.azure_nvme_id_info.validate(self.disk_info)

        if self.skip_symlink_validation:
            logger.info("validate_dev_disk_azure_links_data SKIPPED")
            logger.info("validate_dev_disk_azure_links_local SKIPPED")
            logger.info("validate_dev_disk_azure_links_os SKIPPED")
            logger.info("validate_dev_disk_azure_links_resource SKIPPED")
            logger.info("validate_scsi_resource_disk SKIPPED")
        else:
            self.validate_dev_disk_azure_links_data()
            self.validate_dev_disk_azure_links_local()
            self.validate_dev_disk_azure_links_os()
            self.validate_dev_disk_azure_links_resource()
            self.validate_scsi_resource_disk()

        if self.skip_network_validation:
            logger.info("validate_networking SKIPPED")
        else:
            self.validate_networking()

        self.validate_sku_config()

        logger.info("success!")


def main() -> None:
    """Main entry point."""
    parser = argparse.ArgumentParser(
        description="Azure VM utilities self-tests script."
    )
    parser.add_argument(
        "--debug",
        action="store_true",
        help="Enable debug logging",
    )
    parser.add_argument(
        "--skip-imds-validation",
        action="store_true",
        help="Skip imds validation (allow for running tests outside Azure VM)",
    )
    parser.add_argument(
        "--skip-network-validation",
        action="store_true",
        help="Skip network validation (allow for running test without reboot after install)",
    )
    parser.add_argument(
        "--skip-symlink-validation",
        action="store_true",
        help="Skip symlink validation (allow for running test without reboot after install)",
    )
    args = parser.parse_args()

    if args.debug:
        logging.basicConfig(format="[%(asctime)s] %(message)s", level=logging.DEBUG)
    else:
        logging.basicConfig(format="[%(asctime)s] %(message)s", level=logging.INFO)

    validator = AzureVmUtilsValidator(
        skip_imds_validation=args.skip_imds_validation,
        skip_network_validation=args.skip_network_validation,
        skip_symlink_validation=args.skip_symlink_validation,
    )
    validator.validate()


if __name__ == "__main__":
    main()
