#!/usr/bin/env python3
"""
HMD Serial Number Update Script

Updates the HMD_Serial field in the device configuration of any Beyond.
Reads the current configuration, updates the HMD_Serial to a hardcoded value,
writes it back to the device, and verifies the update.

This creates a file output.uinf in the current directory. This file should be
shared with customer support to help identify the issue. Its presence prevents
rerunning the patch.
"""

import hid
import os
import sys
import time
import logging

# logging.basicConfig(level=logging.DEBUG)  # debug

# Constants
ICON_PATH = os.path.join(sys._MEIPASS, 'favicon.ico') if hasattr(sys, '_MEIPASS') else 'favicon.ico'
BIGSCREEN_VID = 0x35BD
BEYOND_PID = 0x0101
USER_SIG_LENGTH = 512
HMD_SERIAL_TAG = 0x08
# NEW_HMD_SERIAL = "BS2EY00000000"  # Dummy serial that passes eyetracking requirement
NEW_HMD_SERIAL = "BS2EO27010232"  # Specific SN for customer
CRC_INIT_VAL = 0xFF
CRC_POLY = 0x07

def crc8(input_data):
    """
    Calculate CRC8 checksum with polynomial 0x07 and initial value 0xFF
    """
    initval = CRC_INIT_VAL
    for bb in input_data:
        initval = initval ^ bb
        for _ in range(8):
            if(initval & 0x80):
                initval = ((initval << 1) & 0xFF) ^ CRC_POLY
            else:
                initval = (initval << 1) & 0xFF
    return initval

def create_tlvc_field(tag, data):
    """
    Create a TLVC (Tag-Length-Value-CRC) field
    """
    if isinstance(data, str):
        data = data.encode('ascii')
    
    retbytes = bytes([tag, len(data)]) + data
    crc = crc8(retbytes)
    return retbytes + bytes([crc])

def parse_hmd_serial_from_config(config_raw_data):
    """
    Parse the HMD_Serial field from configuration data
    Returns the HMD serial string if found, None otherwise
    """
    config_data = bytes(config_raw_data)
    sig_ptr = 0
    while sig_ptr < USER_SIG_LENGTH:
        if sig_ptr >= len(config_data) or config_data[sig_ptr] == 0xFF:
            # Reached end of saved data
            break
            
        # Ensure there's enough room for a minimal field
        if sig_ptr > (USER_SIG_LENGTH - 4):
            break
            
        tag = config_data[sig_ptr]
        taglen = config_data[sig_ptr + 1]
        
        # Ensure there's enough room for this field
        if sig_ptr > (USER_SIG_LENGTH - (3 + taglen)):
            break
            
        tagval = config_data[sig_ptr + 2:(sig_ptr + 2 + taglen)]
        tagcrc = config_data[sig_ptr + 2 + taglen]
        
        # Check the CRC
        computecrc = crc8(bytes([tag, taglen]) + bytes(tagval))
        if tagcrc != computecrc:
            logging.warning(f"Warning: CRC mismatch for tag {tag:02X}")
            sig_ptr += 3 + taglen
            continue
            
        # Check if this is the HMD_Serial tag
        if tag == HMD_SERIAL_TAG:
            try:
                return tagval.decode('ascii')
            except UnicodeDecodeError:
                logging.debug("Warning: Could not decode HMD_Serial as ASCII")
                return None
                
        sig_ptr += 3 + taglen
    
    return None

def update_hmd_serial_in_config(config_data, new_serial):
    """
    Update the HMD_Serial field in configuration data
    Returns the updated configuration data
    """
    sig_ptr = 0
    hmd_serial_field_found = False
    while True:
        tag = config_data[sig_ptr]
        data_length = config_data[sig_ptr + 1]

        if tag == 0xFF:  # End of data
            break

        if tag == 0x08:
            new_bytes = create_tlvc_field(0x08, new_serial.strip().encode())
            config_data[sig_ptr:sig_ptr + len(new_bytes)] = new_bytes
            hmd_serial_field_found = True

        sig_ptr += 3 + data_length  # 1 byte tag, 1 byte length, 1 byte crc, rest data

    if hmd_serial_field_found is False:  # If no HMD Serial field was found, create it
        new_bytes = create_tlvc_field(0x08, new_serial.strip().encode())
        config_data[sig_ptr:sig_ptr + len(new_bytes)] = new_bytes
    
    return config_data

def read_device_config(device):
    """
    Read configuration from device (16 pages of 32 bytes each) using threaded approach
    """
    logging.debug("Reading device configuration...")
    # Send requests for 16 blocks of current config memory.
    raw_config = []
    for i in range(16):
        device.send_feature_report([0, ord('U'), i])  # ord('U') is 0x55, READ_SIG
        # Wait for reception of this block
        while True:
            data = device.read(64)
            if data:
                if data[0] == ord('U'):
                    logging.debug(f"Received block {i} of configuration")
                    raw_config += data[2:34]
                    break
    
    logging.debug("Configuration read successfully")
    return raw_config

