import logging

from malduck import enhex, ipv4
from malduck.extractor import Extractor
from malduck.procmem import ProcessMemory
from malduck.yara import YaraRuleMatch

from ...utils import Config, get_rule_metadata

log = logging.getLogger(__name__)

__author__ = "c3rb3ru5"
__version__ = "1.0.0"


class DridexLoader(Extractor):

    """
    DridexLoader Configuration Extractor
    """

    family: str = "dridex"
    yara_rules: tuple = ("dridex_loader",)
    LEN_BLOB_KEY = 40
    LEN_BOT_KEY = 107
    botnet_rva = None
    botnet_id = None
    rc4_key = None
    ip_count = 0

    @Extractor.extractor("c2parse_6")
    def c2parse_6(self, p: ProcessMemory, addr: int) -> None:
        self.c2_rva = p.uint32v(addr + 44)
        self.botnet_rva = p.uint32v(addr - 7)
        self.delta = 0

    @Extractor.extractor("c2parse_5")
    def c2parse_5(self, p: ProcessMemory, addr: int) -> None:
        self.c2_rva = p.uint32v(addr + 75)
        self.botnet_rva = p.uint32v(addr + 3)
        if self.botnet_rva:
            self.botnet_id = p.uint16v(self.botnet_rva)
        self.delta = 0

    @Extractor.extractor("c2parse_4")
    def c2parse_4(self, p: ProcessMemory, addr: int) -> None:
        self.c2_rva = p.uint32v(addr + 6)
        self.delta = 0

    @Extractor.extractor("c2parse_3")
    def c2parse_3(self, p: ProcessMemory, addr: int) -> None:
        self.c2_rva = p.uint32v(addr + 60)
        self.delta = 2

    @Extractor.extractor("c2parse_2")
    def c2parse_2(self, p: ProcessMemory, addr: int) -> None:
        self.c2_rva = p.uint32v(addr + 47)
        self.delta = 0

    @Extractor.extractor("c2parse_1")
    def c2parse_1(self, p: ProcessMemory, addr: int) -> None:
        self.c2_rva = p.uint32v(addr + 27)
        self.delta = 2

    @Extractor.extractor("botnet_id")
    def get_botnet_id(self, p: ProcessMemory, addr: int) -> None:
        self.botnet_rva = p.uint32v(addr + 23)
        if self.botnet_rva:
            self.botnet_id = p.uint16v(self.botnet_rva)

    def get_rc4_rva(self, p: ProcessMemory, rc4_decode: int) -> int | None:
        zb = p.uint8v(rc4_decode + 8, True)
        if zb:
            return p.uint32v(rc4_decode + 5)
        return p.uint32v(rc4_decode + 3)

    @Extractor.extractor("rc4_key_1")
    def rc4_key_1(self, p: ProcessMemory, addr: int) -> None:
        self.rc4_key = self.get_rc4_rva(p, addr)

    @Extractor.extractor("rc4_key_2")
    def rc4_key_2(self, p: ProcessMemory, addr: int) -> None:
        self.rc4_key = self.get_rc4_rva(p, addr)

    @Extractor.extractor("ip_count_1")
    def ip_count_1(self, p: ProcessMemory, addr: int) -> None:
        ip_count_rva = p.uint32v(addr + 3)
        if ip_count_rva:
            self.ip_count = p.uint8v(ip_count_rva)

    @Extractor.extractor("ip_count_2")
    def ip_count_2(self, p: ProcessMemory, addr: int) -> None:
        ip_count_rva = p.uint32v(addr + 2)
        if ip_count_rva:
            self.ip_count = p.uint8v(ip_count_rva)

    @Extractor.extractor("ip_count_3")
    def ip_count_3(self, p: ProcessMemory, addr: int) -> None:
        ip_count_rva = p.uint32v(addr + 2)
        if ip_count_rva:
            self.ip_count = p.uint8v(ip_count_rva)

    @staticmethod
    def match_exists(matches, name):
        for element in matches.elements:
            if element == name:
                return True
        return False

    @Extractor.rule
    def dridex_loader(self, p: ProcessMemory, match: YaraRuleMatch) -> Config | bool:
        _info: Config = get_rule_metadata(match)
        return _info

    @Extractor.needs_pe
    @Extractor.final
    def dridex_loader_final(self, p: ProcessMemory) -> dict | None:
        if p.memory:
            config = {
                "family": self.family,
            }
            if not self.ip_count or self.ip_count > 10:
                return None
            log.debug("ip_count: %d", self.ip_count)

            config[self.family] = {"hosts": []}
            if self.c2_rva:
                for i in range(0, self.ip_count):
                    ip = None
                    port = None

                    ip = ipv4(p.readv(self.c2_rva, 4))
                    port = p.uint16v(self.c2_rva + 4)
                    log.debug("found c2 ip: " + str(ip) + ":" + str(port))
                    if ip is not None and port is not None:
                        config[self.family]["hosts"].append(str(ip) + ":" + str(port))
                    self.c2_rva += 6 + self.delta

            if len(config[self.family]["hosts"]) <= 0:
                return None

            if self.rc4_key:
                config["rc4_key"] = enhex(self.rc4_key)
            if self.botnet_id is not None:
                log.debug("found botnet_id: " + str(self.botnet_id))
                config[self.family]["botnet_id"] = self.botnet_id
            return config
        return None
