"""
Beyond HMD Firmware Library

Provides functions for reading firmware versions and updating HMD firmware.
"""

import os
import hid
import struct
import time
import hashlib
import winreg
from typing import Tuple, Optional, List
import bs_usb_tools


# Constants
HMD_GET_SERIAL = ord('&')
HMD_GET_SW_VER = ord('*')
HMD_ERROR = ord('E')
CHUNKSIZE = 32
APP_ADDRESS = 0x00404000
PAGE_SIZE = 0x200
SMALL_SUBSECTOR_SIZE = 0x2000
LARGE_SUBSECTOR_SIZE = 0x1C000
SECTOR_SIZE = 0x20000
ALLOWED_APP_SIZE = LARGE_SUBSECTOR_SIZE + 3 * SECTOR_SIZE
MIN_ERASABLE_SIZE = 0x2000
DEFAULT_HID_TIMEOUT = 10

BS_VID = 0x35BD
APP_PID = 0x0101
BOOT_PID = 0x4004
DFU_PID = 0x0282


class FirmwareUpdateError(Exception):
    """Exception raised when firmware update fails."""
    pass


class BootloaderEntryError(Exception):
    """Exception raised when bootloader entry fails."""
    pass


def crc8_calc(crc8_init: int, databytes: bytes) -> int:
    """Calculate CRC-8 using 0x07 as polynomial generator."""
    CRC8POLY = 0x07
    for bb in databytes:
        crc8_init = crc8_init ^ bb
        for k in range(8):
            if crc8_init & 0x80:
                crc8_init = (0xFF & (crc8_init << 1)) ^ CRC8POLY
            else:
                crc8_init = (0xFF & (crc8_init << 1))
    return crc8_init


def hex_id_to_str(hex_id: List[int]) -> str:
    """Convert a list of bytes to a string containing only the hex digits after 0x."""
    return hex(hex_id)[2:].zfill(4)


def get_hmd_serial(device_path: Optional[bytes] = None, direct_device: Optional[hid.device] = None) -> Optional[str]:
    """
    Get the serial number from a connected HMD.
    
    Args:
        path: HID device path (bytes) used to open the device (optional)
        direct_device: Directly provided hid.device object (optional)
        
    Returns:
        Serial number string, or None if serial number cannot be retrieved
        
    Raises:
        OSError: If the device cannot be opened
    """
    raw_serial = None
    device = direct_device if direct_device is not None else hid.device()
    try:
        if direct_device is None:
            if device_path:
                device.open_path(device_path)
            else:
                device.open(vendor_id=BS_VID, product_id=APP_PID)
        # Try to get serial number via feature report (application mode)
        device.send_feature_report(bytes([0, HMD_GET_SERIAL]))
        
        # Wait for reception of this block
        while True:
            data = device.read(64)
            if data:
                if data[0] == HMD_GET_SERIAL:
                    # log_debug(f"Received serial from device")
                    raw_serial= bytes(data[1:63])
                    break
        
        try:
            if direct_device is None:
                device.close()
            return raw_serial.decode('ascii').replace('\x00', '').strip()
        except UnicodeDecodeError:
            # log_debug("Warning: Could not decode device serial as ASCII")
            if direct_device is None:
                device.close()
            return None
        
    finally:
        if direct_device is None:
            device.close()


