import signal
from msvcrt import get_osfhandle
from ctypes.wintypes import *
from ctypes import *
import os
import psutil
import struct


if sizeof(c_ulong) == sizeof(c_void_p):
    ULONG_PTR = c_ulong
elif sizeof(c_ulonglong) == sizeof(c_void_p):
    ULONG_PTR = c_ulonglong
DWORD_PTR = ULONG_PTR
SIZE_T = ULONG_PTR
PVOID = c_void_p

class PROC_THREAD_ATTRIBUTE_LIST(Structure):
    pass

class SECURITY_ATTRIBUTES(Structure):
    pass

class STARTUPINFOW(Structure):
    _fields_ = [
        ('cb',               DWORD),
        ('lpReserved',       LPWSTR),
        ('lpDesktop',        LPWSTR),
        ('lpTitle',          LPWSTR),
        ('dwX',              DWORD),
        ('dwY',              DWORD),
        ('dwXSize',          DWORD),
        ('dwYSize',          DWORD),
        ('dwXCountChars',    DWORD),
        ('dwYCountChars',    DWORD),
        ('dwFillAttributes', DWORD),
        ('dwFlags',          DWORD),
        ('wShowWindow',      WORD),
        ('cbReserved2',      WORD),
        ('lpReserved2',      LPBYTE),
        ('hStdInput',        HANDLE),
        ('hStdOutput',       HANDLE),
        ('hStdError',        HANDLE),
    ]

class STARTUPINFOEXW(Structure):
    _fields_ = [
        ('StartupInfo',     STARTUPINFOW),
        ('lpAttributeList', POINTER(PROC_THREAD_ATTRIBUTE_LIST)),
    ]

class PROCESS_INFORMATION(Structure):
    _fields_ = [
        ('hProcess',    HANDLE),
        ('hThread',     HANDLE),
        ('dwProcessId', DWORD),
        ('dwThreadId',  DWORD),
    ]


class UnbufferedProcess:
    def __init__(self, path):
        assert os.path.exists(path.split()[0]), f"{path.split()[0]} is not a valid exe"
        self.path = path
        self.exe_name = os.path.basename(path.split()[0])
        self.process_info = None
        self.stdin = None
        self.stdout = None

    def run(self):
        kernel32 = WinDLL('kernel32', use_last_error=True)
        CreateProcessW = kernel32.CreateProcessW
        CreateProcessW.argtypes = [LPCWSTR, LPWSTR,
                                   POINTER(SECURITY_ATTRIBUTES),
                                   POINTER(SECURITY_ATTRIBUTES), BOOL, DWORD,
                                   LPVOID, LPCWSTR, POINTER(STARTUPINFOW),
                                   POINTER(PROCESS_INFORMATION)]
        CreateProcessW.restype = BOOL
        self.CloseHandle = kernel32.CloseHandle
        self.CloseHandle.argtypes = [HANDLE]
        self.CloseHandle.restype = BOOL
        InitializeProcThreadAttributeList = kernel32.InitializeProcThreadAttributeList
        InitializeProcThreadAttributeList.argtypes = [
            POINTER(PROC_THREAD_ATTRIBUTE_LIST), DWORD, DWORD, POINTER(SIZE_T)]
        InitializeProcThreadAttributeList.restype = BOOL
        UpdateProcThreadAttribute = kernel32.UpdateProcThreadAttribute
        UpdateProcThreadAttribute.argtypes = [
            POINTER(PROC_THREAD_ATTRIBUTE_LIST), DWORD, DWORD_PTR, PVOID,
            SIZE_T, PVOID, POINTER(SIZE_T)]
        UpdateProcThreadAttribute.restype = BOOL

        self.stdin_r, self.stdin = os.pipe()
        self.stdout, self.stdout_w = os.pipe()

        os.set_inheritable(self.stdin_r, True)
        os.set_inheritable(self.stdout_w, True)

        stdin_handle = get_osfhandle(self.stdin_r)
        stdout_handle = get_osfhandle(self.stdout_w)

        PROC_THREAD_ATTRIBUTE_HANDLE_LIST = 0x20002
        size = SIZE_T()
        InitializeProcThreadAttributeList(None, 1, 0, byref(size))
        attr_list = cast(create_string_buffer(size.value),
                         POINTER(PROC_THREAD_ATTRIBUTE_LIST))
        handle_list = (HANDLE * 2)(stdin_handle, stdout_handle)
        UpdateProcThreadAttribute(attr_list, 0,
                                  PROC_THREAD_ATTRIBUTE_HANDLE_LIST,
                                  handle_list, sizeof(handle_list), None, None)

        FOPEN = 0x01
        FDEV = 0x40
        buf = struct.pack('I', 2)  # cfi_len, number of handles
        buf += 2 * struct.pack('B', FOPEN | FDEV)  # file info
        buf += struct.pack('PP', stdin_handle, stdout_handle)  # os handle

        startup_info_ex = STARTUPINFOEXW()
        startup_info_ex.lpAttributeList = attr_list
        startup_info = startup_info_ex.StartupInfo
        startup_info.cb = sizeof(startup_info_ex)
        startup_info.cbReserved2 = len(buf)
        startup_info.lpReserved2 = cast(buf, LPBYTE)

        EXTENDED_STARTUPINFO_PRESENT = 0x00080000
        self.process_info = PROCESS_INFORMATION()
        CreateProcessW(None, self.path, None, None, True,
                       EXTENDED_STARTUPINFO_PRESENT, None, None,
                       byref(startup_info), byref(self.process_info))
        #print(f"Opened Process: {self.exe_name} with PID {self.process_info.dwProcessId}")

        os.close(self.stdin_r)
        os.close(self.stdout_w)

    def stop(self):
        #print(f"Killing Process: {self.exe_name} with PID {self.process_info.dwProcessId}")
        os.kill(self.process_info.dwProcessId, signal.SIGTERM)
        self.CloseHandle(self.process_info.hThread)
        self.CloseHandle(self.process_info.hProcess)
        self.process_info = None

        # for p in psutil.process_iter():
        #     if p.name() == self.exe_name:
        #         p.kill()
        #         p.wait(timeout=10)
