import tkinter as tk
import tkinter.ttk as ttk
import hid
import struct
import time
import threading

DEFAULT_HID_TIMEOUT = 10 # milliseconds to wait for a HID message

BIGSCREEN_VID = 0x35BD
BEYOND_PID = 0x0101

HMD_GET_SW_VER = ord('*')
HMD_GET_BOARD_SERIAL = ord('%')
HMD_GET_VXR_FIRMWARE = ord('N')
HMD_GET_OLED_ID = ord('^')
HMD_GET_USAGE_TIMER = ord('Z')
HMD_GET_STACK_LEVELS = ord('s')
HMD_ERROR = ord('E')

# "wait_for_response" is a helper function that waits for a specific message type
# the HID response packet (aka the HID report) has the message type as the first byte
# send the desired message types in a list of ints, any message that 
# matches one of the types in the list will be returned
# if timeout, returns empty bytes
#
# params:
# hmd_device -      A hid.Device object that is already connected to the HMD
# message_types -   List of ints, each int is a message type. If any message
#                   is received that matches one of the types, this function
#                   will return with that response packet
# timeout_ms -      An int with the desired timeout in milliseconds. Default 1000 
#
# Returns:  A bytes object. If timed out, will be empty bytes (b''). Otherwise
#           it will have the first message that matched one of the types given.

def wait_for_response(hmd_device: hid.Device, message_types:list, timeout_ms: int = 1000) -> bytes:
    start_time = time.monotonic_ns()
    
    while( (start_time + (timeout_ms*1000000)) > time.monotonic_ns()):
        bytesout = hmd_device.read(65, timeout=DEFAULT_HID_TIMEOUT)
        if(len(bytesout) > 0):
            if(bytesout[0] in message_types):
                return bytesout
    
    return b''

# Gets the microcontroller firmware version as a string. If not found, returns the string "error"
def get_hmd_sw_ver(hmd_device: hid.Device) -> str:
    hmd_device.send_feature_report(bytes([0, HMD_GET_SW_VER]))

    response = wait_for_response(hmd_device, [HMD_GET_SW_VER, HMD_ERROR])
    if(len(response) == 0 or response[0] == HMD_ERROR):
        return "error"
    
    return response[1:].rstrip(b'\x00').decode('ascii')

# Gets the task high watermark values. Number of words remaining in each stack.
# Closer to zero is bad.
def get_stack_levels(hmd_device: hid.Device) -> list[int]:
    hmd_device.send_feature_report(bytes([0, HMD_GET_STACK_LEVELS]))

    response = wait_for_response(hmd_device, [HMD_GET_STACK_LEVELS, HMD_ERROR])
    if(len(response) == 0 or response[0] == HMD_ERROR):
        return []
    
    # length of the response is in byte 1
    lenbytes = response[1]
    lenwords = int(lenbytes / 4)
    return struct.unpack('<'+'I'*lenwords, response[2:(2+lenbytes)])



# globals used by threads
window = tk.Tk()    
initial_stack_sizes = [1024,
                            512,
                            512,
                            512,
                            512,
                            512,
                            512,
                            512,
                            512,
                            512,
                            280,
                            2800]
stack_levels = [tk.IntVar(window, ss) for ss in initial_stack_sizes]
sw_ver = tk.StringVar(window, "Software Version:")
quit_thread = False

def thread_worker():
    global stack_levels
    global sw_ver
    global quit_thread

    bynd=hid.Device(vid=BIGSCREEN_VID, pid=BEYOND_PID)

    while(not quit_thread):
        time.sleep(0.5)
        software_version = get_hmd_sw_ver(bynd)
        stack_level_list = get_stack_levels(bynd)
        try:
            sw_ver.set("Software Version: "+software_version)
            for i in range(12):
                stack_levels[i].set(4*stack_level_list[i])
                # note: 4x stack level because the reported
                # value is in words, not bytes.
        except:
            return # gets here if tkinter main loop was stopped

def main():
    global stack_levels
    global sw_ver
    global quit_thread

    task_names = [  "USBLoad",
                    "HID Commands",
                    "HID Report",
                    "I2C",
                    "TIMED_START",
                    "ADC",
                    "LED",
                    "FAN",
                    "video_proc",
                    "USART",
                    "IDLE",
                    "Tmr Svc"]

    #header_frame = ttk.Frame(master=window, relief=tk.RAISED, borderwidth=2)
    #header_frame.grid(row=0, column=0, columnspan=3, sticky='news')

    header_label1 = ttk.Label(master=window, text="Task Name", font=('Helvetica',10,'bold'))
    header_label2 = ttk.Label(master=window, text="Initial Stack Size", font=('Helvetica',10,'bold'))
    header_label3 = ttk.Label(master=window, text="Current Stack Remaining", font=('Helvetica',10,'bold'))
    header_label1.grid(row=0, column=0)
    header_label2.grid(row=0, column=1)
    header_label3.grid(row=0, column=2)

    for i in range(12): # number of tasks = 12
        #frame = ttk.Frame(master= window,
        #                 relief=tk.RAISED,
        #                 borderwidth=1)
        #frame.grid(row=i, column=0)
        label = ttk.Label(master=window, text=task_names[i])
        label.grid(row=i+1, column=0)
        label2 = ttk.Label(master=window, text=str(initial_stack_sizes[i]))
        label2.grid(row=i+1, column=1)
        innerframe = ttk.Frame(master=window, relief=tk.SUNKEN,borderwidth=1)
        innerframe.grid(row=i+1, column=2)
        label3 = ttk.Label(master=innerframe, textvariable=stack_levels[i])
        label3.pack()
    blanklabel = ttk.Label(master=window, text="")
    blanklabel.grid(row=13, column=0)
    swverlabel = ttk.Label(master=window, textvariable=sw_ver)
    swverlabel.grid(row=14, column=0, columnspan=2)

    quit_thread = False

    mthread = threading.Thread(target=thread_worker)
    mthread.start()

    window.mainloop()
    quit_thread = True
    mthread.join()

if __name__ == '__main__':
    main()