from PySide6.QtWidgets import QApplication, QWidget, QDialog, QColorDialog, QMessageBox
from PySide6.QtGui import QColor, QPalette
from PySide6.QtCore import QFile, QTimer, QThread, Signal, Slot
import serial
import serial.tools.list_ports as serial_lp
import time
import struct
import sys

from uart_testgui_form import Ui_Widget
import uart_packet as pkt

DEFAULT_TIMEOUT_NS = int(1e9*0.5) # half-second timeout

class uartthread(QThread):
    complete = Signal(int, int, bytes)

    def __init__(self, comport, command_byte, command_data = [], timeout_ns = DEFAULT_TIMEOUT_NS, verbose = False):
        self.com = comport
        self.command_packet = pkt.DataPacket(command_byte, command_data)
        self.timeout_ns = timeout_ns
        self.verbose = verbose
        super().__init__()

    def run(self):
        # Send the command and wait for a reply
        # If no valid packet is received before the timeout, quit early
        # and send the "complete" signal with a zero-length bytes 
        self.got_packet = False
        starttime = time.time_ns()
        self.com.write(self.command_packet.create_packet())
        rec_buf = b''
        while((time.time_ns() - starttime < self.timeout_ns) and (not self.got_packet)):
            # Grab all bytes current at the comport
            avail_bytes = self.com.in_waiting
            if(avail_bytes > 0):
                rec_buf = rec_buf + self.com.read(avail_bytes)
                if(self.verbose):
                    print('currently in the buffer: {}'.format(rec_buf))
                # Check the buffer for a valid packet
                (errcode, startpos, retpkt) = pkt.DataPacket.extract_packet(rec_buf)
                if(errcode == pkt.DataPacket.PKRES_SUCCESS):
                    if(self.verbose):
                        print('got a packet')
                    self.complete.emit(errcode, retpkt.PktType, bytes(retpkt.PktData))
                    self.got_packet = True
                elif(errcode == pkt.DataPacket.PKRES_NO_PACKET):
                    # continue waiting
                    # clear unnecessary bytes from the buffer
                    if(startpos != -1):
                        rec_buf = rec_buf[startpos:]
                else:
                    # some other error has occurred, quit and tell the calling function
                    self.complete.emit(errcode, 0, b'')
                    self.got_packet = True
        if(not self.got_packet):
            if(self.verbose):
                print('packet timeout')
            self.complete.emit(pkt.DataPacket.PKRES_NO_PACKET, 0, b'')

class testrunner(QThread):
    test_list = ['USB Hub', 'RGB LED', 'USB-C Switch', 'VXR Bridge', 'Proximity', 'OLED', 'Fan']
    test_keys = {'USB Hub':ord('H'), 
                 'RGB LED':ord('L'),
                 'USB-C Switch':ord('S'),
                 'VXR Bridge':ord('V'),
                 'Proximity':ord('P'),
                 'OLED':ord('O'),
                 'Fan':ord('F') }
    complete = Signal(dict) # emitted when tests are finished. Keys are the testrunner.test_list.
                            # Results are True if pass, False if failed, None if skipped

    one_test_done = Signal(dict)
    # tests: dict with testrunner.test_list as the keys. True if test should be run, False to skip
    def __init__(self, comport, tests, timeout_ns = DEFAULT_TIMEOUT_NS):
        self.com = comport
        self.tests = tests
        self.timeout = timeout_ns
        self.pkt_err = None
        self.pkt_type = None
        self.pkt_data = None
        super().__init__()
    def run(self):
        self.test_results = {}
        for test in self.tests:
            if(self.tests[test]):
                # print('Running '+ test)
                self.test_results[test] = False
                my_uart_thread = uartthread(self.com, ord('J'),bytes([self.test_keys[test]]),self.timeout)
                my_uart_thread.complete.connect(self.uart_thread_complete)
                self.pkt_err = None # as a flag for the uart thread to complete
                my_uart_thread.start()
                # Wait until thread is done (test complete)
                while(self.pkt_err is None):
                    pass
                while(my_uart_thread.isRunning()):
                    pass
                # print(test + ' uart response complete. error type: {}'.format(self.pkt_err))
                # done or timed out, let's check
                if(self.pkt_err == pkt.DataPacket.PKRES_SUCCESS):
                    if(self.pkt_type == ord('J')):
                        if(len(self.pkt_data) >= 1):
                            if(self.pkt_data[0] == 0x01):
                                #print(test + ' passed.')
                                self.test_results[test] = True
                                self.one_test_done.emit({test:True})
                            else:
                                #print(test + ' failed.')
                                self.one_test_done.emit({test:False})
                                pass
                    else:
                        # print(test + ' reply type was {}'.format(self.pkt_type))
                        self.one_test_done.emit({test:False})
                        pass
                else:
                    # print(test + ' no packet response. Error code: {}'.format(self.pkt_err))
                    self.one_test_done.emit({test:False})
                    pass
            else:
                # skipping this test
                #print(test + ' skipped.')
                self.one_test_done.emit({test:None})
                self.test_results[test] = None

        self.complete.emit(self.test_results)

    def uart_thread_complete(self, pkt_err, pkt_type, pkt_data):
        self.pkt_err = pkt_err
        self.pkt_type = pkt_type
        self.pkt_data = pkt_data

