#!/usr/bin/env python3
"""
HMD Serial Number Update Script

Updates the HMD_Serial field in the device configuration of any Beyond.
Reads the current configuration, updates the HMD_Serial to a hardcoded value,
writes it back to the device, and verifies the update.
"""

import hid
import sys
import time
import logging

# logging.basicConfig(level=logging.DEBUG)  # debug

# Constants
BIGSCREEN_VID = 0x35BD
BEYOND_PID = 0x0101
USER_SIG_LENGTH = 512
HMD_SERIAL_TAG = 0x08
NEW_HMD_SERIAL = "BS2EY00000000"
CRC_INIT_VAL = 0xFF
CRC_POLY = 0x07

def crc8(input_data):
    """
    Calculate CRC8 checksum with polynomial 0x07 and initial value 0xFF
    """
    initval = CRC_INIT_VAL
    for bb in input_data:
        initval = initval ^ bb
        for _ in range(8):
            if(initval & 0x80):
                initval = ((initval << 1) & 0xFF) ^ CRC_POLY
            else:
                initval = (initval << 1) & 0xFF
    return initval

def create_tlvc_field(tag, data):
    """
    Create a TLVC (Tag-Length-Value-CRC) field
    """
    if isinstance(data, str):
        data = data.encode('ascii')
    
    retbytes = bytes([tag, len(data)]) + data
    crc = crc8(retbytes)
    return retbytes + bytes([crc])

def parse_hmd_serial_from_config(config_raw_data):
    """
    Parse the HMD_Serial field from configuration data
    Returns the HMD serial string if found, None otherwise
    """
    config_data = bytes(config_raw_data)
    sig_ptr = 0
    while sig_ptr < USER_SIG_LENGTH:
        if sig_ptr >= len(config_data) or config_data[sig_ptr] == 0xFF:
            # Reached end of saved data
            break
            
        # Ensure there's enough room for a minimal field
        if sig_ptr > (USER_SIG_LENGTH - 4):
            break
            
        tag = config_data[sig_ptr]
        taglen = config_data[sig_ptr + 1]
        
        # Ensure there's enough room for this field
        if sig_ptr > (USER_SIG_LENGTH - (3 + taglen)):
            break
            
        tagval = config_data[sig_ptr + 2:(sig_ptr + 2 + taglen)]
        tagcrc = config_data[sig_ptr + 2 + taglen]
        
        # Check the CRC
        computecrc = crc8(bytes([tag, taglen]) + bytes(tagval))
        if tagcrc != computecrc:
            logging.warning(f"Warning: CRC mismatch for tag {tag:02X}")
            sig_ptr += 3 + taglen
            continue
            
        # Check if this is the HMD_Serial tag
        if tag == HMD_SERIAL_TAG:
            try:
                return tagval.decode('ascii')
            except UnicodeDecodeError:
                logging.debug("Warning: Could not decode HMD_Serial as ASCII")
                return None
                
        sig_ptr += 3 + taglen
    
    return None

def update_hmd_serial_in_config(config_data, new_serial):
    """
    Update the HMD_Serial field in configuration data
    Returns the updated configuration data
    """
    sig_ptr = 0
    hmd_serial_field_found = False
    while True:
        tag = config_data[sig_ptr]
        data_length = config_data[sig_ptr + 1]

        if tag == 0xFF:  # End of data
            break

        if tag == 0x08:
            new_bytes = create_tlvc_field(0x08, new_serial.strip().encode())
            config_data[sig_ptr:sig_ptr + len(new_bytes)] = new_bytes
            hmd_serial_field_found = True

        sig_ptr += 3 + data_length  # 1 byte tag, 1 byte length, 1 byte crc, rest data

    if hmd_serial_field_found is False:  # If no HMD Serial field was found, create it
        new_bytes = create_tlvc_field(0x08, new_serial.strip().encode())
        config_data[sig_ptr:sig_ptr + len(new_bytes)] = new_bytes
    
    return config_data