# path is optional; if not provided, open by VID/PID
def get_hmd_fw_version(device_path: Optional[bytes] = None, direct_device: Optional[hid.device] = None, retry_time: Optional[int] = None) -> Optional[str]:
    """
    Get the firmware version from a connected HMD.
    
    Args:
        device_path: HID device path (bytes) used to open the device (optional)
        direct_device: Directly provided hid.device object (optional)
        retry_time: Time in seconds to retry getting version if initial attempt fails (optional)
        
    Returns:
        Firmware version string, or None if version cannot be retrieved
        
    Raises:
        OSError: If the device cannot be opened
    """
    if retry_time is not None:
        end_time = time.time() + retry_time
        while time.time() < end_time:
            try:
                version = get_hmd_fw_version(device_path, direct_device=direct_device)
                if version is not None:
                    return version
            except OSError:
                pass
            finally:
                time.sleep(0.5)
        return None

    
    device = direct_device if direct_device is not None else hid.device()
    try:
        if direct_device is None:
            if device_path:
                device.open_path(device_path)
            else:
                device.open(vendor_id=BS_VID, product_id=APP_PID)

        # Try to get version via feature report (application mode)
        device.send_feature_report(bytes([0, HMD_GET_SW_VER]))
        
        # Wait for response
        start_time = time.monotonic_ns()
        timeout_ms = 1000
        
        while (start_time + (timeout_ms * 1000000)) > time.monotonic_ns():
            response = device.read(65, timeout_ms=DEFAULT_HID_TIMEOUT)
            if len(response) > 0:
                response = bytes(response)
                if response[0] == HMD_GET_SW_VER:
                    version = response[1:].rstrip(b'\x00').decode('ascii')
                    return version if version else None
                elif response[0] == HMD_ERROR:
                    return None
        
        # If feature report fails, try direct write (bootloader mode)
        device.write(bytes([0, HMD_GET_SW_VER]))
        resp = device.read(65, timeout_ms=1000)
        
        if len(resp) > 0:
            resp = bytes(resp)
            if resp[0] == HMD_GET_SW_VER:
                # Echo response means no bootloader version
                if direct_device is None:
                    device.close()
                return None
            else:
                if direct_device is None:
                    device.close()
                return resp.rstrip(bytes([0])).decode('ascii')
        
        if direct_device is None:
            device.close()
        return None
        
    finally:
        if direct_device is None:
            device.close()


def fw_version_indicates_bootloader(version_str: str) -> bool:
    """Check if the firmware version string indicates bootloader mode."""
    return version_str.startswith("0.1.") or version_str.startswith("0.2.")


def get_utility_path():
    """Get the utility root path using the Steam installation path from Windows registry."""
    try:
        key = winreg.OpenKey(winreg.HKEY_CURRENT_USER, r"Software\Valve\Steam")
        steam_path = winreg.QueryValueEx(key, "SteamPath")[0]
        winreg.CloseKey(key)
        return os.path.abspath(steam_path + r"/steamapps/common/Bigscreen Beyond Driver")
    except FileNotFoundError:
        return None


def get_utility_beyondfw_path():
    """Get the utility HMD firmware path using the Steam installation path from Windows registry."""
    return get_utility_path() + r"/bin/latest.beyondfw"


def get_file_fw_version(file_path: str) -> str:
    """
    Get the firmware version from a .beyondfw file.
    
    Args:
        file_path: Path to the .beyondfw firmware file
        
    Returns:
        Firmware version string
        
    Raises:
        FileNotFoundError: If the file doesn't exist
        ValueError: If the file format is invalid
    """
    with open(file_path, 'rb') as f:
        # Read version string length (first byte)
        vlen_bytes = f.read(1)
        if len(vlen_bytes) != 1:
            raise ValueError("Invalid firmware file: cannot read version length")
        
        vlen = vlen_bytes[0]
        
        # Read version string
        version_bytes = f.read(vlen)
        if len(version_bytes) != vlen:
            raise ValueError("Invalid firmware file: version string truncated")
        
        f.close()
        return version_bytes.decode('ascii')


def _create_hid_message(hidtype: int, datalen: int, addr: int, data: bytes) -> bytes:
    """Create a HID message with CRC."""
    bytes_out = struct.pack('<BBI', hidtype, datalen, addr)
    bytes_out = bytes_out + data
    crcbyte = crc8_calc(0, bytes_out)
    bytes_out = bytes_out + bytes([crcbyte])
    return bytes_out


def _erase_8kb_block(device: hid.device, erase_addr: int) -> bool:
    """Erase an 8kB block of flash."""
    bytes_out = struct.pack('<BBI', 0x65, 0x00, erase_addr)
    crcbyte = crc8_calc(0, bytes_out)
    bytes_out = bytes_out + bytes([crcbyte])
    device.write(bytes([0]) + bytes_out)
    
    status = device.read(64, timeout_ms=1000)
    return len(status) > 1 and status[1] == 0


def _erase_16kb_block(device: hid.device, erase_addr: int) -> bool:
    """Erase a 16kB block of flash."""
    bytes_out = struct.pack('<BBI', 0x65, 0x01, erase_addr)
    crcbyte = crc8_calc(0, bytes_out)
    bytes_out = bytes_out + bytes([crcbyte])
    device.write(bytes([0]) + bytes_out)
    
    status = device.read(64, timeout_ms=1000)
    return len(status) > 1 and status[1] == 0


