import sys
import enum
import struct
import time

from PySide6.QtWidgets import QApplication, QMainWindow, QCheckBox, QDialog, QTextEdit, QGridLayout, QColorDialog, QMessageBox
from PySide6.QtCore import QFile, QTimer, QThread, Signal, Slot, Qt
from PySide6.QtGui import QColor, QPalette, QFont

import hid
from hid import HIDException

import prox_calibration_ui

BIGSCREEN_VID = 0x35BD
BEYOND_PID = 0x0101
USER_SIG_LENGTH = 512
CRC_INIT_VAL = 0xFF
CRC_POLY = 0x07

class SigTag(enum.IntEnum):
    Invalid = 0xFF
    Serial = 0x01
    RGB_Color = 0x02
    Fan_Speed = 0x03
    Prox_Disable = 0x04
    Linkbox_v1 = 0x05
    Prox_Cal = 0x06
    FATP_Mode = 0x07
    HMD_Serial = 0x08
    Tracking_Serial = 0x09

def crc8(input_data: bytes) -> int:
    initval = CRC_INIT_VAL
    for bb in input_data:
        # for each byte, xor with the current value of CRC
        # then iterate over all 8 bits using the polynomial 
        # to generate crc bits
        initval = initval ^ bb
        for _ in range(8):
            if(initval & 0x80):
                initval = ((initval << 1) & 0xFF) ^ CRC_POLY
            else:
                initval = (initval << 1) & 0xFF
    return initval

def create_field(tag: SigTag, data: bytes) -> bytes:
    retbytes = bytes([tag, len(data)]) + data
    crc = crc8(retbytes)
    return retbytes + bytes([crc])

def create_signature(sig_fields: dict) -> bytes:
    sig_ptr = 0
    sig_bytes = bytearray([0xFF]*USER_SIG_LENGTH)
    if(SigTag.Serial in sig_fields):
        # serial number
        newsig = sig_fields[SigTag.Serial]
        if(len(newsig) > 0):
            newblock = create_field(SigTag.Serial, newsig)
            sig_bytes[sig_ptr:sig_ptr + len(newblock)] = newblock
            sig_ptr = sig_ptr + len(newblock)
    if(SigTag.HMD_Serial in sig_fields):
        newsig = sig_fields[SigTag.HMD_Serial]
        if(len(newsig) > 0):
            newblock = create_field(SigTag.HMD_Serial, newsig)
            sig_bytes[sig_ptr:sig_ptr + len(newblock)] = newblock
            sig_ptr = sig_ptr + len(newblock)
    if(SigTag.Tracking_Serial in sig_fields):
        newsig = sig_fields[SigTag.Tracking_Serial]
        if(len(newsig) > 0):
            newblock = create_field(SigTag.Tracking_Serial, newsig)
            sig_bytes[sig_ptr:sig_ptr + len(newblock)] = newblock
            sig_ptr = sig_ptr + len(newblock)
    if(SigTag.Fan_Speed in sig_fields):
        if(type(sig_fields[SigTag.Fan_Speed]) == int):
            newval = bytes([sig_fields[SigTag.Fan_Speed]])
        else:
            newval = bytes(sig_fields[SigTag.Fan_Speed])
        newblock = create_field(SigTag.Fan_Speed, newval)
        sig_bytes[sig_ptr:sig_ptr + len(newblock)] = newblock
        sig_ptr = sig_ptr + len(newblock)
    if(SigTag.Prox_Cal in sig_fields):
        newblock = create_field(SigTag.Prox_Cal, sig_fields[SigTag.Prox_Cal])
        sig_bytes[sig_ptr:sig_ptr + len(newblock)] = newblock
        sig_ptr = sig_ptr + len(newblock)
    if(SigTag.RGB_Color in sig_fields):
        newblock = create_field(SigTag.RGB_Color, sig_fields[SigTag.RGB_Color])
        sig_bytes[sig_ptr:sig_ptr + len(newblock)] = newblock
        sig_ptr = sig_ptr + len(newblock)
    if(SigTag.Prox_Disable in sig_fields):
        if(type(sig_fields[SigTag.Prox_Disable]) == int):
            newval = bytes([sig_fields[SigTag.Prox_Disable]])
        else:
            newval = bytes(sig_fields[SigTag.Prox_Disable])
        newblock = create_field(SigTag.Prox_Disable, newval)
        sig_bytes[sig_ptr:sig_ptr + len(newblock)] = newblock
        sig_ptr = sig_ptr + len(newblock)
    if(SigTag.Linkbox_v1 in sig_fields):
        if(type(sig_fields[SigTag.Linkbox_v1]) == int):
            newval = bytes([sig_fields[SigTag.Linkbox_v1]])
        else:
            newval = bytes(sig_fields[SigTag.Linkbox_v1])
        newblock = create_field(SigTag.Linkbox_v1, newval)
        sig_bytes[sig_ptr:sig_ptr + len(newblock)] = newblock
        sig_ptr = sig_ptr + len(newblock)
    if(SigTag.FATP_Mode in sig_fields):
        if(type(sig_fields[SigTag.Linkbox_v1]) == int):
            newval = bytes([sig_fields[SigTag.FATP_Mode]])
        else:
            newval = bytes(sig_fields[SigTag.FATP_Mode])
        newblock = create_field(SigTag.FATP_Mode, newval)
        sig_bytes[sig_ptr:sig_ptr + len(newblock)] = newblock
        sig_ptr = sig_ptr + len(newblock)

    return bytes(sig_bytes)