def write_device_config(device, config_data):
    """
    Write configuration to device using the exact same approach as config_editor.py
    """
    logging.debug("Writing device configuration...")
    
    try:
        # Write the new config block by block, and save.
        for i in range(16):
            block = [0, ord('W'), i]  # ord('W') is 0x57, WRITE_SIG
            block += bytearray(config_data[i * 32: i * 32 + 32])
            logging.debug(f"Wrote block {i} of configuration")
            device.send_feature_report(block)
            # Make sure it was a success:
            while True:
                data = device.read(64)
                if data:
                    if data[0] == 0x24:
                        break

        # Save config block out.
        device.send_feature_report([0, ord('V')])  # ord('V') is 0x56, SAVE_SIG, after all WRITE_SIG ops
        while True:
            data = device.read(64)
            if data:
                if data[0] == 0x24:
                    logging.debug("Success")
                    break
        return config_data
    except (IOError, OSError):
        return

def connect_to_device():
    """
    Connect to the HMD device
    """
    try:
        device = hid.device()
        device.open(BIGSCREEN_VID, BEYOND_PID)
        print("Connected to HMD")
        return device
    except Exception as e:
        raise Exception(f"Could not connect to HMD: {e}")

def main():
    """
    Main function to update HMD serial number
    """
    device = None

    # End early if file output.uinf exists
    try:
        with open("output.uinf", "r"):
            print("Patch has already been applied. To retry, delete output.uinf and run again.")
            logging.debug("Output file already exists. Exiting.")
            time.sleep(5)
            return
    except FileNotFoundError:
        pass
    # Open log file
    output_file = open("output.uinf", "w")
    
    try:
        print("=== Eyetracking Detection Patcher ===")
        logging.debug("=== HMD Serial Number Update Tool ===")
        logging.debug(f"Target serial number: {NEW_HMD_SERIAL}")
        output_file.write(f"Target serial number: {NEW_HMD_SERIAL}\n")
        logging.debug("")
        
        # Connect to device
        device = connect_to_device()
        
        # Read current configuration
        current_config = read_device_config(device)
        
        # Parse current HMD serial
        current_serial = parse_hmd_serial_from_config(current_config)
        if current_serial:
            logging.debug(f"Current HMD Serial: {current_serial}")
            output_file.write(f"Current HMD Serial: {current_serial}\n")
        else:
            logging.debug("Current HMD Serial: Not found or not set")
            output_file.write(f"Current HMD Serial: Not found or not set\n")
            current_serial = ""
        
        if current_serial == NEW_HMD_SERIAL or 'BS2E' in current_serial:
            print("Your HMD is already patched.")
            logging.debug("HMD Serial is already set to the target value!")
            output_file.write(f"HMD Serial already passes eyetracking requirement.\n")
            time.sleep(5)
            return
        
        logging.debug("")
        
        # Update configuration with new serial
        print("Applying patch...")
        logging.debug("Updating configuration with new HMD Serial...")
        updated_config = update_hmd_serial_in_config(current_config, NEW_HMD_SERIAL)
        
        # Parse current HMD serial
        updated_serial = parse_hmd_serial_from_config(updated_config)
        if updated_serial:
            logging.debug(f"Updated HMD Serial: {updated_serial}")
            output_file.write(f"Updated HMD Serial: {updated_serial}\n")
        else:
            logging.debug("Updated HMD Serial: Not found or not set")
            output_file.write("Updated HMD Serial: Not found or not set\n")
        
        if updated_serial != NEW_HMD_SERIAL:
            print("ERROR: Patch application failed. (error 1)")
            logging.debug("DATA MANIPULATION ERROR: Updated HMD Serial does not match the target value!")
            output_file.write("DATA MANIPULATION ERROR: Updated HMD Serial does not match the target value!\n")
            time.sleep(5)
            return
        
        logging.debug("")
        
        # Write updated configuration
        write_device_config(device, updated_config)
        
        logging.debug("")
        
        # Verify the update by reading again
        print("Verifying patch...")
        verification_config = read_device_config(device)
        verified_serial = parse_hmd_serial_from_config(verification_config)
        
        if verified_serial == NEW_HMD_SERIAL:
            print("Patching complete. OK to close the window.")
            logging.debug(f"✓ SUCCESS: HMD Serial successfully updated to: {verified_serial}")
            output_file.write(f"✓ SUCCESS: HMD Serial successfully updated to: {verified_serial}\n")
        else:
            print("ERROR: Patch application failed. (error 2)")
            logging.debug(f"✗ VERIFICATION FAILED: Expected '{NEW_HMD_SERIAL}', got '{verified_serial}'")
            output_file.write(f"✗ VERIFICATION FAILED: Expected '{NEW_HMD_SERIAL}', got '{verified_serial}'\n")
            time.sleep(5)
            return 1
        
    except Exception as e:
        logging.error(f"✗ ERROR: {e}")
        output_file.write(f"✗ ERROR: {e}\n")
        time.sleep(5)
        return 1
    
    finally:
        if output_file:
            try:
                output_file.close()
                logging.debug("Log file closed")
            except:
                pass
        if device:
            try:
                device.close()
                logging.debug("Device connection closed")
            except:
                pass
    
    time.sleep(5)
    return 0

if __name__ == "__main__":
    sys.exit(main())