def _erase_sector(device: hid.device, erase_addr: int) -> bool:
    """Erase a full sector of flash."""
    bytes_out = struct.pack('<BBI', 0x65, 0x02, erase_addr)
    crcbyte = crc8_calc(0, bytes_out)
    bytes_out = bytes_out + bytes([crcbyte])
    # print(f"  Sending sector erase for addr 0x{erase_addr:08X}...")
    device.write(bytes([0]) + bytes_out)
    
    status = device.read(64, timeout_ms=2000)
    # print(f"  Sector erase response: {list(status[:10]) if status else 'None'} (len={len(status) if status else 0})")
    return len(status) > 1 and status[1] == 0


def _erase_blocks(device: hid.device, erase_addr: int, bytes_to_erase: int) -> bool:
    """Erase blocks of flash memory."""
    while bytes_to_erase > MIN_ERASABLE_SIZE:
        if not _erase_16kb_block(device, erase_addr):
            return False
        bytes_to_erase = bytes_to_erase - (2 * MIN_ERASABLE_SIZE)
        erase_addr = erase_addr + (2 * MIN_ERASABLE_SIZE)
    
    if bytes_to_erase > 0:
        return _erase_8kb_block(device, erase_addr)
    
    return True


def _erase_app_by_blocks(device: hid.device, erase_addr: int, bytes_to_erase: int) -> bool:
    """Erase application region by blocks."""
    divisions = [LARGE_SUBSECTOR_SIZE, SECTOR_SIZE, SECTOR_SIZE, SECTOR_SIZE]
    
    for div in divisions:
        if bytes_to_erase > div:
            if not _erase_sector(device, erase_addr):
                return False
            erase_addr = erase_addr + div
            bytes_to_erase = bytes_to_erase - div
        elif bytes_to_erase > 0:
            if not _erase_blocks(device, erase_addr, bytes_to_erase):
                return False
            bytes_to_erase = 0
    
    return True


def _erase_app_full(device: hid.device) -> bool:
    """Erase entire application region (fallback for older bootloaders)."""
    # print("  Sending full erase command (0x22)...")
    device.write(b'\x00\x22')
    status = device.read(64, timeout_ms=10000)
    # print(f"  Full erase response: {status[:10] if status else 'None'} (len={len(status) if status else 0})")
    
    if len(status) > 1:
        return status[1] == 0
    return True  # Assume success if no response


def load_beyondfw_file(file_path: str) -> Tuple[bytearray, int]:
    """Load and parse a .beyondfw file."""
    with open(file_path, 'rb') as f:
        data = f.read()
    
    # Parse header
    offset = 0
    vlen = data[offset]
    offset += 1
    
    # Skip version string
    offset += vlen
    
    # Read start address and data length
    start_addr, data_len = struct.unpack('<II', data[offset:offset + 8])
    offset += 8
    
    # Extract binary data
    binary_data = bytearray(data[offset:offset + data_len])
    offset += data_len
    
    # Verify SHA512 hash
    expected_hash = data[offset:offset + 64]
    computed_hash = hashlib.sha512(data[:offset]).digest()
    
    if expected_hash != computed_hash:
        raise ValueError("Firmware file hash verification failed")
    
    return binary_data, start_addr


def _get_existing_bootloader_paths() -> set:
    """
    Get the set of all currently connected bootloader device paths.
    
    Returns:
        Set of HID device paths (bytes) for all bootloader devices
    """
    bootloader_devices = hid.enumerate(vendor_id=BS_VID, product_id=BOOT_PID)
    return {dev['path'] for dev in bootloader_devices}


def _wait_for_new_bootloader(existing_paths: set, timeout_seconds: int = 10) -> Optional[bytes]:
    """
    Wait for a new bootloader device to appear that wasn't in the existing set.
    
    Args:
        existing_paths: Set of bootloader paths that existed before the mode switch
        timeout_seconds: Maximum time to wait for the new bootloader
        
    Returns:
        The HID path of the new bootloader device, or None if timeout
    """
    start_time = time.time()
    while (time.time() - start_time) < timeout_seconds:
        current_devices = hid.enumerate(vendor_id=BS_VID, product_id=BOOT_PID)
        for dev in current_devices:
            if dev['path'] not in existing_paths:
                return dev['path']
        time.sleep(0.5)
    return None


