import hid
import time
import typing
from enum import IntEnum

DEFAULT_HID_TIMEOUT = 10
BIGSCREEN_VID = 0x35BD
BEYOND_PID = 0x0101

class Hid_Message(IntEnum):
    SW_VER = ord('*')
    SERIAL_NUM = ord('%')
    HMD_SERIAL_NUM = ord('&')
    OLED_SERIAL_NUM = ord('^')
    SUCCESS = ord('$')
    ERROR = ord('E')

def find_beyond_device():
    """Find the Bigscreen Beyond HID device by enumerating all devices"""
    devices = hid.enumerate()
    beyond_devices = [d for d in devices if d['vendor_id'] == BIGSCREEN_VID and d['product_id'] == BEYOND_PID]
    if beyond_devices:
        # Return the first device found (should be the correct interface)
        return beyond_devices[0]['path']
    return None

def wait_for_response(beyond:hid.device, message_types:list[int], timeout_ms:int = 1000) -> bytes:
    start_time = time.monotonic_ns()
    
    while( (start_time + (timeout_ms*1000000)) > time.monotonic_ns()):
        bytesout = beyond.read(65)
        if(len(bytesout) > 0):
            if(bytesout[0] in message_types):
                return bytesout
    
    return b''

def get_software_version(timeout_ms: int = 1000) -> str:
    try:
        device_path = find_beyond_device()
        if not device_path:
            return 'error: device not found'
        
        beyond = hid.device()
        beyond.open_path(device_path)
        beyond.send_feature_report(bytes([0, Hid_Message.SW_VER]))
        hid_reply = wait_for_response(beyond, [Hid_Message.SW_VER, Hid_Message.ERROR], timeout_ms)
        beyond.close()

        if(len(hid_reply) > 0):
            if(hid_reply[0] == Hid_Message.SW_VER):
                return bytes(hid_reply[1:]).rstrip(b'\x00').decode('ascii')
        return 'error'
    except Exception as e:
        return f'error: {str(e)}'

def get_board_serial(timeout_ms: int = 1000) -> str:
    try:
        device_path = find_beyond_device()
        if not device_path:
            return 'error: device not found'
        
        beyond = hid.device()
        beyond.open_path(device_path)
        beyond.send_feature_report(bytes([0, Hid_Message.SERIAL_NUM]))
        hid_reply = wait_for_response(beyond, [Hid_Message.SERIAL_NUM, Hid_Message.ERROR], timeout_ms)
        beyond.close()

        if(len(hid_reply) > 0):
            if(hid_reply[0] == Hid_Message.SERIAL_NUM):
                return bytes(hid_reply[1:]).rstrip(b'\x00').decode('ascii')
        return 'error'
    except Exception as e:
        return f'error: {str(e)}'

def get_hmd_serial(timeout_ms: int = 1000) -> str:
    try:
        device_path = find_beyond_device()
        if not device_path:
            return 'error: device not found'
        
        beyond = hid.device()
        beyond.open_path(device_path)
        beyond.send_feature_report(bytes([0, Hid_Message.HMD_SERIAL_NUM]))
        hid_reply = wait_for_response(beyond, [Hid_Message.HMD_SERIAL_NUM, Hid_Message.ERROR], timeout_ms)
        beyond.close()

        if(len(hid_reply) > 0):
            if(hid_reply[0] == Hid_Message.HMD_SERIAL_NUM):
                return bytes(hid_reply[1:]).rstrip(b'\x00').decode('ascii')
        return 'error'
    except Exception as e:
        return f'error: {str(e)}'

def get_left_oled_serial(timeout_ms: int = 1000) -> str:
    try:
        device_path = find_beyond_device()
        if not device_path:
            return 'error: device not found'
        
        beyond = hid.device()
        beyond.open_path(device_path)
        beyond.send_feature_report(bytes([0, Hid_Message.OLED_SERIAL_NUM, 0]))
        hid_reply = wait_for_response(beyond, [Hid_Message.OLED_SERIAL_NUM, Hid_Message.ERROR], timeout_ms)
        beyond.close()

        if(len(hid_reply) > 0):
            if(hid_reply[0] == Hid_Message.OLED_SERIAL_NUM):
                return bytes(hid_reply[1:]).rstrip(b'\x00').decode('ascii')
        return 'error'
    except Exception as e:
        return f'error: {str(e)}'

def get_right_oled_serial(timeout_ms: int = 1000) -> str:
    try:
        device_path = find_beyond_device()
        if not device_path:
            return 'error: device not found'
        
        beyond = hid.device()
        beyond.open_path(device_path)
        beyond.send_feature_report(bytes([0, Hid_Message.OLED_SERIAL_NUM, 1]))
        hid_reply = wait_for_response(beyond, [Hid_Message.OLED_SERIAL_NUM, Hid_Message.ERROR], timeout_ms)
        beyond.close()

        if(len(hid_reply) > 0):
            if(hid_reply[0] == Hid_Message.OLED_SERIAL_NUM):
                return bytes(hid_reply[1:]).rstrip(b'\x00').decode('ascii')
        return 'error'
    except Exception as e:
        return f'error: {str(e)}'

# test code
if __name__ == "__main__":
    print("Software Version:", get_software_version())
    print("Board Serial:", get_board_serial())
    print("HMD Serial:", get_hmd_serial())
    print("Left OLED Serial:", get_left_oled_serial())
    print("Right OLED Serial:", get_right_oled_serial())