def read_device_config(device):
    """
    Read configuration from device (16 pages of 32 bytes each) using threaded approach
    """
    logging.debug("Reading device configuration...")
    # Send requests for 16 blocks of current config memory.
    raw_config = []
    for i in range(16):
        device.send_feature_report([0, ord('U'), i])  # ord('U') is 0x55, READ_SIG
        # Wait for reception of this block
        while True:
            data = device.read(64)
            if data:
                if data[0] == ord('U'):
                    logging.debug(f"Received block {i} of configuration")
                    raw_config += data[2:34]
                    break
    
    logging.debug("Configuration read successfully")
    return raw_config

def write_device_config(device, config_data):
    """
    Write configuration to device using the exact same approach as config_editor.py
    """
    logging.debug("Writing device configuration...")
    
    try:
        # Write the new config block by block, and save.
        for i in range(16):
            block = [0, ord('W'), i]  # ord('W') is 0x57, WRITE_SIG
            block += bytearray(config_data[i * 32: i * 32 + 32])
            logging.debug(f"Wrote block {i} of configuration")
            device.send_feature_report(block)
            # Make sure it was a success:
            while True:
                data = device.read(64)
                if data:
                    if data[0] == 0x24:
                        break

        # Save config block out.
        device.send_feature_report([0, ord('V')])  # ord('V') is 0x56, SAVE_SIG, after all WRITE_SIG ops
        while True:
            data = device.read(64)
            if data:
                if data[0] == 0x24:
                    logging.debug("Success")
                    break
        return config_data
    except (IOError, OSError):
        return

def connect_to_device():
    """
    Connect to the HMD device
    """
    try:
        device = hid.device()
        device.open(BIGSCREEN_VID, BEYOND_PID)
        print("Connected to HMD")
        return device
    except Exception as e:
        raise Exception(f"Could not connect to HMD: {e}")

def main():
    """
    Main function to update HMD serial number
    """
    device = None
    
    try:
        print("=== Eyetracking Detection Patcher ===")
        logging.debug("=== HMD Serial Number Update Tool ===")
        logging.debug(f"Target serial number: {NEW_HMD_SERIAL}")
        logging.debug("")
        
        # Connect to device
        device = connect_to_device()
        
        # Read current configuration
        current_config = read_device_config(device)
        
        # Parse current HMD serial
        current_serial = parse_hmd_serial_from_config(current_config)
        if current_serial:
            logging.debug(f"Current HMD Serial: {current_serial}")
        else:
            logging.debug("Current HMD Serial: Not found or not set")
        
        if current_serial == NEW_HMD_SERIAL or 'BS2E' in current_serial:
            print("Your HMD is already patched.")
            logging.debug("HMD Serial is already set to the target value!")
            return
        
        logging.debug("")
        
        # Update configuration with new serial
        print("Applying patch...")
        logging.debug("Updating configuration with new HMD Serial...")
        updated_config = update_hmd_serial_in_config(current_config, NEW_HMD_SERIAL)
        
        # Parse current HMD serial
        updated_serial = parse_hmd_serial_from_config(updated_config)
        if updated_serial:
            logging.debug(f"Updated HMD Serial: {updated_serial}")
        else:
            logging.debug("Updated HMD Serial: Not found or not set")
        
        if updated_serial != NEW_HMD_SERIAL:
            print("ERROR: Patch application failed. (error 1)")
            logging.debug("DATA MANIPULATION ERROR: Updated HMD Serial does not match the target value!")
            return
        
        logging.debug("")
        
        # Write updated configuration
        write_device_config(device, updated_config)
        
        logging.debug("")
        
        # Verify the update by reading again
        print("Verifying patch...")
        verification_config = read_device_config(device)
        verified_serial = parse_hmd_serial_from_config(verification_config)
        
        if verified_serial == NEW_HMD_SERIAL:
            print("Patching complete. OK to close the window.")
            logging.debug(f"✓ SUCCESS: HMD Serial successfully updated to: {verified_serial}")
        else:
            print("ERROR: Patch application failed. (error 2)")
            logging.debug(f"✗ VERIFICATION FAILED: Expected '{NEW_HMD_SERIAL}', got '{verified_serial}'")
            return 1
        
    except Exception as e:
        logging.error(f"✗ ERROR: {e}")
        return 1
    
    finally:
        if device:
            try:
                device.close()
                logging.debug("Device connection closed")
                time.sleep(5)
            except:
                pass
    
    return 0

if __name__ == "__main__":
    sys.exit(main())