def device_is_beyond(device_path: bytes) -> int:
    """
    Check if the given HID device path corresponds to a Beyond HMD.
    
    Args:
        device_path: HID device path (bytes)

    Returns:
        0 if not a Beyond HMD, 1 if application mode, 2 if bootloader mode
    """
    path_str = device_path.decode('utf-8', errors='ignore')
    if 'VID_35BD' in path_str:
        if 'PID_0101' in path_str:
            return 1  # Application mode
        elif 'PID_4004' in path_str:
            return 2  # Bootloader mode
    return 0  # Not a Beyond HMD


def enter_hmd_bootloader(device_path: Optional[bytes] = None, direct_device: Optional[hid.device] = None, confirm_bootloader: bool = True) -> Optional[bytes]:
    """
    Send command to enter bootloader mode.
    
    This function blocks until the bootloader entry is complete. It will automatically enter bootloader mode if device is in application mode
    
    Args:
        device_path: HID device path (bytes) for the HMD
        direct_device: Optional direct hid.device object for the HMD
        confirm_bootloader: Whether to confirm that bootloader mode was entered (optional, default True)

    Returns:
        The bootloader device path (bytes) if successful & confirm_bootloader is True, None otherwise
        
    Raises:
        BootloaderEntryError: If bootloader entry fails with detailed error message
        OSError: If device cannot be opened
    """    
    # Open device and determine mode
    device = direct_device if direct_device is not None else hid.device()
    in_bootloader = False
    bootloader_path = None
    
    if device_path and device_is_beyond(device_path) == 2:
        in_bootloader = True
        bootloader_path = device_path
    elif direct_device is not None:
        in_bootloader = fw_version_indicates_bootloader(get_hmd_fw_version(direct_device=direct_device))
    
    # If not in bootloader, enter bootloader mode using the specific device path
    if not in_bootloader:
        # Survey existing bootloader devices BEFORE switching modes
        existing_bootloaders = _get_existing_bootloader_paths()
        
        # Try to open as application device using the specific path and switch to bootloader  
        try:
            if direct_device is None:
                if device_path:
                    device.open_path(device_path)
                else:
                    device.open(vendor_id=BS_VID, product_id=APP_PID)
            device.send_feature_report(bytes([0, ord('B')]))
            if direct_device is None:
                device.close()
        except OSError:
            # Device might already be disconnected or in bootloader
            if not device_path:
                bootloader_path = existing_bootloaders.pop() if existing_bootloaders else None
            pass
        finally:
            if direct_device is None:
                device.close()
        
        if confirm_bootloader:
            # Wait for the NEW bootloader device to appear (one that wasn't in our survey)
            if not bootloader_path:
                bootloader_path = _wait_for_new_bootloader(existing_bootloaders, timeout_seconds=10)
            
            if bootloader_path:
                try:
                    # Create a fresh device object for the bootloader
                    device = hid.device()
                    device.open_path(bootloader_path)
                    in_bootloader = True
                    if direct_device is None:
                            device.close()
                except OSError as e:
                    raise OSError(f"Failed to open new bootloader device: {e}")
                finally:
                    if direct_device is None:
                        device.close()
            else:
                raise BootloaderEntryError("Failed to enter bootloader mode - no new bootloader device appeared")
    return bootloader_path