def parse_sig(sig_bytes: bytes) -> dict:
    # reads the signature block and extracts the config data
    # data is in TLVC (Tag, Length, Value, CRC) format
    # Byte0: Tag, tells you what data field is saved
    # Byte1: Length, number of bytes in the Value
    # Byte2 to Byte 2+Length-1: Value, arbitrary data saved for this field
    # Byte 2+Length: CRC, 8-bit CRC code generated with polynomial 0x07 over all bytes (including tag)
    #               with an initial value of 0xFF

    sig_ptr = 0
    sig_fields = {}
    while(sig_ptr < USER_SIG_LENGTH):
        if(sig_bytes[sig_ptr] == SigTag.Invalid):
            # we reached the end of saved data, so this was valid
            # even if there's no data saved this can be a valid signature (all 0xFF)
            return sig_fields
        # Ensure there's enough room for a 1-byte field
        if(sig_ptr > (USER_SIG_LENGTH - 4)):
            # Tag will overrun end of signature region
            sig_fields['error'] = 'overrun'
            return sig_fields
        tag = sig_bytes[sig_ptr]
        taglen = sig_bytes[sig_ptr+1]
        # now ensure there's enough room for THIS field. Use the length data.
        if(sig_ptr > (USER_SIG_LENGTH - (3 + taglen))):
            sig_fields['error'] = 'overrun'
            return sig_fields
        tagval = sig_bytes[sig_ptr+2:(sig_ptr+2+taglen)]
        tagcrc = sig_bytes[sig_ptr+2+taglen]
        # check the CRC
        computecrc = crc8(bytes([tag, taglen])+bytes(tagval))
        if(tagcrc != computecrc):
            sig_fields['error'] = 'crc mismatch'
            return sig_fields
        # Check if the tag is in our database
        if(tag == SigTag.Serial):
            sig_fields[SigTag.Serial] = tagval
        if(tag == SigTag.HMD_Serial):
            sig_fields[SigTag.HMD_Serial] = tagval
        if(tag == SigTag.Tracking_Serial):
            sig_fields[SigTag.Tracking_Serial] = tagval            
        if(tag == SigTag.Fan_Speed):
            if(taglen != 1):
                sig_fields['error'] = 'tag length mismatch (Fan speed)'
                return sig_fields
            else:
                sig_fields[SigTag.Fan_Speed] = tagval
        if(tag == SigTag.Prox_Cal):
            if(taglen != 2):
                sig_fields['error'] = 'tag length mismatch (prox cal)'
                return sig_fields
            else:
                sig_fields[SigTag.Prox_Cal] = tagval
        if(tag == SigTag.RGB_Color):
            if(taglen != 3):
                sig_fields['error'] = 'tag length mismatch (rgb color)'
                return sig_fields
            else:
                sig_fields[SigTag.RGB_Color] = tagval
        if(tag == SigTag.Prox_Disable):
            if(taglen != 1):
                sig_fields['error'] = 'tag length mismatch (prox disable)'
                return sig_fields
            else:
                sig_fields[SigTag.Prox_Disable] = tagval
        if(tag == SigTag.Linkbox_v1):
            if(taglen != 1):
                sig_fields['error'] = 'tag length mismatch (linkbox v1)'
                return sig_fields
            else:
                sig_fields[SigTag.Linkbox_v1] = tagval
        if(tag == SigTag.FATP_Mode):
            if(taglen != 1):
                sig_fields['error'] = 'tag length mismatch (FATP mode)'
                return sig_fields
            else:
                sig_fields[SigTag.FATP_Mode] = tagval
        # any other tag should be ignored. can still be a valid config region, just means
        # there's a new field type that we don't know about yet
        sig_ptr = sig_ptr + 3 + taglen # length of data, plus the 3 fixed bytes (tag, length, crc)
    return sig_fields

