import sys
import os
import time
import enum
import struct
import statistics

from PySide6.QtWidgets import QApplication, QMainWindow, QMessageBox, QFileDialog
from PySide6.QtCore import QTimer, QThread, Signal, Qt
from PySide6.QtCharts import QLineSeries, QChart

import hid
from hid import HIDException

import prox_config_ui

import ctypes
from ctypes.wintypes import HANDLE, UINT, WPARAM, LPARAM, DWORD, POINT
WM_DEVICECHANGE = 0x0219
class MSG(ctypes.Structure):
    '''
    WinUser MSG - Contains message information from a thread's message queue.
    typedef struct tagMSG {
        HWND   hwnd;
        UINT   message;
        WPARAM wParam;
        LPARAM lParam;
        DWORD  time;
        POINT  pt;
        DWORD  lPrivate;
    } MSG, *PMSG, *NPMSG, *LPMSG;
    '''
    _pack_ = 1
    _fields_ = [('hwnd', HANDLE),
                ('message', UINT),
                ('wparam', WPARAM),
                ('lparam', LPARAM),
                ('time', DWORD),
                ('pt', POINT),
                ('lPrivate', DWORD)]

BIGSCREEN_VID = 0x35BD
BEYOND_PID = 0x0101

class hidreader(QThread):
    HID_INPUT_REPORT_SIZE = 64
    HID_FEATURE_REPORT_SIZE = 64
    HID_TIMEOUT_MS = 10

    data_received = Signal(bytes)
    hid_disconnected = Signal()

    def __init__(self, hid_device: hid.Device):
        self.dev = hid_device
        self.connected = True
        super().__init__()

    def exit_now(self):
        self.time_to_quit = True

    def write(self, wbytes):
        self.time_to_write = True
        self.stuff_to_write = wbytes

    def run(self):
        # periodically grab the input report from the connected HID
        self.time_to_quit = False
        self.time_to_write = False
        self.stuff_to_write = bytes([])
        # infinite loop. read from the device with timeout
        while(not self.time_to_quit):
            try:
                result = self.dev.read(self.HID_INPUT_REPORT_SIZE,
                                       self.HID_TIMEOUT_MS)
            except(HIDException, OSError):
                self.hid_disconnected.emit()
                self.time_to_quit = True
            if(len(result) != 0):
                # We got something, send it to the UI thread
                self.data_received.emit(result)

            if(self.time_to_write):
                # got a push from the UI thread to write something to the HID
                self.time_to_write = False  # stop after one write
                try:
                    self.dev.send_feature_report(self.stuff_to_write)
                except(HIDException, OSError):
                    self.hid_disconnected.emit()
                    self.time_to_quit = True