def update_hmd_firmware(firmware_payload: Tuple[bytearray, int] = None, bootloader_path: Optional[bytes] = None, direct_bootloader: Optional[hid.device] = None, reset_when_done: Optional[bool] = True) -> bool:
    """
    Update HMD firmware from a .beyondfw file.
    
    This function blocks until the update is complete. It will automatically:
    - Connect to an HMD in bootloader mode
    - Perform the firmware update
    - Return to application mode when complete (configurable)
    
    Args:
        firmware_payload: Tuple of (firmware data, start address). Create using load_beyondfw_file().
        bootloader_path: HID device path (bytes) for the HMD in bootloader mode (required for multi-device support)
        direct_bootloader: Optional direct hid.device object for the bootloader device
        reset_when_done: Whether to command the device to boot back into application mode when done (optional, default True)

    Returns:
        True if update succeeded, False otherwise
        
    Raises:
        FirmwareUpdateError: If update fails with detailed error message
        FileNotFoundError: If firmware file doesn't exist
        ValueError: If firmware file is invalid
    """
    device_mode = 0

    if bootloader_path:
        device_mode = device_is_beyond(bootloader_path)
        if device_mode == 0:
            raise FirmwareUpdateError("Provided device path is not a Beyond HMD")
        if device_mode == 1:
            raise FirmwareUpdateError("Provided device path is not in bootloader mode. Please use enter_hmd_bootloader() first")
    elif direct_bootloader is not None:
        device_mode = 2 if fw_version_indicates_bootloader(get_hmd_fw_version(direct_device=direct_bootloader)) else 0
        if device_mode == 0:
            raise FirmwareUpdateError("Provided direct device is not in bootloader mode. Please use enter_hmd_bootloader() first")

    # Use preloaded data if available
    if firmware_payload:
        app_data = bytearray(firmware_payload[0])  # Make a copy
        app_addr = firmware_payload[1]
    else:  # Load and parse firmware file
        raise FirmwareUpdateError("Firmware payload must be provided")
    
    # Validate firmware
    if len(app_data) > ALLOWED_APP_SIZE:
        raise FirmwareUpdateError("Firmware file too large for device memory")
    
    if app_addr != APP_ADDRESS:
        raise FirmwareUpdateError(f"Firmware address mismatch: expected 0x{APP_ADDRESS:08X}, got 0x{app_addr:08X}")
    
    # Use the bootloader_path if provided, otherwise open first available
    
    device = direct_bootloader if direct_bootloader is not None else hid.device()
    try:
        if direct_bootloader is None:
            if bootloader_path:
                try:
                    device.open_path(bootloader_path)
                except OSError as e:
                    raise FirmwareUpdateError(f"Cannot connect to bootloader at specified path: {e}")
            else:
                device.open(vendor_id=BS_VID, product_id=BOOT_PID)
    except OSError as e:
        raise FirmwareUpdateError(f"Cannot connect to bootloader: {e}")
    
    try:
        # Give the bootloader a moment to stabilize after opening
        time.sleep(0.5)
        
        # Erase flash
        erase_result = _erase_app_by_blocks(device, app_addr, len(app_data))
        if not erase_result:
            # Try full erase as fallback
            full_erase_result = _erase_app_full(device)
            if not full_erase_result:
                raise FirmwareUpdateError("Failed to erase application region")
        
        # Program firmware
        print("Programming firmware...")
        addr = app_addr
        cur_page_addr = app_addr
        num_chunks = len(app_data) // CHUNKSIZE
        
        for chunk_idx in range(num_chunks):
            bin_chunk = app_data[CHUNKSIZE * chunk_idx:CHUNKSIZE * (chunk_idx + 1)]
            bytes_out = _create_hid_message(0x44, CHUNKSIZE, addr, bin_chunk)
            
            device.write(bytes([0]) + bytes_out)
            status = device.read(65, timeout_ms=1)
            
            if len(status) > 1 and status[1] != 0:
                raise FirmwareUpdateError(f"Write error at address 0x{addr:08X}, code: {status[1]}")
            
            addr += CHUNKSIZE
            
            # Program page when full
            if (addr % PAGE_SIZE) == 0:
                bytes_out = _create_hid_message(0x50, 0, cur_page_addr, b'')
                device.write(bytes([0]) + bytes_out)
                status = device.read(65, timeout_ms=1000)
                
                if len(status) > 1 and status[1] != 0:
                    raise FirmwareUpdateError(f"Program error at page 0x{cur_page_addr:08X}, code: {status[1]}")
                
                cur_page_addr += PAGE_SIZE
        
        # Handle remaining partial chunk
        if len(app_data) % CHUNKSIZE != 0:
            bin_chunk = app_data[CHUNKSIZE * num_chunks:]
            bin_len = len(bin_chunk)
            
            # Pad to 8-byte boundary
            if bin_len % 8 != 0:
                bin_chunk = bin_chunk + bytearray([0xFF] * (8 - bin_len % 8))
            
            bytes_out = _create_hid_message(0x44, bin_len, addr, bin_chunk)
            device.write(bytes([0]) + bytes_out)
            status = device.read(65, timeout_ms=1)
            
            if len(status) > 1 and status[1] != 0:
                raise FirmwareUpdateError(f"Write error at address 0x{addr:08X}, code: {status[1]}")
            
            # Program final page
            bytes_out = _create_hid_message(0x50, 0, cur_page_addr, b'')
            device.write(bytes([0]) + bytes_out)
            status = device.read(65, timeout_ms=1000)
            
            if len(status) > 1 and status[1] != 0:
                raise FirmwareUpdateError(f"Program error at page 0x{cur_page_addr:08X}, code: {status[1]}")
        
        elif addr > cur_page_addr:
            # Program final page from loop
            bytes_out = _create_hid_message(0x50, 0, cur_page_addr, b'')
            device.write(bytes([0]) + bytes_out)
            status = device.read(65, timeout_ms=1000)
            
            if len(status) > 1 and status[1] != 0:
                raise FirmwareUpdateError(f"Program error at page 0x{cur_page_addr:08X}, code: {status[1]}")
        
        if reset_when_done:
            # Boot back into application mode
            device.write(bytes([0, ord('B')]))
        
        if direct_bootloader is None:
            device.close()
        return True
        
    finally:
        if direct_bootloader is None:
            device.close()



