import hid
import enum

BIGSCREEN_VID = 0x35BD
BEYOND_PID = 0x0101
USER_SIG_LENGTH = 512
CRC_INIT_VAL = 0xFF
CRC_POLY = 0x07

class SigTag(enum.IntEnum):
    Invalid = 0xFF
    Serial = 0x01
    RGB_Color = 0x02
    Fan_Speed = 0x03
    Prox_Disable = 0x04
    Linkbox_v1 = 0x05
    Prox_Cal = 0x06
    FATP_Mode = 0x07
    HMD_Serial = 0x08
    Tracking_Serial = 0x09

def crc8(input_data: bytes) -> int:
    initval = CRC_INIT_VAL
    for bb in input_data:
        # for each byte, xor with the current value of CRC
        # then iterate over all 8 bits using the polynomial 
        # to generate crc bits
        initval = initval ^ bb
        for _ in range(8):
            if(initval & 0x80):
                initval = ((initval << 1) & 0xFF) ^ CRC_POLY
            else:
                initval = (initval << 1) & 0xFF
    return initval

def create_field(tag: SigTag, data: bytes) -> bytes:
    retbytes = bytes([tag, len(data)]) + data
    crc = crc8(retbytes)
    return retbytes + bytes([crc])

def create_signature(sig_fields: dict) -> bytes:
    sig_ptr = 0
    sig_bytes = bytearray([0xFF]*USER_SIG_LENGTH)
    if(SigTag.Serial in sig_fields):
        # serial number
        newsig = sig_fields[SigTag.Serial]
        if(len(newsig) > 0):
            newblock = create_field(SigTag.Serial, newsig)
            sig_bytes[sig_ptr:sig_ptr + len(newblock)] = newblock
            sig_ptr = sig_ptr + len(newblock)
    if(SigTag.HMD_Serial in sig_fields):
        newsig = sig_fields[SigTag.HMD_Serial]
        if(len(newsig) > 0):
            newblock = create_field(SigTag.HMD_Serial, newsig)
            sig_bytes[sig_ptr:sig_ptr + len(newblock)] = newblock
            sig_ptr = sig_ptr + len(newblock)
    if(SigTag.Tracking_Serial in sig_fields):
        newsig = sig_fields[SigTag.Tracking_Serial]
        if(len(newsig) > 0):
            newblock = create_field(SigTag.Tracking_Serial, newsig)
            sig_bytes[sig_ptr:sig_ptr + len(newblock)] = newblock
            sig_ptr = sig_ptr + len(newblock)
    if(SigTag.Fan_Speed in sig_fields):
        if(type(sig_fields[SigTag.Fan_Speed]) == int):
            newval = bytes([sig_fields[SigTag.Fan_Speed]])
        else:
            newval = bytes(sig_fields[SigTag.Fan_Speed])
        newblock = create_field(SigTag.Fan_Speed, newval)
        sig_bytes[sig_ptr:sig_ptr + len(newblock)] = newblock
        sig_ptr = sig_ptr + len(newblock)
    if(SigTag.Prox_Cal in sig_fields):
        newblock = create_field(SigTag.Prox_Cal, sig_fields[SigTag.Prox_Cal])
        sig_bytes[sig_ptr:sig_ptr + len(newblock)] = newblock
        sig_ptr = sig_ptr + len(newblock)
    if(SigTag.RGB_Color in sig_fields):
        newblock = create_field(SigTag.RGB_Color, sig_fields[SigTag.RGB_Color])
        sig_bytes[sig_ptr:sig_ptr + len(newblock)] = newblock
        sig_ptr = sig_ptr + len(newblock)
    if(SigTag.Prox_Disable in sig_fields):
        if(type(sig_fields[SigTag.Prox_Disable]) == int):
            newval = bytes([sig_fields[SigTag.Prox_Disable]])
        else:
            newval = bytes(sig_fields[SigTag.Prox_Disable])
        newblock = create_field(SigTag.Prox_Disable, newval)
        sig_bytes[sig_ptr:sig_ptr + len(newblock)] = newblock
        sig_ptr = sig_ptr + len(newblock)
    if(SigTag.Linkbox_v1 in sig_fields):
        if(type(sig_fields[SigTag.Linkbox_v1]) == int):
            newval = bytes([sig_fields[SigTag.Linkbox_v1]])
        else:
            newval = bytes(sig_fields[SigTag.Linkbox_v1])
        newblock = create_field(SigTag.Linkbox_v1, newval)
        sig_bytes[sig_ptr:sig_ptr + len(newblock)] = newblock
        sig_ptr = sig_ptr + len(newblock)
    if(SigTag.FATP_Mode in sig_fields):
        if(type(sig_fields[SigTag.Linkbox_v1]) == int):
            newval = bytes([sig_fields[SigTag.FATP_Mode]])
        else:
            newval = bytes(sig_fields[SigTag.FATP_Mode])
        newblock = create_field(SigTag.FATP_Mode, newval)
        sig_bytes[sig_ptr:sig_ptr + len(newblock)] = newblock
        sig_ptr = sig_ptr + len(newblock)

    return bytes(sig_bytes)

