import subprocess
import time
from functools import partial
import tkinter as tk
from tkinter import ttk
from tkinter import PhotoImage, BitmapImage
from tkinter.messagebox import showerror
import sys
import os
import logging
from threading import Thread
from PIL import Image, ImageTk

# Detect import error for HID - usually if hidapi.dll is not found
try:
    import hid
    from hid import HIDException
except ImportError as err:
    showerror(title="Failed to start", message="Could not import \"hid\" Python library. Is hidapi.dll installed?\n\n"
              f"Error message:\n{err}")
    
class tkgui(tk.Frame):
    def __init__(self, parent, image_files):
        super().__init__(parent)

        self.frame_images = ttk.Frame(self, relief=tk.GROOVE, borderwidth=6)
        self.frame_vs_test = ttk.Frame(self, relief=tk.GROOVE, borderwidth=6)
        self.frame_stop = ttk.Frame(self, relief=tk.GROOVE, borderwidth=6)

        label_images = ttk.Label(self, text = "Images")

        self.btn_vs_test = ttk.Button(self.frame_vs_test, text="Vertical Stripe Test", command=self.vert_stripe_test)
        self.btn_stop = ttk.Button(self.frame_stop, text="Stop Display", command=self.vid_stop)

        label_images.pack()
        self.frame_images.pack(fill=tk.BOTH, padx=2, pady=2)
        self.frame_vs_test.pack(fill=tk.BOTH, padx=25, pady=2)
        self.frame_stop.pack(fill=tk.BOTH, padx=25, pady=2)

        self.btn_vs_test.pack(fill=tk.BOTH, padx=10, pady=2)
        self.btn_stop.pack(fill=tk.BOTH, padx=10, pady=2)

        # create image display buttons
        self.img_buttons = []
        for img in image_files:

            image_path = os.path.abspath('test-images/'+img)
            # load file in PIL, resize it to 64x64
            with Image.open(image_path) as im:
                im=im.resize((128,128),Image.Resampling.BILINEAR)
                imgphoto = ImageTk.PhotoImage(im)
            
            newbtn = ttk.Button(self.frame_images, image=imgphoto, command=partial(self.show_image, image_path))
            newbtn.image = imgphoto #keep a local reference!

            self.img_buttons.append(newbtn)
        
        for btn, num in zip(self.img_buttons, range(len(self.img_buttons))):
            col = int(num%4)
            rw = int(num/4)
            btn.grid(column=col, row=rw, padx=2, pady=2)

        self.frame_images.columnconfigure(0,weight=1)
        self.frame_images.columnconfigure(1,weight=1)
        self.frame_images.columnconfigure(2,weight=1)
        self.frame_images.columnconfigure(3,weight=1)

        self.pack(fill=tk.BOTH)

        self.displaythread = Thread()
        self.quitnow = False



    def thread_cleanup(self):
        if(self.displaythread.is_alive()):
            self.displaythread.join()

    def change_button_state(self, is_running):
        if(is_running):
            self.btn_vs_test["state"] = tk.DISABLED
            for btn in self.img_buttons:
                btn["state"] = tk.DISABLED
        else:
            self.btn_vs_test["state"] = tk.ACTIVE
            for btn in self.img_buttons:
                btn["state"] = tk.ACTIVE

    def open_beyond(self) -> hid.Device:
        try:
            bigs = hid.Device(vid=0x35bd, pid=0x0101)
        except HIDException as hiderr:
            showerror("Could not connect", "Could not connect to Beyond through USB.\nPlease check all connections.\n\n"
                      f"Error message:\n{hiderr}")
            # logging.error(f"Could not connect to Beyond through USB. Error message: {hiderr}")
            return None
            
        return bigs

    # def vid_60_1920(self):
    #     self.quitnow = False
    #     print('switching to 60Hz, 1920')
    #     self.change_button_state(True)
    #     self.thread_cleanup()
    #     self.displaythread = Thread(target=self.runnable_60_1920)
    #     self.displaythread.start()

    def vert_stripe_test(self):
        self.quitnow = False
        print('Displaying vertical stripe pattern')
        if self.enable_direct_mode():
            self.change_button_state(True)
            self.thread_cleanup()
            self.displaythread = Thread(target=self.runnable_vert_stripe_test)
            self.displaythread.start()

    def show_image(self, filename):
        self.quitnow = False
        print(f'Displaying "{filename}"')
        if self.enable_direct_mode():
            self.change_button_state(True)
            self.thread_cleanup()
            self.displaythread = Thread(target=self.runnable_show_image, args=[filename])
            self.displaythread.start()

    # def vid_75_2544(self):
    #     self.quitnow = False
    #     print('switching to 75Hz, 2544')
    #     if self.enable_direct_mode():
    #         self.change_button_state(True)
    #         self.thread_cleanup()
    #         self.displaythread = Thread(target=self.runnable_75_2544)
    #         self.displaythread.start()
            

    # def vid_90_1920(self):
    #     self.quitnow = False
    #     print('switching to 90Hz, 1920')
    #     if self.enable_direct_mode():
    #         self.change_button_state(True)
    #         self.thread_cleanup()
    #         self.displaythread = Thread(target=self.runnable_90_1920)
    #         self.displaythread.start()

    def vid_stop(self):
        self.quitnow = True
        print('stopping video')

    def enable_direct_mode(self):
        cmd = ["direct_mode_dx12.exe", "-enable"]
        r = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    #    success = False
    #    if "NVIDIA DirectMode VR ENABLED successfully" in r.stdout.decode():
    #        success = True
    #    if(not success):
    #        showerror("Direct Mode problem", f"Error from direct_mode_dx12.exe: {r.stderr.decode()}")
    #        return False
        return True

    def disable_direct_mode(self):
        cmd = ["direct_mode_dx12.exe", "-disable"]
        r = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    #    success = False
    #    if "NVIDIA DirectMode VR DISABLED successfully" in r.stdout.decode():
    #        success = True
    #    if(not success):
    #        showerror("Direct Mode problem", f"Error from direct_mode_dx12.exe: {r.stderr.decode()}")
    #        return False
        return True

    # def runnable_60_1920(self):
    #     # enters FATP mode with 1920x1920 @60Hz
    #     # starts a direct mode display with the vertical stripe pattern
    #     # waits for user to press stop

    #     bigs = self.open_beyond()
    #     if(bigs is None):
    #         self.change_button_state(False)
    #     else:
    #         bigs.send_feature_report(bytes([0, ord('@'), 1])) # FATP mode 1 - just 1920x1920 60Hz, no DSC
    #         # note - this resolution doesn't allow for Direct Mode
    #         # so after the reset, it will simply appear as a desktop monitor

    #         # wait a moment for the VXR changes to take effect
    #         if not self.delay_with_quit(3):
    #              # early quit occurred
    #             self.change_button_state(False)
    #             return
        
    #         self.change_button_state(False)


    def runnable_vert_stripe_test(self):
        # enters FATP mode with 2544x2544 @60Hz
        # starts a direct mode display with the vertical stripe pattern
        # waits for user to press stop

        bigs = self.open_beyond()
        print('Entering 2544x2544 60Hz video mode')
        if(bigs is None):
            self.change_button_state(False)
        else:
            bigs.send_feature_report(bytes([0, ord('@'), 2])) # FATP mode 2 - just 2544x2544 60Hz, DSC required

            # wait a moment for the VXR changes to take effect
            if not self.delay_with_quit(3):
                 # early quit occurred
                self.change_button_state(False)
                return
        
            proc = self.start_vspattern_with_dsc()

            while not self.quitnow:
                time.sleep(0.005) # keep this process from overusing the cpu
            
            self.stop_direct_mode_process(proc)

    def runnable_show_image(self, filename):
        bigs = self.open_beyond()
        print('Entering 2544x2544 60Hz video mode')
        if(bigs is None):
            self.change_button_state(False)
        else:
            bigs.send_feature_report(bytes([0, ord('@'), 2])) # FATP mode 2 - just 2544x2544 60Hz, DSC required# wait a moment for the VXR changes to take effect
            if not self.delay_with_quit(3):
                 # early quit occurred
                self.change_button_state(False)
                return
        
            proc = self.start_show_image(filename)

            while not self.quitnow:
                time.sleep(0.005) # keep this process from overusing the cpu
            
            self.stop_direct_mode_process(proc)

    # def runnable_75_2544(self):
    #     # enters FATP mode with 2544x2544 @75Hz
    #     # starts a direct mode display with the vertical stripe pattern
    #     # waits for user to press stop

    #     bigs = self.open_beyond()
    #     if(bigs is None):
    #         self.change_button_state(False)
    #     else:
    #         bigs.send_feature_report(bytes([0, ord('d'), 2])) # EDID switch 2 - just 2544x2544 75Hz

    #         # wait a moment for the VXR changes to take effect
    #         if not self.delay_with_quit(3):
    #              # early quit occurred
    #             self.change_button_state(False)
    #             return
        
    #         proc = self.start_vspattern_with_dsc()

    #         while not self.quitnow:
    #             time.sleep(0.005) # keep this process from overusing the cpu
            
    #         self.stop_direct_mode_process(proc)

    # def runnable_90_1920(self):
    #     # enters FATP mode with 1920x1920 @90Hz
    #     # starts a direct mode display with the vertical stripe pattern
    #     # waits for user to press stop

    #     bigs = self.open_beyond()
    #     if(bigs is None):
    #         self.change_button_state(False)
    #     else:
    #         bigs.send_feature_report(bytes([0, ord('d'), 1])) # EDID switch 1 - just 1920x1920 90Hz

    #         # wait a moment for the VXR changes to take effect
    #         if not self.delay_with_quit(3):
    #              # early quit occurred
    #             self.change_button_state(False)
    #             return
        
    #         proc = self.start_vspattern_with_dsc()

    #         while not self.quitnow:
    #             time.sleep(0.005) # keep this process from overusing the cpu
            
    #         self.stop_direct_mode_process(proc)

    def start_vspattern_with_dsc(self) -> subprocess.Popen:
        cmd = ["direct_mode_dx12.exe", "-vspattern"]
        proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        return proc
    
    def start_show_image(self, filename) -> subprocess.Popen:
        cmd = ["direct_mode_dx12.exe", "-image", filename]
        proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        return proc


    def stop_direct_mode_process(self, popen_obj: subprocess.Popen):
        # try to quit the display by typing "q<enter>" into the command line
        popen_obj.communicate('q\r\n'.encode(), timeout=1)
        if(popen_obj.poll() is None):
            # still running!
            popen_obj.terminate()
        self.change_button_state(False)

    # A delay function with early quit ability
    # The internal signal "quitnow" can be set by 
    # any thread and will exit this delay loop early
    # Returns True if timeout occurred, False if "quitnow" was set
    def delay_with_quit(self, timeout_seconds):
        start_time = time.monotonic_ns()

        while (timeout_seconds * 1e9) > (time.monotonic_ns() - start_time):
            time.sleep(0.005) # keep this process from overusing the cpu
            if(self.quitnow):
                return False
        return True

def supported_image_file(filename: str) -> bool:
    (root, ext) = os.path.splitext(filename)
    return ext in ['.bmp','.gif','.jpg','.png','.tif','.tiff']

if __name__ == '__main__':
    # how many image files do we have?
    image_files = []
    for files in os.listdir(r'test-images'):
        if os.path.isfile(r'test-images/'+files):
            if supported_image_file(files):
                image_files.append(files)

    tkroot = tk.Tk()
    tkroot.title('Beyond Display Test 60Hz v2')
    # resize based on number of images
    # try to keep it 4 per row, and expand the vertical size of the GUI
    # window to fit them
    rows = 1 + int((len(image_files) - 1) / 4)
    if(rows < 1):
        rows = 1
    vert_size = int(200 + 130*rows)
    tkroot.geometry(f'640x{vert_size}')
    mygui = tkgui(tkroot, image_files)

    tkroot.mainloop()