import struct
from hmd_config import validate_config

##
CRC_INIT_VAL = 0xFF
CRC_POLY = 0x07


##
def crc8(input_data: bytes) -> int:
    initval = CRC_INIT_VAL
    for bb in input_data:
        # for each byte, xor with the current value of CRC
        # then iterate over all 8 bits using the polynomial
        # to generate crc bits
        initval = initval ^ bb
        for _ in range(8):
            if initval & 0x80:
                initval = ((initval << 1) & 0xFF) ^ CRC_POLY
            else:
                initval = (initval << 1) & 0xFF
    return initval


# Resets the device proximity values to zero to begin proximity calibration.
def reset_proximity(device):
    # Send requests for 16 blocks of current config memory.
    raw_config = []
    for i in range(16):
        device.send_feature_report([0, 0x55, i])
        # Wait for reception of this block
        while True:
            data = device.read(64)
            if data:
                if data[0] == 0x55:
                    print(data, flush=True)
                    raw_config += data[2:34]
                    break

    validate_config(raw_config)

    sig_ptr = 0
    proximity_field_found = False
    while True:
        tag = raw_config[sig_ptr]
        data_length = raw_config[sig_ptr + 1]

        if tag == 0xFF:  # End of data
            break

        if tag == 0x06:
            raw_config[sig_ptr + 2] = 0x0  # LSB of calibration
            raw_config[sig_ptr + 3] = 0x0  # MSB of calibration
            raw_config[sig_ptr + 4] = crc8(bytes([0x06, raw_config[sig_ptr + 1], raw_config[sig_ptr + 2],
                                           raw_config[sig_ptr + 3]]))  # checksum
            proximity_field_found = True

        sig_ptr += 3 + data_length  # 1 byte tag, 1 byte length, 1 byte crc, rest data

    if proximity_field_found is False:  # If no HMD Serial field was found, create it
        raw_config[sig_ptr] = 0x06
        raw_config[sig_ptr + 1] = 0x02
        raw_config[sig_ptr + 2] = 0x0  # LSB of calibration
        raw_config[sig_ptr + 3] = 0x0  # MSB of calibration
        raw_config[sig_ptr + 4] = crc8(bytes([0x06, raw_config[sig_ptr + 1], raw_config[sig_ptr + 2],
                                       raw_config[sig_ptr + 3]]))  # checksum

    sig_ptr = 0
    threshold_field_found = False
    while True:
        tag = raw_config[sig_ptr]
        data_length = raw_config[sig_ptr + 1]

        if tag == 0xFF:  # End of data
            break

        if tag == 0x0B:
            raw_config[sig_ptr + 2] = 0x0  # LSB of calibration
            raw_config[sig_ptr + 3] = 0x0  # MSB of calibration
            raw_config[sig_ptr + 4] = crc8(bytes([0x0B, raw_config[sig_ptr + 1], raw_config[sig_ptr + 2],
                                           raw_config[sig_ptr + 3]]))  # checksum
            threshold_field_found = True

        sig_ptr += 3 + data_length  # 1 byte tag, 1 byte length, 1 byte crc, rest data

    if threshold_field_found is False:  # If no HMD Serial field was found, create it
        raw_config[sig_ptr] = 0x0B
        raw_config[sig_ptr + 1] = 0x02
        raw_config[sig_ptr + 2] = 0x0  # LSB of calibration
        raw_config[sig_ptr + 3] = 0x0  # MSB of calibration
        raw_config[sig_ptr + 4] = crc8(bytes([0x0B, raw_config[sig_ptr + 1], raw_config[sig_ptr + 2],
                                       raw_config[sig_ptr + 3]]))  # checksum

    # Write the new config block by block, and save.
    for i in range(16):
        block = [0, 0x57, i]
        block += bytearray(raw_config[i * 32: i * 32 + 32])
        print(block, flush=True)
        device.send_feature_report(block)
        # Make sure it was a success:
        while True:
            data = device.read(64)
            if data:
                if data[0] == 0x24:
                    break

    # Save config block out.
    device.send_feature_report([0, 0x56])
    while True:
        data = device.read(64)
        if data:
            if data[0] == 0x24:
                print("Success", flush=True)
                break
    return raw_config