def read_i2c(device: hid.device, reg_addr: int, num_bytes: int) -> bytes:
    """
    Read from FPGA via I2C through the HMD device.
    
    Args:
        device: HMD device object
        reg_addr: FPGA register address to read from
        num_bytes: Number of bytes to read
        
    Returns:
        bytes: Data read from register, or empty list on error/timeout
    """
    if device is not None:
        # Send I2C read command
        # Format: 'e' + 'i' (I2C read) + length + register_address
        device.send_feature_report(b'\x00ei' + bytes([num_bytes, reg_addr]))

        # 200ms timeout
        start_time = time.time()
        while (time.time() - start_time) < 0.2:
            retbytes = device.read(65, timeout_ms=10)
            if len(retbytes) > 0 and retbytes[0] == ord('e'):  # fpga response
                if len(retbytes) > 1:
                    if retbytes[1] == num_bytes:
                        return retbytes[2:2+num_bytes]
                    else:
                        return []            
            if len(retbytes) > 0 and retbytes[0] == ord('E'):  # any error occurred
                return []
        
        return []  # timeout
    else:  # bg is None
        return []


def fpga_mode_is_dfu(device_path: Optional[bytes] = None, direct_device: Optional[hid.device] = None) -> bool:
    """
    Check if the FPGA is in DFU mode.
    
    Args:
        device_path: The device path to open with hidapi
        direct_device: Directly provided hid.device object (optional)
        
    Returns:
        bool: True if in DFU mode, False if in camera mode
    """
    device = direct_device if direct_device is not None else hid.device()
    try:
        if direct_device is None:
            if device_path:
                device.open_path(device_path)
            else:
                device.open(vendor_id=BS_VID, product_id=APP_PID)
        
        # Read register 0xA0 (Configuration ID code)
        response = read_i2c(device, 0xA0, 1)
        
        if len(response) == 0:
            return None
        
        config_id = response[0]
        
        # 'B' (0x42) = Bootloader, 'A' (0x41) = Camera
        return config_id == 0x42
        
    finally:
        if direct_device is None:
            device.close()


def get_hmd_fpga_fw_version(device_path: Optional[bytes] = None, direct_device: Optional[hid.device] = None) -> str:
    """
    Get the FPGA software version string.
    
    Args:
        device_path: The device path to open with hidapi
        direct_device: Directly provided hid.device object (optional)
        
    Returns:
        str: Version string in format "X.Y.Z" (e.g., "4.8.2")
    """
    device = direct_device if direct_device is not None else hid.device()
    try:
        if direct_device is None:
            if device_path:
                device.open_path(device_path)
            else:
                device.open(vendor_id=BS_VID, product_id=APP_PID)
        
        # Read register 0xA1 (SW version high byte, BCD)
        response_high = read_i2c(device, 0xA1, 1)
        if len(response_high) == 0:
            return None
        high_byte = response_high[0]
        
        # Read register 0xA2 (SW version low byte, BCD)
        response_low = read_i2c(device, 0xA2, 1)
        if len(response_low) == 0:
            return None
        low_byte = response_low[0]
        
        # Convert BCD to decimal
        # Each nibble (4 bits) represents a decimal digit
        # High byte: major version
        major = ((high_byte >> 4) & 0x0F) * 10 + (high_byte & 0x0F)
        
        # Low byte: minor version (two digits)
        minor_high = (low_byte >> 4) & 0x0F
        minor_low = low_byte & 0x0F
        
        # Format as version string
        # Example: 0x04 0x82 -> 4.8.2
        if direct_device is None:
            device.close()
        return f"{major}.{minor_high}.{minor_low}"
        
    finally:
        if direct_device is None:
            device.close()


