import hid
import json
import struct
import tkinter as tk
#import tkinter.ttk as ttk
import tkinter.filedialog
from imgui_bundle import imgui, immapp
from enum import IntEnum
import time
#from threading import Thread
#from queue import Queue

BIGSCREEN_VID = 0x35BD
BEYOND_PID = 0x0101
USER_SIG_LENGTH = 512
CRC_INIT_VAL = 0xFF
CRC_POLY = 0x07
DEFAULT_HID_TIMEOUT = 10

CMD_GET_CONFIG = ord('U') # followed by a number 0 to 15 for the 32-byte page
CMD_WRITE_CONFIG = ord('W')
CMD_SAVE_CONFIG = ord('V')
CMD_ERROR = ord('E')
CMD_ACK = ord('$')


class SigTag(IntEnum):
    Invalid = 0xFF
    Serial = 0x01
    RGB_Color = 0x02
    Fan_Speed = 0x03
    Prox_Disable = 0x04
    Linkbox_v1 = 0x05
    Prox_Cal = 0x06
    FATP_Mode = 0x07
    HMD_Serial = 0x08
    Tracking_Serial = 0x09
    Brightness = 0x0A
    Prox_Thresh = 0x0B
    Prox_Hyst = 0x0C
    EDID_Switch = 0x0D

def tag_to_string(tagval: SigTag) -> str:
    # note that only the values we care about, and want to show to users, will be decoded
    tag_decoder = {
        SigTag.HMD_Serial: 'Serial Number',
        SigTag.RGB_Color: 'RGB Color',
        SigTag.Fan_Speed: 'Fan Speed',
        SigTag.Brightness: 'Display Brightness',
        SigTag.Prox_Cal: 'Proximity Zero Cal',
        SigTag.Prox_Thresh: 'Proximity Threshold',
        SigTag.EDID_Switch: 'Framerate Setting'
    }

    if tagval in tag_decoder:
        return tag_decoder[tagval]
    else:
        return 'unknown'

def value_to_string(tagval: SigTag, val: bytes) -> str:
    try:
        if(tagval == SigTag.HMD_Serial):
            return val.decode('utf-8',errors='ignore')
        
        if(tagval == SigTag.RGB_Color):
            (r,g,b) = struct.unpack('<BBB',val[0:3])
            return 'R:'+str(r) + ', G:'+str(g) + ', B:'+str(b)
        
        if(tagval == SigTag.Fan_Speed):
            (fanspeed,) = struct.unpack('<B',val[0:1])
            return str(fanspeed)+'%'

        if(tagval == SigTag.Brightness):
            (rawbright,) = struct.unpack('<H',val[0:2])
            brightpct = get_brightness_percent_from_raw(rawbright)
            return str(int(round(brightpct,0)))

        if(tagval == SigTag.Prox_Cal):
            (proxcal,) = struct.unpack('<H',val[0:2])
            return str(proxcal)

        if(tagval == SigTag.Prox_Thresh):
            (proxthresh,) = struct.unpack('<H',val[0:2])
            return str(proxthresh)
        
        if(tagval == SigTag.EDID_Switch):
            (swval,) = struct.unpack('<B',val[0:1])
            if swval == 0 or swval == 2:
                return '75Hz'
            if swval == 1:
                return '90Hz'
            return 'unknown'
        
    except:
        return 'invalid'
    
def crc8(input_data: bytes) -> int:
    initval = CRC_INIT_VAL
    for bb in input_data:
        # for each byte, xor with the current value of CRC
        # then iterate over all 8 bits using the polynomial 
        # to generate crc bits
        initval = initval ^ bb
        for _ in range(8):
            if(initval & 0x80):
                initval = ((initval << 1) & 0xFF) ^ CRC_POLY
            else:
                initval = (initval << 1) & 0xFF
    return initval

def wait_for_response(hmd_device: hid.Device, message_types:list, timeout_ms: int = 1000) -> bytes:
    start_time = time.monotonic_ns()
    
    while( (start_time + (timeout_ms*1000000)) > time.monotonic_ns()):
        bytesout = hmd_device.read(65, timeout=DEFAULT_HID_TIMEOUT)
        if(len(bytesout) > 0):
            if(bytesout[0] in message_types):
                return bytesout
    
    return b''

def load_raw_config_from_beyond() -> bytes:
    bynd=hid.Device(vid=BIGSCREEN_VID,pid=BEYOND_PID)

    fullbytes = bytearray(b'')
    for i in range(16):
        bynd.send_feature_report(bytes([0, CMD_GET_CONFIG,i]))
        sig_page = wait_for_response(bynd, [CMD_GET_CONFIG, CMD_ERROR])
        if (len(sig_page) == 0) or (sig_page[0] != CMD_GET_CONFIG):
            return b''
        sig_len = sig_page[1]
        fullbytes.extend(sig_page[2:(2+sig_len)])

    return bytes(fullbytes)