# Read proximity from empty side to serve as new calibration value.
def proximity_read_proximity(device):
    while True:
        data = device.read(64)
        if data:
            if data[0] == 0x23:
                proximity = struct.unpack('>H', bytearray(data[4:6]))[0]
                return proximity


# Write new proximity calibration value to headset.
def proximity_write_to_headset(device, new_calibration, raw_config):
    new_calibration_lsb = int(new_calibration).to_bytes(2, 'little')[0]
    new_calibration_msb = int(new_calibration).to_bytes(2, 'little')[1]

    validate_config(raw_config)

    # Write new calibration to headset.
    sig_ptr = 0
    while True:
        tag = raw_config[sig_ptr]
        data_length = raw_config[sig_ptr + 1]

        if tag == 0xFF:  # End of data
            break

        if tag == 0x06:
            raw_config[sig_ptr + 2] = new_calibration_lsb  # LSB of calibration
            raw_config[sig_ptr + 3] = new_calibration_msb  # MSB of calibration
            raw_config[sig_ptr + 4] = crc8(bytes([0x06, raw_config[sig_ptr + 1], raw_config[sig_ptr + 2],
                                                  raw_config[sig_ptr + 3]]))  # checksum

        sig_ptr += 3 + data_length  # 1 byte tag, 1 byte length, 1 byte crc, rest data

    # Write the new config block by block, and save.
    for i in range(16):
        block = [0, 0x57, i]
        block += bytearray(raw_config[i * 32: i * 32 + 32])
        print(block, flush=True)
        device.send_feature_report(block)
        # Make sure it was a success:
        while True:
            data = device.read(64)
            if data:
                if data[0] == 0x24:
                    break

    # Save config block out.
    device.send_feature_report([0, 0x56])
    while True:
        data = device.read(64)
        if data:
            if data[0] == 0x24:
                print("Success", flush=True)
                break
    return raw_config


# Write new proximity threshold value to headset.
def proximity_threshold_write_to_headset(device, proximity_threshold, raw_config):
    new_threshold_lsb = proximity_threshold.to_bytes(2, 'little')[0]
    new_threshold_msb = proximity_threshold.to_bytes(2, 'little')[1]

    validate_config(raw_config)

    # If we have a good proximity threshold, write to config
    # Search the raw_config to replace 'proximity threshold' with new proximity threshold
    sig_ptr = 0
    while True:
        tag = raw_config[sig_ptr]
        data_length = raw_config[sig_ptr + 1]

        if tag == 0xFF:  # End of data
            break

        if tag == 0x0B:
            raw_config[sig_ptr + 2] = new_threshold_lsb  # LSB of calibration
            raw_config[sig_ptr + 3] = new_threshold_msb  # MSB of calibration
            raw_config[sig_ptr + 4] = crc8(bytes([0x0B, raw_config[sig_ptr + 1], raw_config[sig_ptr + 2],
                                                  raw_config[sig_ptr + 3]]))  # checksum

        sig_ptr += 3 + data_length  # 1 byte tag, 1 byte length, 1 byte crc, rest data

    # Write the new config block by block, and save.
    for i in range(16):
        block = [0, 0x57, i]
        block += bytearray(raw_config[i * 32: i * 32 + 32])
        print(block, flush=True)
        device.send_feature_report(block)
        # Make sure it was a success:
        while True:
            data = device.read(64)
            if data:
                if data[0] == 0x24:
                    break

    # Save config block out.
    device.send_feature_report([0, 0x56])
    while True:
        data = device.read(64)
        if data:
            if data[0] == 0x24:
                print("Success", flush=True)
                break
    return raw_config


def create_field(tag, data):
    retbytes = bytes([tag, len(data)]) + data
    crc = crc8(retbytes)
    return retbytes + bytes([crc])