def get_utility_etfw_path():
    """Get the utility eyetracking firmware path using the Steam installation path from Windows registry."""
    return get_utility_path() + r"/bin/eyetracking_firm.dfu"


def get_dfu_file_fw_version(dfu_file_path: str) -> Optional[str]:
    """
    Extract the firmware version string from a DFU file.
    
    The version is stored as a 16-bit value at the beginning of the 16-byte
    DFU suffix (last 16 bytes of the file). The version is converted to a
    hexadecimal string representation.
    
    Args:
        dfu_file_path: Path to the .dfu file
        
    Returns:
        str: Version string in format "X.Y.Z" (e.g., "4.8.2")
             or None if version cannot be extracted
    """
    try:
        with open(dfu_file_path, 'rb') as f:
            # Seek to end to get file length
            f.seek(0, 2)  # SEEK_END
            file_length = f.tell()
            
            # Check if file is large enough (at least 16 bytes for suffix + 2 for version)
            if file_length < 18:
                return None
            
            # Seek to the start of the DFU suffix (last 16 bytes)
            f.seek(file_length - 16)
            
            # Read the first 2 bytes (uint16_t) as little-endian
            version_bytes = f.read(2)
            if len(version_bytes) != 2:
                return None
            
            # Unpack as little-endian uint16
            version_raw = struct.unpack('<H', version_bytes)[0]
            
            # Convert to hexadecimal string (without '0x' prefix)
            version_str = f"{version_raw:03x}"
            if not version_str or len(version_str) < 2:
                return None
    
            parts = list(version_str)
            return '.'.join(parts)
            
    except (IOError, OSError) as e:
        print(f"Error reading DFU file: {e}")
        return None
    except Exception as e:
        print(f"Error parsing DFU file: {e}")
        return None


def get_dfu_payload_fw_version(firmware_payload: bytearray) -> Optional[str]:
    """
    Extract the firmware version string from a DFU firmware payload.
    
    The version is stored as a 16-bit value at the beginning of the 16-byte
    DFU suffix (last 16 bytes of the payload). The version is converted to a
    hexadecimal string representation.
    
    Args:
        firmware_payload: Firmware data loaded from file as bytearray.
        
    Returns:
        str: Version string in format "X.Y.Z" (e.g., "4.8.2")
             or None if version cannot be extracted
    """
    try:
        if len(firmware_payload) < 18:
            return None
        
        # Extract the last 16 bytes (DFU suffix)
        dfu_suffix = firmware_payload[-16:]
        
        # Read the first 2 bytes (uint16_t) as little-endian
        version_bytes = dfu_suffix[0:2]
        
        # Unpack as little-endian uint16
        version_raw = struct.unpack('<H', version_bytes)[0]
        
        # Convert to hexadecimal string (without '0x' prefix)
        version_str = f"{version_raw:03x}"
        if not version_str or len(version_str) < 2:
            return None

        parts = list(version_str)
        return '.'.join(parts)
        
    except Exception as e:
        print(f"Error parsing DFU payload: {e}")
        return None


def enter_fpga_dfu(device_path: Optional[bytes] = None, direct_device: Optional[hid.device] = None, exit_dfu: Optional[bool] = False) -> bool:
    """
    Send command to enter DFU mode.
    
    This function confirms that the FPGA is in run-time mode, then activates DFU mode.
    
    Args:
        device_path: HID device path (bytes) for the HMD
        direct_device: Directly provided hid.device object (optional)
        exit_dfu: If True, exit DFU mode instead of entering it (optional, default False)

    Returns:
        The DFU device's port chain if it is ultimately in DFU mode
        
    Raises:
        OSError: If device cannot be opened
    """    
    # Open device and determine mode
    device = direct_device if direct_device is not None else hid.device()
    serial = None
    dfu_target_state = not exit_dfu
    success = False
    
    try:
        if direct_device is None:
            if device_path:
                device.open_path(device_path)
            else:
                device.open(vendor_id=BS_VID, product_id=APP_PID)

        serial = device.get_serial_number_string()
        if serial is None or serial == '':
            raise OSError("Could not get Atmel SN before entering DFU mode")
        
        # exit_dfu is opposite the result we want to achieve
        if fpga_mode_is_dfu(direct_device=device) != dfu_target_state:
            device.send_feature_report(b'\x00eB')
            # try for up to 5 seconds to confirm mode switch
            start_time = time.time()
            while (time.time() - start_time) < 5:
                time.sleep(0.5)
                if fpga_mode_is_dfu(direct_device=device) == dfu_target_state:
                    success = True
                    break
            if not success:
                raise OSError("Failed to enter FPGA bootloader mode")
        
        
        usb_info = bs_usb_tools.get_info(f"*{serial}*")
        retries = 3
        while 'PortChain' not in usb_info and retries > 0:
            time.sleep(1)
            usb_info = bs_usb_tools.get_info(f"*{serial}*")
            retries -= 1
        
        if 'PortChain' not in usb_info:
            raise OSError(f"Could not get HMD {serial} PortChain after entering DFU mode")
        
        if direct_device is None:
            device.close()
        return usb_info['PortChain'][:-2] + ['2']  # FPGA is always at index 2 in chain
    finally:
        if direct_device is None:
            device.close()


