import base64
import logging
import string
from typing import cast

import requests
from Crypto.Cipher import AES
from Crypto.Protocol.KDF import PBKDF2
from malduck import UInt32
from malduck.extractor import Extractor
from malduck.procmem import ProcessMemory
from malduck.yara import YaraRuleMatch

from ..utils import Config, get_rule_metadata

log = logging.getLogger(__name__)

__author__ = "c3rb3ru5"
__version__ = "1.0.0"


class ASyncRAT(Extractor):

    """
    ASyncRAT Configuration Extractor
    """

    family: str = "asyncrat"
    yara_rules: tuple = ("asyncrat",)

    AES_BLOCK_SIZE = 128
    AES_KEY_SIZE = 256
    AES_CIPHER_MODE = AES.MODE_CBC

    @staticmethod
    def get_salt() -> bytes:
        return bytes.fromhex("BFEB1E56FBCD973BB219022430A57843003D5644D21E62B9D4F180E7E6C33941")

    def decrypt(self, key: bytes, ciphertext: bytes) -> str:
        aes_key: bytes = PBKDF2(key, self.get_salt(), 32, 50000)
        cipher = AES.new(aes_key, self.AES_CIPHER_MODE, ciphertext[32 : 32 + 16])
        plaintext: str = cipher.decrypt(ciphertext[48:]).decode("ascii", "ignore").strip()
        return plaintext

    @staticmethod
    def get_string(data: list[bytes], index: int) -> str:
        return data[index][1:].decode("utf-8", "ignore")

    def decrypt_config_item(self, key: bytes, data: list[bytes], index: int) -> str | bool:
        _data: bytes = base64.b64decode(self.get_string(data, index))
        plaintext = self.decrypt(key, _data)
        if plaintext.lower() == "true":
            return True
        if plaintext.lower() == "false":
            return False
        return plaintext

    @staticmethod
    def get_wide_string(data: list[bytes], index: int) -> str:
        _data: bytes = data[index][1:] + b"\x00"
        return _data.decode("utf-16")

    def decrypt_config_item_list(self, key: bytes, data: list[bytes], index: int) -> list[str]:
        result: str = "".join(
            filter(
                lambda x: x in string.printable,
                self.decrypt(key, base64.b64decode(data[index][1:])),
            )
        )
        if result == "null":
            return []
        return result.split(",")

    def decrypt_config_item_printable(self, key: bytes, data: list[bytes], index: int) -> str:
        result = "".join(
            filter(
                lambda x: x in string.printable,
                self.decrypt(key, base64.b64decode(data[index][1:])),
            )
        )
        return result

    @Extractor.rule
    def asyncrat(self, p: ProcessMemory, match: YaraRuleMatch) -> Config | bool:
        _info: Config = get_rule_metadata(match)
        return _info

    @Extractor.extractor("magic_cslr_0")
    def asyncrat_magic(self, p: ProcessMemory, addr: int) -> Config | None:
        try:
            strings_offset = cast(UInt32, p.uint32v(addr + 0x40))
            strings_size = cast(UInt32, p.uint32v(addr + 0x44))
            raw: bytes = p.readv(addr + strings_offset, strings_size)
            data = raw.split(b"\x00\x00")
            key = base64.b64decode(self.get_string(data, 7))
            log.debug("extracted key: %s", str(key))

            config = {
                self.family: {
                    "hosts": self.decrypt_config_item_list(key, data, 2),
                    "ports": self.decrypt_config_item_list(key, data, 1),
                    "version": self.decrypt_config_item_printable(key, data, 3),
                    "install_folder": self.get_wide_string(data, 5),
                    "install_file": self.get_wide_string(data, 6),
                    "install": self.decrypt_config_item_printable(key, data, 4),
                    "mutex": self.decrypt_config_item_printable(key, data, 8),
                    "pastebin": self.decrypt(key, base64.b64decode(data[12][1:])).encode("ascii").replace(b"\x0f", b""),
                },
            }
            if config[self.family].get("pastebin", None) and config[self.family]["pastebin"] != "null":
                try:
                    req = requests.get(url=config[self.family]["pastebin"])
                    if req.status_code == 200:
                        data = req.content.split(b"\x3a")
                        config[self.family]["host"] = data[0].decode("ascii", "ignore")
                        config[self.family]["ports"] = [data[1].decode("ascii", "ignore")]
                except requests.exceptions.RequestException as error:
                    log.warning(error)
            return config
        except requests.exceptions.RequestException as error:
            log.warning(error)
            return None