def parse_sig(sig_bytes: bytes) -> dict:
    # reads the signature block and extracts the config data
    # data is in TLVC (Tag, Length, Value, CRC) format
    # Byte0: Tag, tells you what data field is saved
    # Byte1: Length, number of bytes in the Value
    # Byte2 to Byte 2+Length-1: Value, arbitrary data saved for this field
    # Byte 2+Length: CRC, 8-bit CRC code generated with polynomial 0x07 over all bytes (including tag)
    #               with an initial value of 0xFF

    sig_ptr = 0
    sig_fields = {}
    while(sig_ptr < USER_SIG_LENGTH):
        if(sig_bytes[sig_ptr] == SigTag.Invalid):
            # we reached the end of saved data, so this was valid
            # even if there's no data saved this can be a valid signature (all 0xFF)
            return sig_fields
        # Ensure there's enough room for a 1-byte field
        if(sig_ptr > (USER_SIG_LENGTH - 4)):
            # Tag will overrun end of signature region
            sig_fields['error'] = 'overrun'
            return sig_fields
        tag = sig_bytes[sig_ptr]
        taglen = sig_bytes[sig_ptr+1]
        # now ensure there's enough room for THIS field. Use the length data.
        if(sig_ptr > (USER_SIG_LENGTH - (3 + taglen))):
            sig_fields['error'] = 'overrun'
            return sig_fields
        tagval = sig_bytes[sig_ptr+2:(sig_ptr+2+taglen)]
        tagcrc = sig_bytes[sig_ptr+2+taglen]
        # check the CRC
        computecrc = crc8(bytes([tag, taglen])+bytes(tagval))
        if(tagcrc != computecrc):
            sig_fields['error'] = 'crc mismatch'
            return sig_fields
        # Check if the tag is in our database
        if(tag == SigTag.Serial):
            sig_fields[SigTag.Serial] = tagval
        if(tag == SigTag.HMD_Serial):
            sig_fields[SigTag.HMD_Serial] = tagval
        if(tag == SigTag.Tracking_Serial):
            sig_fields[SigTag.Tracking_Serial] = tagval            
        if(tag == SigTag.Fan_Speed):
            if(taglen != 1):
                sig_fields['error'] = 'tag length mismatch (Fan speed)'
                return sig_fields
            else:
                sig_fields[SigTag.Fan_Speed] = tagval
        if(tag == SigTag.Prox_Cal):
            if(taglen != 2):
                sig_fields['error'] = 'tag length mismatch (prox cal)'
                return sig_fields
            else:
                sig_fields[SigTag.Prox_Cal] = tagval
        if(tag == SigTag.RGB_Color):
            if(taglen != 3):
                sig_fields['error'] = 'tag length mismatch (rgb color)'
                return sig_fields
            else:
                sig_fields[SigTag.RGB_Color] = tagval
        if(tag == SigTag.Prox_Disable):
            if(taglen != 1):
                sig_fields['error'] = 'tag length mismatch (prox disable)'
                return sig_fields
            else:
                sig_fields[SigTag.Prox_Disable] = tagval
        if(tag == SigTag.Linkbox_v1):
            if(taglen != 1):
                sig_fields['error'] = 'tag length mismatch (linkbox v1)'
                return sig_fields
            else:
                sig_fields[SigTag.Linkbox_v1] = tagval
        if(tag == SigTag.FATP_Mode):
            if(taglen != 1):
                sig_fields['error'] = 'tag length mismatch (FATP mode)'
                return sig_fields
            else:
                sig_fields[SigTag.FATP_Mode] = tagval
        # any other tag should be ignored. can still be a valid config region, just means
        # there's a new field type that we don't know about yet
        sig_ptr = sig_ptr + 3 + taglen # length of data, plus the 3 fixed bytes (tag, length, crc)
    return sig_fields

if __name__ == '__main__':

    bigs = hid.Device(vid = BIGSCREEN_VID, pid = BEYOND_PID)

    # load config region
    config_bytes = bytearray([0xFF] * USER_SIG_LENGTH)
    for i in range(16):
        bigs.send_feature_report(bytes([0, ord('U'), i]))
        newbytes = bigs.read(65, timeout=100)
        while(len(newbytes) == 0 or newbytes[0] != ord('U')):
            newbytes = bigs.read(65, timeout=100)
        newlen = newbytes[1]
        newsig = newbytes[2:(2+newlen)]
        config_bytes[(32*i):(32*i + newlen)] = newsig

    # print(config_bytes)

    config_fields = parse_sig(config_bytes)

    print('Found config data: ')
    print(config_fields)

    if(SigTag.Prox_Cal in config_fields):
        config_fields.pop(SigTag.Prox_Cal)

        print('Removing proximity calibration and saving new config:')
        print(config_fields)

        new_config_bytes = create_signature(config_fields)

        for i in range(16):
            bigs.send_feature_report(bytes([0, ord('W'), i]) + new_config_bytes[(32*i):(32*(i+1))])
            newbytes = bigs.read(65, timeout=100)
            while(len(newbytes) == 0 or newbytes[0] != ord('$')):
                newbytes = bigs.read(65, timeout=100)

        bigs.send_feature_report(bytes([0, ord('V')]))
        newbytes = bigs.read(65, timeout=100)
        while(len(newbytes) == 0 or newbytes[0] != ord('$')):
            newbytes = bigs.read(65, timeout=100)

        print('Saved new config. Power cycle HMD.')
    else:
        print('Prox calibration not found. No changes were made.')