import signal
from msvcrt import get_osfhandle
from ctypes.wintypes import *
from ctypes import *
import os
import psutil
import struct


if sizeof(c_ulong) == sizeof(c_void_p):
    ULONG_PTR = c_ulong
elif sizeof(c_ulonglong) == sizeof(c_void_p):
    ULONG_PTR = c_ulonglong
DWORD_PTR = ULONG_PTR
SIZE_T = ULONG_PTR
PVOID = c_void_p


class PROC_THREAD_ATTRIBUTE_LIST(Structure):
    pass


class SECURITY_ATTRIBUTES(Structure):
    pass


class STARTUPINFOW(Structure):
    _fields_ = [
        ("cb", DWORD),
        ("lpReserved", LPWSTR),
        ("lpDesktop", LPWSTR),
        ("lpTitle", LPWSTR),
        ("dwX", DWORD),
        ("dwY", DWORD),
        ("dwXSize", DWORD),
        ("dwYSize", DWORD),
        ("dwXCountChars", DWORD),
        ("dwYCountChars", DWORD),
        ("dwFillAttributes", DWORD),
        ("dwFlags", DWORD),
        ("wShowWindow", WORD),
        ("cbReserved2", WORD),
        ("lpReserved2", LPBYTE),
        ("hStdInput", HANDLE),
        ("hStdOutput", HANDLE),
        ("hStdError", HANDLE),
    ]


class STARTUPINFOEXW(Structure):
    _fields_ = [
        ("StartupInfo", STARTUPINFOW),
        ("lpAttributeList", POINTER(PROC_THREAD_ATTRIBUTE_LIST)),
    ]


class PROCESS_INFORMATION(Structure):
    _fields_ = [
        ("hProcess", HANDLE),
        ("hThread", HANDLE),
        ("dwProcessId", DWORD),
        ("dwThreadId", DWORD),
    ]


class UnbufferedProcess:
    def __init__(self, path):
        assert os.path.exists(path.split()[0]), f"{path.split()[0]} is not a valid exe"
        self.path = path
        self.exe_name = os.path.basename(path.split()[0])
        self.process_info = None
        self.stdin = None
        self.stdout = None

    def run(self):
        kernel32 = WinDLL("kernel32", use_last_error=True)
        CreateProcessW = kernel32.CreateProcessW
        CreateProcessW.argtypes = [
            LPCWSTR,
            LPWSTR,
            POINTER(SECURITY_ATTRIBUTES),
            POINTER(SECURITY_ATTRIBUTES),
            BOOL,
            DWORD,
            LPVOID,
            LPCWSTR,
            POINTER(STARTUPINFOW),
            POINTER(PROCESS_INFORMATION),
        ]
        CreateProcessW.restype = BOOL
        self.CloseHandle = kernel32.CloseHandle
        self.CloseHandle.argtypes = [HANDLE]
        self.CloseHandle.restype = BOOL
        InitializeProcThreadAttributeList = kernel32.InitializeProcThreadAttributeList
        InitializeProcThreadAttributeList.argtypes = [
            POINTER(PROC_THREAD_ATTRIBUTE_LIST),
            DWORD,
            DWORD,
            POINTER(SIZE_T),
        ]
        InitializeProcThreadAttributeList.restype = BOOL
        UpdateProcThreadAttribute = kernel32.UpdateProcThreadAttribute
        UpdateProcThreadAttribute.argtypes = [
            POINTER(PROC_THREAD_ATTRIBUTE_LIST),
            DWORD,
            DWORD_PTR,
            PVOID,
            SIZE_T,
            PVOID,
            POINTER(SIZE_T),
        ]
        UpdateProcThreadAttribute.restype = BOOL

        self.stdin_r, self.stdin = os.pipe()
        self.stdout, self.stdout_w = os.pipe()

        os.set_inheritable(self.stdin_r, True)
        os.set_inheritable(self.stdout_w, True)

        stdin_handle = get_osfhandle(self.stdin_r)
        stdout_handle = get_osfhandle(self.stdout_w)

        PROC_THREAD_ATTRIBUTE_HANDLE_LIST = 0x20002
        size = SIZE_T()
        InitializeProcThreadAttributeList(None, 1, 0, byref(size))
        attr_list = cast(
            create_string_buffer(size.value), POINTER(PROC_THREAD_ATTRIBUTE_LIST)
        )
        handle_list = (HANDLE * 2)(stdin_handle, stdout_handle)
        UpdateProcThreadAttribute(
            attr_list,
            0,
            PROC_THREAD_ATTRIBUTE_HANDLE_LIST,
            handle_list,
            sizeof(handle_list),
            None,
            None,
        )

        FOPEN = 0x01
        FDEV = 0x40
        buf = struct.pack("I", 2)  # cfi_len, number of handles
        buf += 2 * struct.pack("B", FOPEN | FDEV)  # file info
        buf += struct.pack("PP", stdin_handle, stdout_handle)  # os handle

        startup_info_ex = STARTUPINFOEXW()
        startup_info_ex.lpAttributeList = attr_list
        startup_info = startup_info_ex.StartupInfo
        startup_info.cb = sizeof(startup_info_ex)
        startup_info.cbReserved2 = len(buf)
        startup_info.lpReserved2 = cast(buf, LPBYTE)

        EXTENDED_STARTUPINFO_PRESENT = 0x00080000
        self.process_info = PROCESS_INFORMATION()
        CreateProcessW(
            None,
            self.path,
            None,
            None,
            True,
            EXTENDED_STARTUPINFO_PRESENT,
            None,
            None,
            byref(startup_info),
            byref(self.process_info),
        )
        # print(f"Opened Process: {self.exe_name} with PID {self.process_info.dwProcessId}")

        os.close(self.stdin_r)
        os.close(self.stdout_w)

    def stop(self):
        # print(f"Killing Process: {self.exe_name} with PID {self.process_info.dwProcessId}")
        os.kill(self.process_info.dwProcessId, signal.SIGTERM)
        self.CloseHandle(self.process_info.hThread)
        self.CloseHandle(self.process_info.hProcess)
        self.process_info = None

        # for p in psutil.process_iter():
        #     if p.name() == self.exe_name:
        #         p.kill()
        #         p.wait(timeout=10)