def save_raw_config_to_beyond(newconfig: bytes) -> bool:
    if(len(newconfig) != 512):
        return False
    
    bynd=hid.Device(vid=BIGSCREEN_VID, pid=BEYOND_PID)

    for i in range(16):
        byte_chunk = newconfig[(32*i):(32*i + 32)]
        bynd.send_feature_report(bytes([0, CMD_WRITE_CONFIG, i]) + byte_chunk)
        ack = wait_for_response(bynd, [CMD_ACK, CMD_ERROR])
        if (len(ack) == 0) or (ack[0] != CMD_ACK):
            return False
    
    bynd.send_feature_report(bytes([0, CMD_SAVE_CONFIG]))
    ack = wait_for_response(bynd, [CMD_ACK, CMD_ERROR])
    if (len(ack) == 0) or (ack[0] != CMD_ACK):
        return False
    
    return True

def show_raw_config(cnfg: bytes):
    if(cnfg == b'error'):
        imgui.text('Could not read from Beyond')
    if(len(cnfg) == 512):
        for i in range(16):
            printable_text = ' '.join(f'0x{bb:02X}' for bb in cnfg[32*i:32+32*i])
            imgui.text(printable_text)

def parse_config(cnfg: bytes) -> dict:
    # reads the signature block and extracts the config data
    # data is in TLVC (Tag, Length, Value, CRC) format
    # Byte0: Tag, tells you what data field is saved
    # Byte1: Length, number of bytes in the Value
    # Byte2 to Byte 2+Length-1: Value, arbitrary data saved for this field
    # Byte 2+Length: CRC, 8-bit CRC code generated with polynomial 0x07 over all bytes (including tag)
    #               with an initial value of 0xFF

    # set all of the fields to false until we actually decode that field
    all_tags = {}
    # error checking - confirm that length is 512
    if(len(cnfg) != USER_SIG_LENGTH):
        return all_tags
    
    # begin sequentially parsing the bytes
    sig_ptr = 0
    while(sig_ptr < USER_SIG_LENGTH):
        if(cnfg[sig_ptr] == SigTag.Invalid):
            # this is the end of stored tags
            return all_tags
        # ensure enough room for at least a 1-byte value
        if(sig_ptr > (USER_SIG_LENGTH - 4)):
            return all_tags
        # and enough room for THIS value
        tag = cnfg[sig_ptr]
        taglen = cnfg[sig_ptr + 1]
        if(sig_ptr > (USER_SIG_LENGTH - (3 + taglen))):
            return all_tags
        tagval = cnfg[(sig_ptr + 2):(sig_ptr + 2 + taglen)]
        tagcrc = cnfg[sig_ptr + 2 + taglen]
        # check the CRC
        crc_computed = crc8(cnfg[sig_ptr:(sig_ptr + 2 + taglen)])
        if(tagcrc != crc_computed):
            return all_tags
        # everything checked out, add this tag to our dictionary
        all_tags[tag] = tagval
        # increment pointer to get to the next tag
        sig_ptr = sig_ptr + 3 + taglen

    # got here if we went past the end of the 512 byte field
    return all_tags

def create_config(vals: dict) -> bytes:
    retval = bytearray()

    for eachkey, eachval in vals.items():
        tag_len_val = bytes([eachkey, len(eachval)]) + eachval
        this_crc = crc8(tag_len_val)
        retval.extend(tag_len_val)
        retval.append(this_crc)

    # extend with 0xFF until 512 bytes
    if(len(retval) < 512):
        retval.extend(bytes([0xFF]*(512 - len(retval))))
    
    return retval

def get_brightness_percent_from_raw(raw: int) -> int:
    if(raw <= 266):
        brightness = 100*(float(raw) - 50.0)/216.0
    else:
        brightness = 100+100*(float(raw) - 266.0)/553.0
    return brightness

def convert_json_string_to_bytes(config: dict):
    retdict = {}
    try:
        for eachkey in config:
            retdict[int(eachkey)] = bytes([ord(ch) for ch in config[eachkey]])
    except:
        return {}

    return retdict

def convert_to_json_printable(config: dict):
    # copy the formatting in JSON11 library in the C++ utility
    
    retdict = {}
    for eachkey in config:
        bytesval = config[eachkey]
        out = ''
        skip = 0
        for i in range(len(bytesval)):
            if skip > 0:
                skip = skip - 1
                continue
            ch = bytesval[i]
            if ch == ord('\\'):
                out = out + '\\\\'
            elif ch == ord('\b'):
                out = out + '\\b'
            elif ch == ord('\f'):
                out = out + '\\f'
            elif ch == ord('\n'):
                out = out + '\\n'
            elif ch == ord('\r'):
                out = out + '\\r'
            elif ch == ord('\t'):
                out = out + '\\t'
            elif ch <= 0x1F:
                out = out + f'\\u{ch:04x}'
            elif ch == 0xE2:
                if i+2 < len(bytesval):
                    if(bytesval[i+1] == 0x80 and bytesval[i+2] == 0xA8):
                        out = out + '\\u2028'
                        skip = 2
                    if(bytesval[i+1] == 0x80 and bytesval[i+2] == 0xA9):
                        out = out + '\\u2029'
                        skip = 2
            else:
                out = out + chr(ch)
        retdict[eachkey] = out
    return retdict

