'''
config_lib.py

Routines for reading and editing the configuration memory on the Beyond's ATSAMG55 microcontroller.

There is a special 512B Flash page in the microcontroller, not for regular code storage, known as
the "user signature page". For the Beyond, this page is used for trimming and configuration constants.

The layout of the page is a list of TLV (Tag, Length, Value) entries. There is no fixed size for the
entries, and no fixed order either. The Tag (one byte) tells us what config value is stored in this entry,
the Length (one byte) is the number of bytes in the data section. The variable length data follows the 
Length byte. At the end of the data bytes, there is a single CRC byte generated over the ENTIRE entry
using the generator polynomial 0x07. The CRC calculation includes the Tag and Length bytes.

Example:
   |Tag--|Length--|Value (8 bytes)--------------------------------|CRC8|
	0x01, 0x08,    0x48, 0x69, 0x54, 0x68, 0x65, 0x72, 0x65, 0x21, 0x9D
		# This is a 8-byte serial number: "HiThere!" #

Current valid tags are:
SigTag_Serial = 0x01            # Variable length, ASCII characters. Assigned to PCBA during factory test.
SigTag_RGB_Color = 0x02         # 3 bytes, Red, Green, and Blue color
SigTag_Fan_Speed = 0x03         # one byte, fan speed in percent
SigTag_Prox_Disable = 0x04      # one byte, boolean. Disables proximity sensor screen dimming for debugging
SigTag_Linkbox_v1 = 0x05        # one byte, boolean. Inverts the DisplayPort HotPlugDetect signal for linkbox v1
SigTag_Prox_Cal = 0x06          # two bytes, u16 value of proximity sensor signal with nothing present
SigTag_FATP_Mode = 0x07         # one byte, boolean. Forces FATP mode EDID at startup.
SigTag_HMD_Serial = 0x08        # Variable length, ASCII characters. Serial number of whole headset assembly
SigTag_Tracking_Serial = 0x09   # Variable length, ASCII characters. Serial number of the tracking flex in the HMD
SigTag_Display_Brightness = 0x0A    # two bytes, u16 value of default display brightness. 0 to 0x03FF
SigTag_Prox_Threshold = 0x0B    # two bytes, u16 value of prox sensor threshold
SigTag_Prox_Hysteresis = 0x0C   # two bytes, u16 value of prox sensor threshold hysteresis (eg turns on at thresh+hyst, turns off at thresh-hyst)
SigTag_EDID_Switch = 0x0D       # one byte: 0 = both 75Hz and 90Hz modes, 1 = 90Hz only, 2 = 75Hz only
SigTag_Prox_User_Trim = 0x0E,   # two bytes, i16 proximity threshold adjustment (NOTE THIS IS A SIGNED VALUE!)
SigTag_VXR_Sleep_Enable = 0x0F, # one byte, boolean. False = no sleep, True = sleep. Default (not loaded) is no sleep (false).
SigTag_Invalid = 0xFF           # Any unprogrammed byte will be 0xFF. Can use this to denote the end of the valid tags	

'''

import enum
import hid # NOTE: Uses hidapi, not hid. https://pypi.org/project/hidapi/

import hid_lib

USER_SIG_LENGTH = 512
CRC_INIT_VAL = 0xFF
CRC_POLY = 0x07

READSIG_LENGTH = 32 # total of 16 reads needed for 512 bytes

class SigTag(enum.IntEnum):
    Invalid = 0xFF
    Serial = 0x01
    RGB_Color = 0x02
    Fan_Speed = 0x03
    Prox_Disable = 0x04
    Linkbox_v1 = 0x05
    Prox_Cal = 0x06
    FATP_Mode = 0x07
    HMD_Serial = 0x08
    Tracking_Serial = 0x09
    Display_Brightness = 0x0A
    Prox_Threshold = 0x0B
    Prox_Hysteresis = 0x0C
    EDID_Switch = 0x0D,
    User_Prox_Trim = 0x0E
    VXR_Sleep_Enable = 0x0F

class SigErrors(enum.IntEnum):
    BadCRC = 0x01
    Overrun = 0x02


def crc8(input_data: bytes) -> int:
    '''
    Generates a 8-bit CRC over all input bytes.
    Used for the final byte in a config entry.
    '''
    initval = CRC_INIT_VAL
    for bb in input_data:
        # for each byte, xor with the current value of CRC
        # then iterate over all 8 bits using the polynomial 
        # to generate crc bits
        initval = initval ^ bb
        for _ in range(8):
            if(initval & 0x80):
                initval = ((initval << 1) & 0xFF) ^ CRC_POLY
            else:
                initval = (initval << 1) & 0xFF
    return initval


