import idaapi
import idc
import ida_ida
import ida_dbg
import logging
import ida_nalt
import os
from kernel32 import Kernel32
from advapi32 import Advapi32

# Calling conventions
CALL_CONV_CDECL = 0
CALL_CONV_STDCALL = 1
CALL_CONV_FASTCALL = 2
CALL_CONV_FLOAT = 3
VAR_ARGS = -1

functions = {
    "kernel32" : Kernel32(),
    "advapi32" : Advapi32()
}
# logging setting
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')

class FuncTracePlus(ida_dbg.DBG_Hooks):
    def __init__(self, ignore_bp = True):
        ida_dbg.DBG_Hooks.__init__(self)
        self.ignore_bp = ignore_bp
        self.is_64bit = True if idaapi.BADADDR == 0xffffffffffffffff else False
        self.module_funcs = {}
        self.module_address_name = {}
        self.entry_points = []
        self.to_hook_func = {}
        self.break_points = {}
        self.break_points_ret_addr = {}
        self.getproc_func_ea = None
        self.getproc_func_addr = None
        self.getproc_func_ret_addrs = {}
    def dbg_process_start(self, pid, tid, ea, modinfo_name, modinfo_base, modinfo_size):
        """Handle process start event."""
        ep = ida_ida.inf_get_start_ip()
        self.entry_points.append(ep)
        idc.add_bpt(ep)

        logging.info(f"process start => pid : {pid}, tid : {tid}, modinfo_name : {modinfo_name},\
                 modinfo_base : {hex(modinfo_base)}, modinfo_size : {hex(modinfo_size)}")
    def dbg_process_exit(self, pid, tid, ea, exit_code):
        for e in self.entry_points:
            idc.del_bpt(e)
        for addr in self.break_points:
            idc.del_bpt(addr)
        if self.getproc_func_addr != None and ida_dbg.exist_bpt(self.getproc_func_addr) == True:
            idc.del_bpt(self.getproc_func_addr)
        logging.info(f"process exit => pid : {pid}, tid : {tid}, ea : {hex(ea)}, exit_code : {exit_code}")
    def dbg_thread_start(self, pid, tid, ea):
        logging.info(f"thread start => pid : {pid}, tid : {tid}, ea : {hex(ea)}")
    def dbg_thread_exit(self, pid, tid, ea, exit_code):
        logging.info(f"thread exit => pid : {pid}, tid : {tid}, ea : {hex(ea)}, exit_code : {exit_code}")
    def dbg_library_load(self, pid, tid, ea, modinfo_name, modinfo_base, modinfo_size):
        # writefile
        self.module_address_name[modinfo_base] = os.path.splitext(os.path.basename(modinfo_name))[0].lower()
        
        logging.info(f"library load => pid : {pid}, tid : {tid}, ea : {hex(ea)},\
                 modinfo_name : {modinfo_name}, modinfo_base : {hex(modinfo_base)},\
                 modinfo_size : {hex(modinfo_size)}")
    # module_funcs to_hook_func
    def dbg_library_unload(self, pid, tid, ea, info):
        logging.info(f"unload library => pid : {pid}, tid : {tid}, ea : {hex(ea)}, info : {info}")
    def is_hit_entry_point(self, bptea):
        for entry in self.entry_points:
            if entry == bptea: # hit entry point
                return True
        return False
    def get_ptr_from_addr(self, addr):
        """Get pointer value from given address."""
        ida_dbg.refresh_debugger_memory()
        if self.is_64bit:
            target_func_addr = idc.get_qword(addr)
        else:
            target_func_addr = idc.get_wide_dword(addr)
            
        return target_func_addr
    # add breakpoint on GetProcAddress and target hook funcs
    def hook_funcs_action(self):
        def hook(module_name, hook_func):
            if module_name not in self.module_funcs: #
                return
            for idx, func in enumerate(self.module_funcs[module_name]):
                if hook_func["name"] == func["name"]:
                    target_func_addr = self.get_ptr_from_addr(func["ea"])
                    module_func_s = "%s!%s" % (module_name, hook_func["name"])
                    self.module_funcs[module_name][idx]["va"] = target_func_addr
                    self.break_points[target_func_addr] = module_func_s
                    logging.debug("hook:%s!%s => %s" % (module_name, hook_func["name"], hex(target_func_addr)))
                    idc.add_bpt(target_func_addr)
        
        for module_name in self.to_hook_func:
            hook_funcs = self.to_hook_func[module_name]
            for hook_func in hook_funcs:
                hook(module_name, hook_func)
        
        if "kernel32!GetProcAddress" not in self.break_points:
            #self.getproc_func_addr = self.get_ptr_from_addr(idc.get_name_ea_simple("GetProcAddress"))
            for func in self.module_funcs["kernel32"]:
                if func["name"] == "GetProcAddress":
                    self.getproc_func_ea = func["ea"]
                    break
            self.getproc_func_addr = self.get_ptr_from_addr(self.getproc_func_ea)
            # self.break_points[self.getproc_func_addr] = "kernel32!GetProcAddress"
            logging.debug(f"hook:kernel32!GetProcAddress => {hex(self.getproc_func_addr)}")
            idc.add_bpt(self.getproc_func_addr)
    def find_hook_func_breakpoint(self, bptea):
        return bptea in self.break_points.values()
    def get_func_args(self, nth):
        ida_dbg.refresh_debugger_memory()
        if self.is_64bit:
            four_args = ["RCX", "RDX", "R8", "R9"]
            if nth == 0:
                rsp = idc.get_reg_value("RSP")
                result = idc.get_qword(rsp)
            elif nth < 5:
                result = idc.get_reg_value(four_args[nth - 1])
            else:
                rsp = idc.get_reg_value("RSP")
                result = idc.get_qword(rsp + (nth - 4) * 8)
        else:
            esp = idc.get_reg_value("ESP")
            result = idc.get_wide_dword(esp + nth * 4)
         
        return result
    def get_register_value(self, register_name):
        if self.is_64bit:
            register_name = "R" + register_name
        else:
            register_name = "E" + register_name
        return idc.get_reg_value(register_name)
    def get_args_number(self, bptea):
        module_name, func_name = self.break_points[bptea].split("!")
        if module_name in self.to_hook_func:
            for hook_func in self.to_hook_func[module_name]:
                if hook_func["name"] == func_name:
                    return hook_func["argc"]
        return None
    def hit_getprocaddress_or_ret_bp(self, bptea):
        # hit breakpoint on GetProcAddress
        if bptea == self.getproc_func_addr:
            hmodule = self.get_func_args(1)
            if hmodule in self.module_address_name:
                module_name = self.module_address_name[hmodule]
                get_proc_addr_ret_addr = self.get_func_args(0)
                func_name_addr = self.get_func_args(2)
                func_name_str = idc.get_strlit_contents(func_name_addr)
                # print(f"func_name_str : {func_name_str.decode('utf-8')}")
                # add a new module to module_funcs
                if module_name not in self.module_funcs:
                    self.module_funcs[module_name] = []
                if func_name_str != None:
                    self.module_funcs[module_name].append({"name":func_name_str.decode('utf-8')})
                    self.getproc_func_ret_addrs[get_proc_addr_ret_addr] = \
                    {
                        "module_name":module_name, 
                        "func_name":func_name_str.decode("utf-8"),
                        "debug_info":f"hook function exec: {hex(get_proc_addr_ret_addr)} => kernel32!"
                    }
                    idc.add_bpt(get_proc_addr_ret_addr)
            return True
        # hit breakpoint on GetProcAddress return address
        if bptea in self.getproc_func_ret_addrs:
            _module_name = self.getproc_func_ret_addrs[bptea]["module_name"]
            _func_name = self.getproc_func_ret_addrs[bptea]["func_name"]
            _debug_info = self.getproc_func_ret_addrs[bptea]["debug_info"]
            target_func_addr = self.get_register_value("AX")
            if _debug_info != None:
                _debug_info += f"GetProcAddress({_module_name},{_func_name}) = {hex(target_func_addr)}"
                logging.debug(_debug_info)
            for idx, func in enumerate(self.module_funcs[_module_name]):
                if func["name"] == _func_name:
                    self.module_funcs[_module_name][idx]["va"] = target_func_addr
                    break
            if _module_name in self.to_hook_func:
                for idx, hook_func in enumerate(self.to_hook_func[_module_name]):
                    if hook_func["name"] == _func_name:
                        self.break_points[target_func_addr] = "%s!%s" % (_module_name, _func_name)
                        if ida_dbg.exist_bpt(target_func_addr) == False:
                            logging.debug("hook:%s!%s => %s" % (_module_name, hook_func["name"], hex(target_func_addr)))
                            idc.add_bpt(target_func_addr)
                        break
            idc.del_bpt(bptea)
            self.getproc_func_ret_addrs.pop(bptea)
            return True
        return False
    def hit_general_hook_func(self, bptea):
        if bptea in self.break_points:
            _module_name, _func_name = self.break_points[bptea].split('!')
            general_hook_func_ret_addr = self.get_func_args(0)
            args_number = self.get_args_number(bptea)
            if args_number != None:
                _debug_info = f"hook function exec: {hex(general_hook_func_ret_addr)} => {self.break_points[bptea]}("
                if _module_name in functions and _func_name in functions[_module_name].total_func:
                    _debug_info += functions[_module_name].total_func[_func_name]()
                else:
                    for i in range(1, args_number + 1):
                        _debug_info += f"{hex(self.get_func_args(i))},"
                    _debug_info = _debug_info[:-1] + ")"
            
            self.break_points_ret_addr[general_hook_func_ret_addr] = \
            {
                "module_name":_module_name, 
                "func_name":_func_name,
                "debug_info":_debug_info
            }
            idc.add_bpt(general_hook_func_ret_addr)
            return True
        
        if bptea in self.break_points_ret_addr:
            _module_name = self.break_points_ret_addr[bptea]["module_name"]
            _func_name = self.break_points_ret_addr[bptea]["func_name"]
            _debug_info = self.break_points_ret_addr[bptea]["debug_info"]
            if _debug_info != None:
                if _module_name in functions and _func_name in functions[_module_name].total_func:
                    _debug_info += functions[_module_name].total_func[_func_name](True)
                else:
                    _debug_info += f"={hex(idc.get_reg_value('AX'))}"
                logging.debug(_debug_info)
            idc.del_bpt(bptea)
            self.break_points_ret_addr.pop(bptea)
            return True
        return False
    def dbg_bpt(self, tid, bptea):
        
        if self.is_hit_entry_point(bptea) == True:
            self.dbg_get_imports()
            self.hook_funcs_action()
            ida_dbg.request_continue_process()
            ida_dbg.run_requests()
            return 0
        
        hit_res = self.hit_getprocaddress_or_ret_bp(bptea)
        if hit_res == False:
            hit_res = self.hit_general_hook_func(bptea)
        
        if hit_res == False and self.ignore_bp == False:
            return 0
        
        ida_dbg.request_continue_process()
        ida_dbg.run_requests()
        return 0
    def dbg_exception(self, pid, tid, ea, exc_code, exc_can_cont, exc_ea, exc_info):
        logging.debug("Exception: pid=%d tid=%d ea=0x%x exc_code=0x%x can_continue=%d exc_ea=0x%x exc_info=%s" % (
            pid, tid, ea, exc_code & idaapi.BADADDR, exc_can_cont, exc_ea, exc_info))
    def dbg_get_imports(self):
        nimps = ida_nalt.get_import_module_qty()
        for i in range(nimps):
            module_name = ida_nalt.get_import_module_name(i).lower()
            if not module_name:
                logging.warning("Failed to get import module name for #%d" % i)
                module_name = "<unnamed>"
            self.module_funcs[module_name] = []
            def imp_cb(ea, name, ordinal):
                if not name:
                    self.module_funcs[module_name].append({"ea":ea, "ordinal":ordinal})
                else:
                    self.module_funcs[module_name].append({"ea":ea, "ordinal":ordinal, "name":name})
                return True
            ida_nalt.enum_import_names(i, imp_cb)
    def add_hook_functions(self, module_name, func_name, argc):
        if module_name not in self.to_hook_func:
            self.to_hook_func[module_name] = []
        self.to_hook_func[module_name].append({"name":func_name, "argc" : argc})

if __name__ == "__main__":
    fh = FuncTracePlus(ignore_bp=False)
    fh.hook()
    # fh.add_hook_functions("user32", "MessageBoxA", 4, CALL_CONV_STDCALL)
    fh.add_hook_functions("kernel32", "CreateFileA", 7)
    fh.add_hook_functions("kernel32", "CreateFileW", 7)
    fh.add_hook_functions("kernel32", "ReadFile", 5)
    fh.add_hook_functions("kernel32", "WriteFile", 5)
    fh.add_hook_functions("kernel32", "CloseHandle", 1)

    # fh.add_hook_functions("kernel32", "CreateThread", 6)
    # fh.add_hook_functions("kernel32", "CreateProcessA", 10)
    # fh.add_hook_functions("kernel32", "CreateProcessW", 10)
    # fh.add_hook_functions("kernel32", "VirtualAlloc", 4)
    # fh.add_hook_functions("kernel32", "VirtualFree", 3)
    # fh.add_hook_functions("kernel32", "VirtualProtect", 4)

    # fh.add_hook_functions("kernel32", "IsDebuggerPresent", 0)