def update_fpga_firmware(firmware_payload: bytearray = None, dfu_device_port_chain: Optional[str] = None, reset_device_path: Optional[bytes] = None, reset_direct_device: Optional[hid.device] = None) -> bool:
    """
    Update FPGA firmware from a .dfu file.
    
    This function blocks until the update is complete. It will automatically:
    - Connect to an HMD's FPGA in DFU mode
    - Perform the firmware update
    - Return the FPGA to run-time mode when complete (configurable)
    
    Args:
        firmware_payload: Firmware data loaded from file. Create using bs_usb_tools.load_file().
        dfu_port_chain: Port chain of the device in DFU mode (required for multi-device support). Can be obtained using enter_fpga_dfu().
        reset_device_path: HMD device path for the FPGA to reset after update (optional, default None)
        reset_direct_device: Directly provided hid.device object for the FPGA to reset after update (optional, default None)

    Returns:
        True if update succeeded, False otherwise
        
    Raises:
        FirmwareUpdateError: If update fails with detailed error message
        FileNotFoundError: If firmware file doesn't exist
        ValueError: If firmware file is invalid
    """
    fw_version = None
    update_result = False

    target_exists = bs_usb_tools.dfu_device_exists(dfu_vid=hex_id_to_str(BS_VID), dfu_pid=hex_id_to_str(DFU_PID), dfu_port_chain=dfu_device_port_chain)
    if not target_exists:
        raise FirmwareUpdateError("Cannot connect to specified FPGA device. Is it in DFU mode?")

    if firmware_payload:
        app_data = bytearray(firmware_payload)  # Make a copy
    else:
        raise FirmwareUpdateError("Firmware payload must be provided")
    
    # Validate firmware
    fw_version = get_dfu_payload_fw_version(app_data)
    if fw_version is None:
        raise FirmwareUpdateError("Invalid DFU firmware payload: cannot extract version number")
    
    retries = 3
    while retries > 0 and not update_result:
        update_result = bs_usb_tools.dfu_download(app_data, hex_id_to_str(BS_VID), hex_id_to_str(DFU_PID), dfu_device_port_chain)
        retries -= 1
    if not update_result:
        raise FirmwareUpdateError("FPGA firmware update failed")
    
    if reset_device_path is not None or reset_direct_device is not None:
        try:
            time.sleep(0.5)
            if reset_direct_device is None:
                enter_fpga_dfu(device_path=reset_device_path, exit_dfu=True)
            else:
                enter_fpga_dfu(direct_device=reset_direct_device, exit_dfu=True)
        except OSError as e:
            raise FirmwareUpdateError(f"Error resetting DFU reset target device at specified path: {e}")
    
    return update_result


# Example usage
if __name__ == "__main__":
    # Example: Get version from file
    try:
        version = get_file_fw_version("firmware.beyondfw")
        print(f"File firmware version: {version}")
    except Exception as e:
        print(f"Error reading file: {e}")
    
    # Example: Get version from connected device
    # Note: You need to enumerate devices first to get the path
    for dev in hid.enumerate(vid=BS_VID):
        try:
            version = get_hmd_fw_version(dev['path'])
            print(f"Device firmware version: {version}")
        except Exception as e:
            print(f"Error reading device: {e}")
    
    # Example: Update firmware
    # update_hmd_firmware("firmware.beyondfw", device_path)