class hidreader(QThread):
    HID_INPUT_REPORT_SIZE = 64
    HID_FEATURE_REPORT_SIZE = 64
    HID_TIMEOUT_MS = 10

    data_received = Signal(bytes)
    hid_disconnected = Signal()

    def __init__(self, hid_device: hid.Device):
        self.dev = hid_device
        self.connected = True
        super().__init__()

    def exit_now(self):
        self.time_to_quit = True

    def write(self, wbytes):
        self.time_to_write = True
        self.stuff_to_write = wbytes

    def run(self):
        # periodically grab the input report from the connected HID
        self.time_to_quit = False
        self.time_to_write = False
        self.stuff_to_write = bytes([])
        # infinite loop. read from the device with timeout
        while(not self.time_to_quit):
            try:
                result = self.dev.read(self.HID_INPUT_REPORT_SIZE,
                                       self.HID_TIMEOUT_MS)
            except(HIDException, OSError):
                self.hid_disconnected.emit()
                self.time_to_quit = True
            if(len(result) != 0):
                # We got something, send it to the UI thread
                self.data_received.emit(result)

            if(self.time_to_write):
                # got a push from the UI thread to write something to the HID
                self.time_to_write = False  # stop after one write
                try:
                    self.dev.send_feature_report(self.stuff_to_write)
                except(HIDException, OSError):
                    self.hid_disconnected.emit()
                    self.time_to_quit = True

