import hid
import time
from typing import NamedTuple
import struct
from enum import Enum, auto, IntEnum

DEFAULT_HID_TIMEOUT = 10
BIGSCREEN_VID = 0x35BD
BEYOND_PID = 0x0101

# Analog sensing on the board use these values
ADC_REF_VOLTS = 3.3
ADC_MAX_COUNTS = 4095
# Inactive - pin voltage will be close to zero
#           if the CC wire is not connected to this pin
CC_INACTIVE_THRESH = 0.05 / ADC_REF_VOLTS * ADC_MAX_COUNTS
# Wrong way - pin voltage will be too high if the linkbox
#           end of the cable is plugged in flipped
CC_WRONG_WAY_THRESH = 3.0 / ADC_REF_VOLTS * ADC_MAX_COUNTS

class USBC_Flip_State(Enum):
    Unflipped = auto() # For sure "Unflipped" is a technical term. Trust me.
    Flipped = auto()
    Invalid = auto()

class Hid_Message(IntEnum):
    FAN_SPEED = ord('F')
    REPORT_RATE = ord('R')
    DATA_MESSAGE = ord('#')
    SUCCESS = ord('$')
    ERROR = ord('E')

class BeyondDataPacket(NamedTuple): 
    fan_speed: int
    prox_distance: int
    cc1_val: int
    cc2_val: int
    board_temp: float
    disp1_temp: float
    disp2_temp: float

def find_beyond_device():
    devices = hid.enumerate()
    beyond_devices = [d for d in devices if d['vendor_id'] == BIGSCREEN_VID and d['product_id'] == BEYOND_PID]
    if beyond_devices:
        return beyond_devices[0]['path']
    return None

def wait_for_response(beyond: hid.device, message_types: list[int], timeout_ms: int = 100) -> bytes:
    start_time = time.monotonic_ns()
    # Make sure the HID device is open
    if not beyond:
        return b''

    while( (start_time + (timeout_ms*1000000)) > time.monotonic_ns()):
        bytesout = beyond.read(65)
        if(len(bytesout) > 0):
            if(bytesout[0] in message_types):
                return bytes(bytesout)
    
    return b''

# Note that the startup data rate from the Beyond is once per second
# That's why the default timeout is so long
def get_periodic_data_packet(beyond: hid.device, timeout_ms: int = 1200) -> BeyondDataPacket | None:
    rawdata = wait_for_response(beyond, [Hid_Message.DATA_MESSAGE], timeout_ms=timeout_ms)

    if(len(rawdata) > 0):
        if(rawdata[0] == Hid_Message.DATA_MESSAGE):
            intdatavals = struct.unpack('>HHHH',rawdata[2:10])
            floatdatavals = struct.unpack('<fff',rawdata[10:22])
            # board temp is in kelvin, but oleds are in celsius
            board_temp = floatdatavals[0] - 273.15

            return BeyondDataPacket(*intdatavals, board_temp, floatdatavals[1], floatdatavals[2])
    return None

def get_usbc_flip(num_averages: int = 1) -> USBC_Flip_State:
    if(num_averages < 1):
        num_averages = 1

    device_path = find_beyond_device()
    if not device_path:
        return USBC_Flip_State.Invalid
    try:
        bynd = hid.device()
        bynd.open_path(device_path)
    except:
        return USBC_Flip_State.Invalid
    
    # If it is taking too long, try upping the Beyond's periodic data rate
    # uncomment the following 2 lines to change it to 50ms:
    # bynd.send_feature_report(bytes([0, Hid_Message.REPORT_RATE, 0, 50]))
    # hid_reply = wait_for_response(bynd, [Hid_Message.SUCCESS])

    # Perform data averaging
    # The data should be pretty stable though, so this might be unnecessary
    cc1val = 0
    cc2val = 0
    for _ in range(num_averages):
        try:
            pkt = get_periodic_data_packet(bynd)
        except:
            bynd.close()
            return USBC_Flip_State.Invalid
        if(pkt == None):
            bynd.close()
            return USBC_Flip_State.Invalid
        cc1val += pkt.cc1_val
        cc2val += pkt.cc2_val
    cc1val /= num_averages
    cc2val /= num_averages
    bynd.close()

    # Check if values are valid first, then see which is the active line
    if(cc1val >= CC_WRONG_WAY_THRESH and cc2val <= CC_INACTIVE_THRESH):
        return USBC_Flip_State.Invalid
    
    if(cc2val >= CC_WRONG_WAY_THRESH and cc1val <= CC_INACTIVE_THRESH):
        return USBC_Flip_State.Invalid
    
    # okay, things are good. return whichever has the higher voltage
    if(cc1val > cc2val):
        return USBC_Flip_State.Unflipped
    if(cc2val > cc1val):
        return USBC_Flip_State.Flipped
    
    # If we got here, huh? Return error.
    return USBC_Flip_State.Invalid

if __name__ == '__main__':
    print(get_usbc_flip())