def parse_sig(input_data: bytes) -> dict:
    '''
    Expects input bytes of length 512 (the entire user signature page)
    Incrementally runs through the page, parsing out TLV entries. Checks if
    they are valid before committing them to the return dict.
    Dict entries are either bytes (the value of this entry) or an error code.
    If the parser is unable to extract any data (for example, length is invalid
    by extending the data past the end of the 512B page) then the parser will 
    stop.
    '''
    if len(input_data) != USER_SIG_LENGTH:
        return {}
    
    sig_ptr = 0
    sig_fields = {}
    while sig_ptr < USER_SIG_LENGTH:
        try:
            tag_byte = input_data[sig_ptr]
            if tag_byte == SigTag.Invalid:
                # No more tags here
                break
            length_byte = input_data[sig_ptr+1]
            data_bytes = input_data[sig_ptr+2:sig_ptr+2+length_byte]
            crc_byte = input_data[sig_ptr + 2 + length_byte]
            if crc_byte == crc8(bytes([tag_byte,length_byte]) + bytes(data_bytes)):
                sig_fields[tag_byte] = data_bytes
            else:
                sig_fields[tag_byte] = SigErrors.BadCRC

            sig_ptr = sig_ptr + 3 + length_byte
        except IndexError:
            sig_fields[tag_byte] = SigErrors.Overrun
            break
    
    return sig_fields


def create_sig(entries: dict) -> bytes:
    '''
    Creates a 512 byte signature region by generating TLV+crc entries
    from the input dict
    Expects bytes in each dict value. dict keys will become the tags
    '''
    outbytes = bytearray()
    
    for key,val in entries.items():
        if type(val) != bytes:
            raise ValueError(f"Value must be of type bytes. Got type {type(val)} instead")
        entry_without_crc = bytes([key, len(val)]) + val
        crc_byte = crc8(entry_without_crc)
        outbytes.extend(entry_without_crc)
        outbytes.append(crc_byte)

        if len(outbytes) > USER_SIG_LENGTH:
            raise IndexError(f"User signature page exceeded {USER_SIG_LENGTH} bytes. Was {len(outbytes)} in length.")
        
    if len(outbytes) < USER_SIG_LENGTH:
        outbytes.extend(bytes([0xFF]*(USER_SIG_LENGTH - len(outbytes))))

    return bytes(outbytes)

def read_sig(beyond_hid: hid.device) -> bytes:
    '''
    Grabs the entire 512 byte signature (configuration) page from the Beyond
    '''
    config_bytes = bytearray()
    for i in range(16):
        cmd = bytes([0, hid_lib.HIDCommands.READ_SIG, i])
        valid_replies = [hid_lib.HIDCommands.READ_SIG, hid_lib.HIDReplies.ERROR]
        read_bytes = hid_lib.wait_for_response(beyond_hid, cmd, valid_replies)
        if len(read_bytes) == 0:
            # timeout
            return b''
        if read_bytes[0] == hid_lib.HIDReplies.ERROR:
            return b''
        if read_bytes[0] == hid_lib.HIDCommands.READ_SIG:
            sig_len = read_bytes[1]
            sig_bytes = read_bytes[2:(2+sig_len)]
            config_bytes.extend(sig_bytes)

    return bytes(config_bytes)

def save_sig(beyond_hid: hid.device, config_bytes: bytes):
    '''
    Overwrites the current signature page with the given bytes
    '''
    if len(config_bytes) != USER_SIG_LENGTH:
        raise ValueError(f"User signature page size incorrect. Should be {USER_SIG_LENGTH} bytes. Was {len(config_bytes)} in length.")
    
    for i in range(16):
        bytes_to_send = config_bytes[i*READSIG_LENGTH:(i+1)*READSIG_LENGTH]
        cmd = bytes([0, hid_lib.HIDCommands.WRITE_SIG, i]) + bytes_to_send
        valid_replies = [hid_lib.HIDReplies.SUCCESS, hid_lib.HIDReplies.ERROR]
        read_bytes = hid_lib.wait_for_response(beyond_hid, cmd, valid_replies)
        if len(read_bytes) == 0:
            raise RuntimeError('Timed out writing signature page')
        elif read_bytes[0] == hid_lib.HIDReplies.ERROR:
            raise RuntimeError('Error writing signature page')
        
    # after writing the full page, save it
    cmd = bytes([0, hid_lib.HIDCommands.SAVE_SIG])
    valid_replies = [hid_lib.HIDReplies.SUCCESS, hid_lib.HIDReplies.ERROR]
    read_bytes = hid_lib.wait_for_response(beyond_hid, cmd, valid_replies)
    if len(read_bytes) == 0:
        raise RuntimeError('Timed out saving signature page')
    elif read_bytes[0] == hid_lib.HIDReplies.ERROR:
        raise RuntimeError('Error saving signature page')