class mainwin(QMainWindow):

    def __init__(self):
        super().__init__()
        self.loadui()
        self.connect_buttons()
        self.hmd_disconnected()

    def loadui(self):
        self.widg = prox_calibration_ui.Ui_MainWindow()
        self.widg.setupUi(self)
        self.setWindowTitle('Bigscreen Proximity Calibration')

    def connect_buttons(self):
        self.widg.btnConnect.clicked.connect(self.connect_hmd)
        self.widg.btnAvgAndSave.clicked.connect(self.average_and_save)


    def connect_hmd(self):
        try:
            self.hmd = hid.Device(vid=0x35BD, pid = 0x0101)
            # MOVE THIS LATER
            
            self.hid_thread = hidreader(self.hmd)
            self.hid_thread.hid_disconnected.connect(self.hmd_disconnected)
            self.hid_thread.data_received.connect(self.hmd_data_received)
            self.hid_thread.start()
            self.widg.lblStatus.setText('Status: Connecting to HMD...')
            self.read_user_sig()
        except(HIDException):
            QMessageBox.critical(self, 'Could not connect','Unable to connect to HMD. Please check cable connections.', QMessageBox.Ok, QMessageBox.NoButton)

    def disconnect_hmd(self):
        if(hasattr(self, 'hid_thread')):
            self.hid_thread.exit_now()
            while(self.hid_thread.isRunning()):
                pass
        if(hasattr(self,'hmd')):
            self.hmd.close()
        
        if(self.warn_power_cycle):
            QMessageBox.warning(self,'Power cycle needed','Calibration cleared. Power cycle HMD required before calibrating again!',
            QMessageBox.Ok, QMessageBox.NoButton)
        self.hmd_disconnected()


    def average_and_save(self):
        self.averaging_time = 0
        self.averaging_list = []
        self.widg.lblStatus.setText('Status: Averaging proximity data ({}/10)'.format(self.averaging_time))
        self.averaging = True
        self.averaging_timer = QTimer()
        self.averaging_timer.timeout.connect(self.averaging_tick)
        self.averaging_timer.start(1000)

    def averaging_tick(self):
        self.averaging_time = self.averaging_time + 1
        self.widg.lblStatus.setText('Status: Averaging proximity data ({}/10)'.format(self.averaging_time))
        if(self.averaging_time >= 10):
            self.averaging = False
            self.averaging_timer.stop()
            self.average = sum(self.averaging_list) / len(self.averaging_list)
            self.widg.leProxAverage.setText(str(int(self.average)))
            if(int(self.average) > 300):
                self.average = int(self.average) - 300
            self.user_sig_fields[SigTag.Prox_Cal] = struct.pack("<H", int(self.average))
            self.begin_hmd_save()


    def hmd_data_received(self, newbytes:bytes):
        if(newbytes[0] == ord('#')):
            # New streaming data
            # includes the proximity sensor value
            (hProxVal,) = struct.unpack('>H', newbytes[4:6])
            self.widg.leProxValue.setText(str(hProxVal))
            if(self.averaging):
                self.averaging_list.append(hProxVal)
        elif(newbytes[0] == ord('U')):
            # New signature response
            self.sigreadtimeout.stop()
            user_sig_reply_len = newbytes[1]
            user_sig_reply = newbytes[2:(2+user_sig_reply_len)]
            self.user_sig_data[32*self.current_sig_page: (32*(self.current_sig_page+1))] = user_sig_reply
            if(self.current_sig_page < 15):
                # get the next page
                self.current_sig_page = self.current_sig_page + 1
                self.hid_thread.write(bytes([0, ord('U'), self.current_sig_page]))
                self.sigreadtimeout.start(1000)
            else:
                # finished reading signature (all 16 blocks of 32 bytes)
                # check the signature before continuing
                self.user_sig_fields = parse_sig(self.user_sig_data)
                if('error' in self.user_sig_fields):
                    QMessageBox.warning(self,'Error in config data','Configuration data on HMD had an error: {}'.format(self.user_sig_fields['error']),
                    QMessageBox.Ok, QMessageBox.NoButton)
                if(SigTag.Prox_Cal in self.user_sig_fields):
                    existing_prox_value = struct.unpack('<H',self.user_sig_fields[SigTag.Prox_Cal])[0]
                    QMessageBox.warning(self,'Prox calibration done','Proximity is already calibrated (value = {})'.format(existing_prox_value),
                    QMessageBox.Ok, QMessageBox.NoButton)
                    self.disconnect_hmd()
                    return

                self.widg.lblStatus.setText("Status: HMD Connected")
                self.widg.btnAvgAndSave.setEnabled(True)
                # let's increase the data response frequency
                self.hid_thread.write(bytes([0, ord('R')])+struct.pack('>H', 100)) # 100ms rate

        elif(newbytes[0] == ord('$')):
            # Got an "ACK" from the HMD for the last packet sent
            if(self.writing_signature):
                self.sigwritetimeout.stop()
                if(self.current_sig_page < 15):
                    # continue to write the next page / block of 32 bytes
                    self.current_sig_page = self.current_sig_page + 1
                    self.hid_thread.write(bytes([0, ord('W'), self.current_sig_page])
                    + self.new_user_sig_data[32*self.current_sig_page:32*(self.current_sig_page+1)])
                    # and restart the timeout
                    self.sigwritetimeout.start(1000)
                elif(self.current_sig_page == 15):
                    # done with writing, now need to save to NVM
                    self.current_sig_page = self.current_sig_page + 1
                    self.hid_thread.write(bytes([0, ord('V')]))
                    self.sigwritetimeout.start(1000)
                else: # self.current_sig_page >= 16
                    self.writing_signature = False

                    self.restart_timer = QTimer()
                    self.restart_timer.setSingleShot(True)
                    self.restart_timer.timeout.connect(self.disconnect_hmd)
                    if(self.warn_power_cycle):
                        self.widg.lblStatus.setText("Status: Calibration cleared")
                        self.restart_timer.start(300)
                    else:
                        self.widg.lblStatus.setText("Status: Calibration complete!")
                        self.restart_timer.start(1500)
        else:
            print('some other reply:')
            print(newbytes)

    def hmd_disconnected(self):
        self.warn_power_cycle = False
        self.writing_signature = False
        self.averaging = False
        self.widg.lblStatus.setText('Status: Not Connected')
        self.cleartimer = QTimer()
        self.cleartimer.setSingleShot(True)
        self.cleartimer.timeout.connect(self.clear_fields)
        self.cleartimer.start(300)
    def clear_fields(self):
        self.widg.leProxValue.setText('---')
        self.widg.leProxAverage.setText('---')
        self.widg.btnAvgAndSave.setEnabled(False)

    def read_user_sig(self):
        # starts a brief timer (30ms) to delay the 
        # HMD read function. this is just to ensure
        # that the hid_thread is running 
        self.hmdreadstarttimer = QTimer()
        self.hmdreadstarttimer.setSingleShot(True)
        self.hmdreadstarttimer.timeout.connect(self.begin_hmd_read)
        self.hmdreadstarttimer.start(100)

    def begin_hmd_read(self):
        self.user_sig_data = bytearray([0]*USER_SIG_LENGTH)
        self.current_sig_page = 0
        self.hid_thread.write(bytes([0, ord('U'), self.current_sig_page]))
        # Also include a timeout in case we don't get a reply
        self.sigreadtimeout = QTimer()
        self.sigreadtimeout.setSingleShot(True)
        self.sigreadtimeout.timeout.connect(self.sig_read_write_timed_out)
        self.sigreadtimeout.start(1000)

    def sig_read_write_timed_out(self):
        self.widg.lblStatus.setText("Status: HMD response timed out")
        QMessageBox.critical(self, 'Timed out', 'HMD timed out while reading or writing configuration data',
        QMessageBox.Ok, QMessageBox.NoButton)
        self.hid_thread.exit_now()
        while(self.hid_thread.isRunning()):
            pass
        self.hmd.close()

    def begin_hmd_save(self):
        self.new_user_sig_data = create_signature(self.user_sig_fields)
        self.current_sig_page = 0
        self.hid_thread.write(bytes([0, ord('W'), self.current_sig_page])
        + self.new_user_sig_data[32*self.current_sig_page:32*(self.current_sig_page+1)])
        self.writing_signature = True
        # Start a timeout as well
        self.sigwritetimeout = QTimer()
        self.sigwritetimeout.setSingleShot(True)
        self.sigwritetimeout.timeout.connect(self.sig_read_write_timed_out)
        self.sigwritetimeout.start(1000)



if __name__=='__main__':
    qapp = QApplication()
    mwin = mainwin()
    mwin.show()
    sys.exit(qapp.exec())