import json
import enum

CRC_INIT_VAL = 0xFF
CRC_POLY = 0x07

class SigTag(enum.IntEnum):
    Invalid = 0xFF
    Serial = 0x01
    RGB_Color = 0x02
    Fan_Speed = 0x03
    Prox_Disable = 0x04
    Linkbox_v1 = 0x05
    Prox_Cal = 0x06
    FATP_Mode = 0x07
    HMD_Serial = 0x08
    Tracking_Serial = 0x09
    Brightness = 0x0A
    Prox_Thresh = 0x0B
    Prox_Hyst = 0x0C


def crc8(input_data: bytes) -> int:
    initval = CRC_INIT_VAL
    for bb in input_data:
        # for each byte, xor with the current value of CRC
        # then iterate over all 8 bits using the polynomial
        # to generate crc bits
        initval = initval ^ bb
        for _ in range(8):
            if initval & 0x80:
                initval = ((initval << 1) & 0xFF) ^ CRC_POLY
            else:
                initval = (initval << 1) & 0xFF
    return initval


def validate_config(config):
    sig_ptr = 0

    # Default PCB serial; will be overwritten if found in the config memory
    pcb_serial_value = b'XCNL000000000000'

    while True:
        tag = config[sig_ptr]
        data_length = config[sig_ptr + 1]

        if tag == SigTag.Invalid:  # End of data
            break

        if tag == SigTag.Serial:
            pcb_serial_value = config[sig_ptr + 2: sig_ptr + data_length + 2]

            if len(pcb_serial_value) != 16:
                raise ValueError("The length of the PCB Serial is not 16 characters.")

            pcb_serial_value_decoded = bytes(pcb_serial_value).decode()
            if not pcb_serial_value_decoded.startswith("XCNL"):
                raise ValueError("The PCB Serial does not start with 'XCNL'.")
        elif tag != SigTag.Prox_Cal and tag != SigTag.Prox_Thresh and tag != SigTag.Prox_Hyst and tag != SigTag.HMD_Serial:
            # If we found anything other than PCB SN or prox vals, rewrite the config
            rewrite_config(config, pcb_serial_value)

        sig_ptr += 3 + data_length  # 1 byte tag, 1 byte length, 1 byte crc, rest data

    # If we get here, the config memory is valid or has been rewritten to be valid; save it to a file.
    config_dict = {}
    sig_ptr = 0
    while sig_ptr < len(config):
        tag = config[sig_ptr]
        data_length = config[sig_ptr + 1]

        if tag == SigTag.Invalid:  # End of data
            break

        # Add the data to the dictionary using the tag as the key
        data_slice = config[sig_ptr + 2: sig_ptr + 2 + data_length]
        config_dict[str(tag)] = bytes(data_slice).decode(errors='ignore')  # Decoding, but ignoring errors if any.

        sig_ptr += 3 + data_length  # 1 byte tag, 1 byte length, 1 byte crc, rest data

    with open("config_mem_backup.json", 'w') as backup_file:
        json.dump(config_dict, backup_file)


def hmd_config_write_serial(device, serial):
    # Send requests for 16 blocks of current config memory.
    raw_config = []
    for i in range(16):
        device.send_feature_report([0, ord('U'), i])  # ord('U') is 0x55, READ_SIG
        # Wait for reception of this block
        while True:
            data = device.read(64)
            if data:
                if data[0] == ord('U'):
                    print(data, flush=True)
                    raw_config += data[2:34]
                    break

    validate_config(raw_config)

    sig_ptr = 0
    hmd_serial_field_found = False
    while True:
        tag = raw_config[sig_ptr]
        data_length = raw_config[sig_ptr + 1]

        if tag == 0xFF:  # End of data
            break

        if tag == 0x08:
            new_bytes = create_field(0x08, serial.strip().encode())
            raw_config[sig_ptr:sig_ptr + len(new_bytes)] = new_bytes
            hmd_serial_field_found = True

        sig_ptr += 3 + data_length  # 1 byte tag, 1 byte length, 1 byte crc, rest data

    if hmd_serial_field_found is False:  # If no HMD Serial field was found, create it
        new_bytes = create_field(0x08, serial.strip().encode())
        raw_config[sig_ptr:sig_ptr + len(new_bytes)] = new_bytes

    # Write the new config block by block, and save.
    for i in range(16):
        block = [0, ord('W'), i]  # ord('W') is 0x57, WRITE_SIG
        block += bytearray(raw_config[i * 32: i * 32 + 32])
        print(block, flush=True)
        device.send_feature_report(block)
        # Make sure it was a success:
        while True:
            data = device.read(64)
            if data:
                if data[0] == 0x24:
                    break

    # Save config block out.
    device.send_feature_report([0, ord('V')])  # ord('V') is 0x56, SAVE_SIG, after all WRITE_SIG ops
    while True:
        data = device.read(64)
        if data:
            if data[0] == 0x24:
                print("Success", flush=True)
                break
    return raw_config


def hmd_config_set_brightness(device, serial):
    # Send requests for 16 blocks of current config memory.
    raw_config = []
    for i in range(16):
        device.send_feature_report([0, ord('U'), i])  # ord('U') is 0x55, READ_SIG
        # Wait for reception of this block
        while True:
            data = device.read(64)
            if data:
                if data[0] == ord('U'):
                    print(data, flush=True)
                    raw_config += data[2:34]
                    break

    validate_config(raw_config)

    sig_ptr = 0
    hmd_serial_field_found = False
    while True:
        tag = raw_config[sig_ptr]
        data_length = raw_config[sig_ptr + 1]

        if tag == 0xFF:  # End of data
            break

        if tag == 0x08:
            new_bytes = create_field(0x08, serial.strip().encode())
            raw_config[sig_ptr:sig_ptr + len(new_bytes)] = new_bytes
            hmd_serial_field_found = True

        sig_ptr += 3 + data_length  # 1 byte tag, 1 byte length, 1 byte crc, rest data

    if hmd_serial_field_found is False:  # If no HMD Serial field was found, create it
        new_bytes = create_field(0x08, serial.strip().encode())
        raw_config[sig_ptr:sig_ptr + len(new_bytes)] = new_bytes

    # Write the new config block by block, and save.
    for i in range(16):
        block = [0, ord('W'), i]  # ord('W') is 0x57, WRITE_SIG
        block += bytearray(raw_config[i * 32: i * 32 + 32])
        print(block, flush=True)
        device.send_feature_report(block)
        # Make sure it was a success:
        while True:
            data = device.read(64)
            if data:
                if data[0] == 0x24:
                    break

    # Save config block out.
    device.send_feature_report([0, ord('V')])  # ord('V') is 0x56, SAVE_SIG, after all WRITE_SIG ops
    while True:
        data = device.read(64)
        if data:
            if data[0] == 0x24:
                print("Success", flush=True)
                break
    return raw_config


def rewrite_config(raw_config, pcb_serial):
    # Zero out the config memory
    for i in range(len(raw_config)):
        raw_config[i] = 0xFF

    # Create the PCB Serial field
    new_bytes = create_field(0x01, bytes(pcb_serial))
    raw_config[0:len(new_bytes)] = new_bytes

    return raw_config


def create_field(tag, data):
    retbytes = bytes([tag, len(data)]) + data
    crc = crc8(retbytes)
    return retbytes + bytes([crc])