class mainwin(QWidget):
    def __init__(self):
        super(mainwin, self).__init__()
        self.current_device_list = []
        self.comport = None
        self.load_ui()        
        # Helper lists
        self.test_checkboxes = [self.widg.chkTestHub, self.widg.chkTestRGB, self.widg.chkTestUSBC,
                        self.widg.chkTestVxr, self.widg.chkTestProx, self.widg.chkTestOled, self.widg.chkTestFan]
        self.test_results = [self.widg.lblResultHub, self.widg.lblResultRGB, self.widg.lblResultUSBC,
                        self.widg.lblResultVxr, self.widg.lblResultProx, self.widg.lblResultOled, self.widg.lblResultFan]
        self.test_checkboxes_dict = {testrunner.test_list[i] : self.test_checkboxes[i] for i in range(len(testrunner.test_list))}
        self.test_results_dict = {testrunner.test_list[i] : self.test_results[i] for i in range(len(testrunner.test_list))}
        self.link_actions()
        self.rescan_comports()
        self.reset_test_results()

    def load_ui(self):
        self.widg = Ui_Widget()
        self.widg.setupUi(self)
        self.setWindowTitle('Bigscreen Displayboard Tester')

    def link_actions(self):
        self.widg.btnConnect.clicked.connect(self.connect_comport)
        self.widg.btnRescan.clicked.connect(self.rescan_comports)
        self.widg.btnColorPick.clicked.connect(self.launch_color_picker)
        self.widg.btnSelectAll.clicked.connect(self.select_all_tests)
        self.widg.btnSelectNone.clicked.connect(self.unselect_all_tests)
        self.widg.btnRunTests.clicked.connect(self.run_tests)
        self.widg.btnResetTests.clicked.connect(self.reset_test_results)

    def launch_color_picker(self):
        self.color_dialog = QColorDialog()
        self.color_dialog.colorSelected.connect(self.set_rgb_color)
        self.color_dialog.open()
        self.newcolor = self.color_dialog.currentColor()

    def set_all_checkboxes(self, checked):
        for chkbox in self.test_checkboxes:
            chkbox.setChecked(checked)

    def select_all_tests(self):
        self.set_all_checkboxes(True)

    def unselect_all_tests(self):
        self.set_all_checkboxes(False)

    def connect_comport(self):
        if(len(self.current_device_list) > 0):
            selected_port = self.widg.listPorts.currentRow()
            if(selected_port != -1):
                self.comport = serial.Serial(self.current_device_list[selected_port].device, 115200, timeout=1)
                self.get_software_version()
                
    def rescan_comports(self):
        all_ports = serial_lp.comports()
        # Save the currently selected port
        selected_port = None
        if(len(self.current_device_list) > 0):
            selected_id = self.widg.listPorts.currentRow()
            # if result is -1, then nothing was selected
            if(selected_id != -1):
                selected_port = self.current_device_list[selected_id].device
        self.widg.listPorts.clear()
        for comport in all_ports:
            self.widg.listPorts.addItem('{}'.format(comport.description))

            # reselect the port that was selected before
            if(selected_port is not None):
                # set to the current row if it matches the port
                if(comport.device == selected_port):
                    self.widg.listPorts.setCurrentRow(self.widg.listPorts.count()-1)
        self.current_device_list = all_ports

    def get_software_version(self):
        if((self.comport is not None) and (self.comport.is_open)):
            self.uart_thread = uartthread(self.comport, ord('*'), [])
            self.uart_thread.complete.connect(self.get_software_version_callback)
            self.uart_thread.start()
        else:
            self.set_status('Com port not connected')
    def get_software_version_callback(self, error_code, packet_type, packet_data):
        got_sw_ver = False
        if(error_code == pkt.DataPacket.PKRES_SUCCESS):
            if(packet_type == ord('*')):
                # Got the correct response
                self.swver = bytes(packet_data).decode('ascii').strip().strip('\x00')
                self.widg.lineSwVer.setText(self.swver)
                got_sw_ver = True
                self.get_serial_number()

        if(not got_sw_ver):
            # Somehow it failed. Not connected.
            self.set_status('Error: Could not read software version.')
            QMessageBox.critical(self, 'Communication failure','Could not read software version from HMD.', QMessageBox.Ok, QMessageBox.NoButton)
            self.comport.close()

    def get_serial_number(self):
        if((self.comport is not None) and (self.comport.is_open)):
            self.uart_thread = uartthread(self.comport, ord('%'), [])
            self.uart_thread.complete.connect(self.get_serial_number_callback)
            self.uart_thread.start()
    def get_serial_number_callback(self, error_code, packet_type, packet_data):
        if(error_code == pkt.DataPacket.PKRES_SUCCESS):
            if(packet_type == ord('%')):
                self.serialnum = bytes(packet_data).decode('ascii').strip().strip('\x00')
                self.widg.lineSerial.setText(self.serialnum)
            elif(packet_type == ord('E')):
                # Serial number not available, but at least the HMD is communicating
                self.serialnum = None
                self.widg.lineSerial.setText('--no serial number set--')
            self.set_status('Connected!')
        else:
            # Something went wrong. No reply or a badly formed packet.
            self.set_status('Error: No reply to serial number query.')

    def set_rgb_color(self, newcolor):
        if((self.comport is not None) and (self.comport.is_open)):
            # send the command and wait for a response
            self.uart_thread = uartthread(self.comport, ord('L'), [newcolor.red(),newcolor.green(),newcolor.blue()])
            self.uart_thread.complete.connect(self.set_rgb_color_callback)
            self.uart_thread.start()
    def set_rgb_color_callback(self, error_code, packet_type, packet_data):
        if(error_code == pkt.DataPacket.PKRES_SUCCESS):
            if(packet_type == ord('$')):
                self.set_status('new color set')
            else:
                self.set_status('failed to set color')
        else:
            self.set_status('failed to set color')

    def run_tests(self):
        # Create the dict for the tests to run
        self.reset_test_results()
        test_run_dict = {}
        for i in range(len(testrunner.test_list)):
            test_run_dict[testrunner.test_list[i]] = self.test_checkboxes[i].isChecked()
        self.testrunner_thread = testrunner(self.comport, test_run_dict, int(2e9)) # 2 second timeout, might need for fan test
        self.testrunner_thread.complete.connect(self.tests_complete_callback)
        self.testrunner_thread.one_test_done.connect(self.each_test_complete_callback)
        self.testrunner_thread.start()
    def each_test_complete_callback(self, test_result):
        for key, result in test_result.items():
            if(result is not None):
                if(result):
                    self.result_show_pass(self.test_results_dict[key])
                else:
                    self.result_show_fail(self.test_results_dict[key])
            else:
                self.result_show_not_run(self.test_results_dict[key])
    def tests_complete_callback(self, test_results):
        self.set_status("Tests completed")
        #for i in range(len(test_results)):
        #    if(test_results[testrunner.test_list[i]] is not None):
        #        if(test_results[testrunner.test_list[i]]):
        #            self.result_show_pass(self.test_results[i])
        #        else:
        #            self.result_show_fail(self.test_results[i])
        #    else:
        #        self.result_show_not_run(self.test_results[i])

    def reset_test_results(self):
        for test_result in self.test_results:
            self.result_show_not_run(test_result)
    def result_show_not_run(self, line_edit):
        line_edit.setStyleSheet("QLabel { background-color : rgb(192,192,192) }")
        line_edit.setText('-n/a-')
    def result_show_pass(self, line_edit):
        line_edit.setStyleSheet("QLabel { background-color : rgb(192,255,192) }")
        line_edit.setText('PASS')
    def result_show_fail(self, line_edit):
        line_edit.setStyleSheet("QLabel { background-color : rgb(255,192,192) }")
        line_edit.setText('FAIL')

    def set_status(self, status_str, temporary = False):
        if(not temporary):
            self.widg.statusbar.setText("Status: "+status_str)
        else:
            self.widg.statusbar.setText("Status(T): "+status_str)

if __name__ == "__main__":
    app = QApplication([])
    widget = mainwin()
    widget.show()
    sys.exit(app.exec())