class mainwin(QMainWindow):
    PROX_CHART_LENGTH = int(100)
    MAX_PROX_MOVING_AVERAGE_LENGTH = int(100)
    # Combo box values
    PGAINs = ['1x','2x','4x','8x']
    PPULSE_LENs = ['1us','2us','4us','8us','12us','16us','24us','32us']
    PMAVGs = ['None','2x','4x','8x']
    PROX_AVGs = ['None','2x','4x','8x','16x','32x','64x','128x']

    def __init__(self):
        super().__init__()
        self.load_ui()
        self.populate_combo_boxes()
        self.connect_buttons_and_actions()
        self.create_prox_chart()
        self.pwlong = False
        self.beyond = None
        self.averages = [0]*self.MAX_PROX_MOVING_AVERAGE_LENGTH
        self.average_loc = 0
        self.averaging_paused = False
        self.prox_averaging_length = 16
        self.logging = False
        self.attempt_connect()

    def load_ui(self):
        self.widg = prox_config_ui.Ui_MainWindow()
        self.widg.setupUi(self)
        self.setWindowTitle("Proximity Sensor Configuration Utility")

    def create_prox_chart(self):
        self.lineseries = QLineSeries()
        self.proxdata = [0]*self.PROX_CHART_LENGTH
        for i in range(len(self.proxdata)):
            self.lineseries.append(i, self.proxdata[i])
        self.chart = QChart()
        self.chart.legend().hide()
        self.chart.addSeries(self.lineseries)
        self.chart.createDefaultAxes()
        self.chart.axes(orientation=Qt.Vertical)[0].setMax(17000)
        self.chart.axes(orientation=Qt.Vertical)[0].setMin(0)
        #self.chart.setTitle('Proximity Sensor Data')
        self.widg.proxChart.setChart(self.chart)

    def populate_combo_boxes(self):
        for gain in self.PGAINs:
            self.widg.comboPGAIN.addItem(gain)
        for pulse_len in self.PPULSE_LENs:
            self.widg.comboPPULSELEN.addItem(pulse_len)
        for mavg in self.PMAVGs:
            self.widg.comboPMAVG.addItem(mavg)
        for prox_avg in self.PROX_AVGs:
            self.widg.comboPROXAVG.addItem(prox_avg)
        # set defaults
        self.widg.comboPGAIN.setCurrentIndex(2)
        self.widg.comboPPULSELEN.setCurrentIndex(3)

    def connect_buttons_and_actions(self):
        self.widg.chkPWLONG.clicked.connect(self.pwlong_clicked)
        self.widg.spinPRATE.valueChanged.connect(self.prate_changed)
        self.widg.spinPWTIME.valueChanged.connect(self.pwtime_changed)
        self.widg.btnSend.clicked.connect(self.send_to_hmd)
        self.widg.btnPause.clicked.connect(self.pause_averaging)
        self.widg.spinAveragingLength.valueChanged.connect(self.averaging_length_changed)
        self.widg.btnLogging.clicked.connect(self.start_stop_logging)

    def averaging_length_changed(self):
        self.averages = [0]*self.MAX_PROX_MOVING_AVERAGE_LENGTH
        self.prox_averaging_length = self.widg.spinAveragingLength.value()

    def pwlong_clicked(self, is_clicked:bool):
        self.pwlong = is_clicked
        wait_time = (self.widg.spinPWTIME.value() + 1)*2.78
        if(self.pwlong):
            wait_time = wait_time*12
        self.widg.lblPWTIME.setText('{:.2f} ms'.format(wait_time))

    def prate_changed(self, new_val: int):
        prox_duration = (new_val + 1)*88
        if(prox_duration < 1000):
            self.widg.lblPRATE.setText('{} us'.format(prox_duration))
        else:
            self.widg.lblPRATE.setText('{:.2f} ms'.format(float(prox_duration)/1000.0))

    def pwtime_changed(self, new_val: int):
        wait_time = (new_val + 1)*2.78
        if(self.pwlong):
            wait_time = wait_time*12
        if(wait_time < 1000):
            self.widg.lblPWTIME.setText('{:.2f} ms'.format(wait_time))
        else:
            self.widg.lblPWTIME.setText('{:.2f} s'.format(wait_time/1000.0))

    def attempt_connect(self):
        # Try to find the HMD and connect to it by HID
        if(self.beyond is None):
            try:
                self.beyond = hid.Device(vid = BIGSCREEN_VID, pid = BEYOND_PID)
                self.statusBar().showMessage('Connected to HMD')
                self.hmd_thread = hidreader(self.beyond)
                self.hmd_thread.data_received.connect(self.new_hid_data)
                self.hmd_thread.hid_disconnected.connect(self.hid_detached)
                self.hmd_thread.start()
                # Start a little one shot timer before sending data to the HMD
                # This allows the hidthread to start properly
                self.startup_timer = QTimer()
                self.startup_timer.setSingleShot(True)
                self.startup_timer.timeout.connect(self.set_hid_rate)
                self.startup_timer.start(100)
            except hid.HIDException:
                # Could not connect, probably means not plugged in
                self.statusBar().showMessage('HMD not connected')

    def pause_averaging(self):
        if(self.averaging_paused):
            self.averaging_paused = False
            self.widg.btnPause.setText('Pause Averaging')
        else:
            self.averaging_paused = True
            self.widg.btnPause.setText('Resume Averaging')

    def start_stop_logging(self):
        if(not self.logging):
            # Popup for file selection
            (self.log_file_name, selectedFilter) = QFileDialog.getSaveFileName(self, "Select Log File", os.getcwd(), filter=r"CSV files (*.csv)")
            if(len(self.log_file_name) > 0):
                self.widg.btnLogging.setText("Stop Logging")
                self.logging = True
                self.log_file_handle = open(self.log_file_name, 'a')
                self.log_file_handle.write(time.ctime() + ", \n")
        else:
            self.log_file_handle.close()
            self.logging = False
            self.widg.btnLogging.setText("Start Logging")

    def new_hid_data(self, newbytes:bytes):
        if(newbytes[0] == ord('#')):
            # This is a streaming data packet
            # We only need to unpack the proximity sensor value
            # Packet format:
            # Byte -    Description
            # 0         header ('#')
            # 1         length of data (should be 20)
            # 2,3       Fan speed (MSB first for all of these uint16 values)
            # 4,5       Proximity distance
            # 6,7       CC1 ADC value
            # 8,9       CC2 ADC value
            # 10-13     board temperature (float, little-endian)
            # 14-17     left oled temperature (float, little-endian)
            # 18-21     right oled temperature (float, little-endian)
            prox_value = struct.unpack('>H', newbytes[4:6])[0]
            self.proxdata = self.proxdata[1:] + [prox_value]
            self.lineseries.removePoints(0, len(self.proxdata))
            for i in range(len(self.proxdata)):
                self.lineseries.append(i, self.proxdata[i])

            # Averaging this value, rolling over 16 samples
            # Output of averaging is the mean of the previous 16 samples
            # Samples are loaded into a 16-slot long array that wraps around
            # Simply takes the mean of the array
            self.averages[self.average_loc] = prox_value
            self.average_loc = self.average_loc + 1
            if(self.average_loc >= self.prox_averaging_length):
                self.average_loc = 0
            if(not self.averaging_paused):
                self.widg.leAvgVal.setText(str(int(statistics.mean(self.averages[:self.prox_averaging_length]))))
                self.widg.leStdDev.setText(str(statistics.stdev(self.averages[:self.prox_averaging_length])))
            if(self.logging):
                self.log_file_handle.write(str(prox_value)+", \n")

    def hid_detached(self):
        self.statusBar().showMessage('HMD not connected')
        self.beyond = None
        
    def set_hid_rate(self):
        # Control passes here once the startup timer expires after HID thread started
        # Check if still connected. The disconnect callback could have occurred
        if(self.beyond is not None):
            # set new data response rate
            self.hmd_thread.write(bytes([0, ord('R')])+struct.pack('>H', 100)) # 100ms rate

    def send_to_hmd(self):
        # Byte -    Description
        # 1         PWEN (bool) wait time enable
        # 2         PRATE (byte) time between each group of pulses when hardware averaging enabled
        # 3         PWLONG (bool) wait time is increased by 12x
        # 4         PGAIN (2 bits) photodiode amplifier gain setting, allowed 0 to 3 (maps to 1x, 2x, 4x, or 8x)
        # 5         PPULSE (6 bits) maximum number of pulses in one measurement
        # 6         PPULSELEN (3 bits) proximity pulse length
        # 7         PLDRIVE (4 bits) laser drive current, allowed values 0 to 8, which means 2mA to 10mA
        # 8         PWTIME (byte) wait time between measurement cycles
        # 9         PDSELECT (2 bits) choice of near, far, or both photodiodes
        # 10        PMAVG (2 bits) moving average filter, applied after each measurement cycle
        # 11        PROX_AVG (3 bits) hardware averaging, uses multiple groups of pulses in each measurement cycle

        if(self.widg.rbNear.isChecked()):
            pd_select = 2
        elif(self.widg.rbFar.isChecked()):
            pd_select = 1
        else:
            pd_select = 3
        newbytes = struct.pack('?B?BBBBBBBB',
            self.widg.chkPWEN.isChecked(),
            self.widg.spinPRATE.value(),
            self.widg.chkPWLONG.isChecked(),
            self.widg.comboPGAIN.currentIndex(),
            self.widg.spinPPULSE.value() - 1, # spinbox is 1 to 64, value sent to HMD is 0 to 63
            self.widg.comboPPULSELEN.currentIndex(),
            self.widg.spinPLDRIVE.value() - 2, # spinbox is 2 to 10, value sent to HMD is 0 to 8
            self.widg.spinPWTIME.value(),
            pd_select,
            self.widg.comboPMAVG.currentIndex(),
            self.widg.comboPROXAVG.currentIndex())
        # Testing - just print out the bytes that would have been sent
        # print(', '.join(['{:02X}'.format(bb) for bb in newbytes]))
        if(self.beyond is None):
            # No HMD connected, can't send. Show warning
            QMessageBox.warning(self, 'Not connected', 'HMD is not connected. Cannot send proximity configuration.',
                QMessageBox.Ok, QMessageBox.NoButton)
        else:
            self.hmd_thread.write(bytes([0, ord('M')]) + newbytes)

       # Overrides window native event, which can capture
    # device change notifications from the OS
    # This can be used to refresh the usb tree when a device is added/removed.
    def nativeEvent(self, eventType, message):

        # As far as I can tell, there is no intrinsic method to retrieve the
        # pointer address from a shiboken VoidPtr type
        # Thankfully, the string representation includes the address in its
        # text. The next line strips the address out of this text
        # Example: 'shiboken6.shiboken6.VoidPtr(Address 0x00000061BCFEACB0, Size 0, isWritable False)'
        try:
            address = int(message.__repr__().split('shiboken6.shiboken6.VoidPtr(')[1].split(',')[0], 16)
        except IndexError:
            address = int(message.__repr__().split('shiboken6.Shiboken.VoidPtr(')[1].split(',')[0], 16)
        newptr = ctypes.c_void_p(address)
        msg_struct_ptr = ctypes.cast(newptr, ctypes.POINTER(MSG))
        msg_struct = msg_struct_ptr.contents
        # print('Message: {}, wParam: {}, lParam: {}'.format(msg_struct.message, msg_struct.wparam, msg_struct.lparam))
        if(msg_struct.message == WM_DEVICECHANGE):
            #print('Status: New thing!')
            self.attempt_connect()
        return super().nativeEvent(eventType, message)     

    # Override close event to make sure the HID thread is safely stopped
    def closeEvent(self, event):
        if(hasattr(self, 'hmd_thread')):
            if(self.hmd_thread.isRunning()):
                self.hmd_thread.exit_now()
                while(self.hmd_thread.isRunning()):
                    pass
        if(hasattr(self, "log_file_handle")):
            if(not self.log_file_handle.closed):
                self.log_file_handle.close()
        return super().closeEvent(event)

if __name__ == "__main__":
    myapp = QApplication([])
    widget = mainwin()
    widget.show()
    sys.exit(myapp.exec())