def create_json_file(params, file_obj):
    file_obj.write('{')
    first = True
    for eachkey in params:
        if(first):
            first = False
        else:
            file_obj.write(',')
        file_obj.write(f'"{eachkey}": "{params[eachkey]}"')
    file_obj.write('}')

class my_bundle_gui:
    IDLE = 0
    SAVE_FAILED = 1
    SAVE_SUCCEEDED = 2
    def __init__(self):
        self.rawconfig = b''
        self.json_obj = {}
        self.status = self.IDLE

    def run(self):

        
        if(self.status == self.SAVE_FAILED):
            imgui.open_popup('Invalid File')
        if(self.status == self.SAVE_SUCCEEDED):
            imgui.open_popup('Complete')

        if imgui.button('try popup'):
            imgui.open_popup('popuptrial')

        if imgui.begin_popup('popuptrial'):
            imgui.text('hello from popup')
            imgui.end_popup()

        if imgui.begin_popup('Invalid File'):
            imgui.text('invalid file contents, cannot load to Beyond')
            imgui.end_popup()
        else:
            self.status = self.IDLE
        if imgui.begin_popup('Complete'):
            imgui.text('file contents written to Beyond')
            imgui.end_popup()
        else:
            self.status = self.IDLE


        if imgui.begin_table('##config_table', 2, flags= imgui.TableFlags_.borders):
            imgui.table_next_column()
            if imgui.button("load values from Beyond"):
                self.get_config()
            self.show_config()
            if imgui.button("save these values to a file"):
                self.save_to_json_file()
            imgui.table_next_column()
            if imgui.button("load from backup file"):
                self.load_config_backup_file()
            self.show_file_config()
            if imgui.button("save file contents to Beyond"):
                self.save_file_to_beyond()
            imgui.end_table()
        #if self.status == 

    def get_config(self):
        try:
            self.rawconfig = load_raw_config_from_beyond()
        except:
            self.rawconfig = b'error'

    def show_config(self):
        if self.rawconfig == b'error':
            imgui.text('could not read from Beyond')
        elif len(self.rawconfig) == USER_SIG_LENGTH:
            all_values = parse_config(self.rawconfig)
            if(len(all_values) == 0):
                imgui.text('no values read from Beyond')
            else:
                for eachkey in sorted(all_values):
                    text_label = tag_to_string(eachkey)
                    if text_label != 'unknown':
                        text_value = value_to_string(eachkey, all_values[eachkey])
                        imgui.text(text_label + ': ')
                        imgui.text('\t' + text_value)
        else:
            imgui.text('no values read from Beyond')

    def show_file_config(self):
        if((self.json_obj is None) or (len(self.json_obj)==0)):
            imgui.text('no values read from file')
        else:
            bytes_json_obj = convert_json_string_to_bytes(self.json_obj)
            for eachkey in sorted(bytes_json_obj):
                text_label = tag_to_string(eachkey)
                if text_label != 'unknown':
                    text_value = value_to_string(eachkey, bytes_json_obj[eachkey])
                    imgui.text(text_label + ': ')
                    imgui.text('\t' + text_value)

    def save_file_to_beyond(self):
        # first ensure that the file contents are valid
        if((self.json_obj is None) or (len(self.json_obj)==0)):
            # imgui popup "invalid file contents, cannot load to Beyond"
            self.status = self.SAVE_FAILED
            return
            
        # check that the expected tags are present
        # PCB serial number should always be there, and starts with XCNL
        # examples:     XCNL2D2374000031
        #               XCNL2D232R000003
        bytes_json_obj = convert_json_string_to_bytes(self.json_obj)
        if(SigTag.Serial in bytes_json_obj):
            if bytes_json_obj[SigTag.Serial].startswith(b'XCNL'):
                rawbytes = create_config(bytes_json_obj)
                save_raw_config_to_beyond(rawbytes)
                # imgui popup "file contents written to Beyond"
                # imgui.open_popup("Complete")
                self.status = self.SAVE_SUCCEEDED


    def save_to_json_file(self):
        save_filename = tkinter.filedialog.asksaveasfilename(confirmoverwrite=True,initialdir=r"C:\Program Files (x86)\Steam\steamapps\common\Bigscreen Beyond Driver\bin",
                                                               filetypes=[("json","*.json"),("all files","*")])
        try:
            with open(save_filename, 'w') as fil:
                create_json_file(convert_to_json_printable(parse_config(self.rawconfig)),fil)
                #json.dump(convert_to_json_printable(parse_config(self.rawconfig)), fil)
        except:
            pass

    def load_config_backup_file(self):
        self.json_file_object = tkinter.filedialog.askopenfile(mode='r', initialdir=r"C:\Program Files (x86)\Steam\steamapps\common\Bigscreen Beyond Driver\bin",
                                                               filetypes=[("json","*.json"),("all files","*")])
        try:
            self.json_obj = json.load(self.json_file_object)
        except:
            pass


    

if __name__ == '__main__':
    gui = my_bundle_gui()
    immapp.run(
        gui_function = gui.run,
        window_title="Beyond Configuration Restore",
        window_size_auto=False,
        window_size=(400,600)
    )



