############################################
# syn_fw_loader.py
# Version: 2
# Date: January 13, 2025
# Author: David Miller
#
# Updates firmware files on the Beyond's 
# VXR7200 DisplayPort-to-MIPI bridge.
# 
# To delete all firmware and completely
# clear VXR7200's internal flash, use the
# "-e" option.
# Example:
# > python .\syn_firmware_loader.py -e
# 
# To load a new firmware file, use the
# "-f" option.
# Example:
# > python .\syn_firmware_loader.py -f <path\to\firmware\file.fullrom>
#
############################################

import struct
import hid
import argparse
import sys

CMD_GET_ACTIVE_VXR_TAGS = ord('T')
CMD_DEL_VXRFLASH_BANK = ord('D')
CMD_GET_VXRFLASH_CHKSUM = ord('K')
CMD_VXRFLASH_PROGRAM = ord('A')
CMD_SUCCESS = ord('$')
CMD_ERROR = ord('E')

def hid_open_safe(vid = 0x35bd, pid = 0x0101):
    ####################
    # Note: hid and hidapi python packages
    # use the same import name "hid"
    # This script automatically selects 
    # the correct function names based on
    # whichever package is loaded.
    ####################
    if getattr(hid,'Device',None) is None:
        # we have hidapi installed
        hiddev = hid.device()
        hiddev.open(vendor_id=vid, product_id=pid)
    else:
        hiddev = hid.Device(vid=vid, pid=pid)
    return hiddev

class Firmware_Blocks:
    Config0 = 1
    Config1 = 2
    Firmware0 = 4
    Firmware1 = 8

    Config_Length = 0x08000
    Firmware_Length = 0x18000
    ESM_Length = 0x40000

    Config0_Start		= 0x00000
    Config0_Tag_Start	= 0x07FF0
    Config0_Tag_End		= 0x07FFF
    Config0_End			= Config0_Tag_End
    Firmware0_Start		= 0x08000
    Firmware0_Tag_Start	= 0x1FFF0
    Firmware0_Tag_End	= 0x1FFFF
    Firmware0_End		= Firmware0_Tag_End
    Config1_Start		= 0x20000
    Config1_Tag_Start	= 0x27FF0
    Config1_Tag_End		= 0x27FFF
    Config1_End			= Config1_Tag_End
    Firmware1_Start		= 0x28000
    Firmware1_Tag_Start	= 0x3FFF0
    Firmware1_Tag_End	= 0x3FFFF
    Firmware1_End		= Firmware1_Tag_End
    ESM_Start           = 0x40000
    ESM_End             = 0x7FFFF

def config_and_fw_in_same_bank(active_tags):
    if(active_tags & Firmware_Blocks.Config0):
        if(active_tags & Firmware_Blocks.Firmware0):
            return True
    if(active_tags & Firmware_Blocks.Config1):
        if(active_tags & Firmware_Blocks.Firmware1):
            return True

    return False

def load_fullrom_file(filename):

    with open(filename, "rb") as fileh:
        config0_block = fileh.read(Firmware_Blocks.Config_Length)
        firmware0_block = fileh.read(Firmware_Blocks.Firmware_Length)
        # config1 and fw1 are copies of config0 and fw0
        # regardless, we have to read them to get to the last block (HDCP ESM)
        config1_block = fileh.read(Firmware_Blocks.Config_Length)
        firmware1_block = fileh.read(Firmware_Blocks.Firmware_Length)
        esm_block = fileh.read()

    # Confirm a few things:
    if((len(config0_block) != Firmware_Blocks.Config_Length) or (len(config1_block) != Firmware_Blocks.Config_Length) or
        (len(firmware0_block) != Firmware_Blocks.Firmware_Length) or (len(firmware1_block) != Firmware_Blocks.Firmware_Length)
        or (len(esm_block) != Firmware_Blocks.ESM_Length)):
        raise RuntimeError("Firmware file has invalid format (length incorrect)")
    if((config0_block != config1_block) or (firmware0_block != firmware1_block)):
        raise RuntimeError("Firmware file has invalid format (blocks 0 and 1 do not match)")

    return (config0_block, firmware0_block, esm_block)

def wait_for_hid_packet(hid_handle, matching_tags):
    recbytes = b''
    while(len(recbytes) == 0):
        recbytes = bytes(hid_handle.read(65, 0))
        if((len(recbytes) > 0) and (recbytes[0] in matching_tags)):
            return recbytes
        recbytes = b''

def check_vxr_tags():
    # bigs = hid.device()
    # bigs.open(0x35bd, 0x0101)
    bigs = hid_open_safe()
    bigs.send_feature_report(bytes([0, CMD_GET_ACTIVE_VXR_TAGS]))

    tags = None
    # Wait until we get the reply packet or an error
    recpacket = wait_for_hid_packet(bigs, [CMD_GET_ACTIVE_VXR_TAGS, CMD_ERROR])
    if(recpacket[0] == CMD_ERROR):
        raise RuntimeError("Cannot read VXR Flash tags, error code 0x{:02X}".format(recpacket[1]))
    if(recpacket[0] == CMD_GET_ACTIVE_VXR_TAGS):
        tags = recpacket[1]

    return tags

def vxr_get_checksum(start_addr, length):
    # bigs = hid.device()
    # bigs.open(0x35bd, 0x0101)
    bigs = hid_open_safe()
    partial_command = struct.pack("<II", start_addr, length)
    bigs.send_feature_report(bytes([0, CMD_GET_VXRFLASH_CHKSUM])+partial_command)

    # Wait until we get the reply packet or an error
    recpacket = wait_for_hid_packet(bigs, [CMD_GET_VXRFLASH_CHKSUM, CMD_ERROR])
    checksum = None
    if(recpacket[0] == CMD_ERROR):
        raise RuntimeError("Could not calculate checksum, error code 0x{:02X}".format(recpacket[1]))
    if(recpacket[0] == CMD_GET_VXRFLASH_CHKSUM):
        if(recpacket[1] != 4):
            # always a 4-byte reply for the checksum
            raise RuntimeError("Error in received checksum format")
        checksum = struct.unpack("<I",recpacket[2:6])[0]

    return checksum

def vxr_program(start_addr, prog_bytes):
    if((len(prog_bytes) == 0) or (len(prog_bytes) > 32)):
        raise RuntimeError("Invalid programming length. Should be 1-32. Got {}".format(len(prog_bytes)))
    # bigs = hid.device()
    # bigs.open(0x35bd, 0x0101)
    bigs = hid_open_safe()
    partial_command = struct.pack("<BI", len(prog_bytes), start_addr)
    bigs.send_feature_report(bytes([0, CMD_VXRFLASH_PROGRAM])+partial_command+prog_bytes)
    recpacket = wait_for_hid_packet(bigs, [CMD_SUCCESS, CMD_ERROR])
    if(recpacket[0] == CMD_ERROR):
        raise RuntimeError("Could not program Flash. Addr: 0x{:05X}, length: {}, error code: 0x{:02X}".format(start_addr, len(prog_bytes), recpacket[1]))
    

def invalidate_vxr_tag(region):
    if(region & Firmware_Blocks.Config0):
        vxr_program(Firmware_Blocks.Config0_Tag_Start, bytes([0]))
        vxr_program(Firmware_Blocks.Config0_Tag_End, bytes([0]))
    elif(region & Firmware_Blocks.Config1):
        vxr_program(Firmware_Blocks.Config1_Tag_Start, bytes([0]))
        vxr_program(Firmware_Blocks.Config1_Tag_End, bytes([0]))
    elif(region & Firmware_Blocks.Firmware0):
        vxr_program(Firmware_Blocks.Firmware0_Tag_Start, bytes([0]))
        vxr_program(Firmware_Blocks.Firmware0_Tag_End, bytes([0]))
    elif(region & Firmware_Blocks.Firmware1):
        vxr_program(Firmware_Blocks.Firmware1_Tag_Start, bytes([0]))
        vxr_program(Firmware_Blocks.Firmware1_Tag_End, bytes([0]))

# Returns a tuple (update_config, update_fw, update_esm)
# if the "update_*" variable is true, that region needs
# to be updated in Flash from the file
# if false, it can be skipped
def vxr_compare_file_to_flash_checksum(filename):
    # return values
    update_config = False
    update_fw = False
    update_esm = False

    # get the blocks from the file
    (config, firmware, esm) = load_fullrom_file(filename)

    # get which banks are active
    active_tags = check_vxr_tags()
    # print("Currently active: {} and {}".format(
    #     "Config0" if (active_tags & Firmware_Blocks.Config0) else "Config1",
    #     "Firmware0" if (active_tags & Firmware_Blocks.Firmware0) else "Firmware1"
    # ))
    if(active_tags & Firmware_Blocks.Config0):
        flash_checksum = vxr_get_checksum(Firmware_Blocks.Config0_Start, Firmware_Blocks.Config_Length)
        file_checksum = sum(config)
        if(flash_checksum != file_checksum):
            update_config = True
            # print("Config0 doesn't match file, Config1 needs to be updated")
            # print("Flash: 0x{:08X}, File: 0x{:08X}".format(flash_checksum, file_checksum))
        else:
            update_config = False
            # print("Config0 matches file! Will not update.")
    else:
        flash_checksum = vxr_get_checksum(Firmware_Blocks.Config1_Start, Firmware_Blocks.Config_Length)
        file_checksum = sum(config)
        if(flash_checksum != file_checksum):
            update_config = True
            # print("Config1 doesn't match file, Config0 needs to be updated")
            # print("Flash: 0x{:08X}, File: 0x{:08X}".format(flash_checksum, file_checksum))
        else:
            update_config = False
            # print("Config1 matches file! Will not update.")

    if(active_tags & Firmware_Blocks.Firmware0):
        flash_checksum = vxr_get_checksum(Firmware_Blocks.Firmware0_Start, Firmware_Blocks.Firmware_Length)
        file_checksum = sum(firmware)
        if(flash_checksum != file_checksum):
            update_fw = True
            # print("Firmware0 doesn't match, Firmware1 needs to be updated")
            # print("Flash: 0x{:08X}, File: 0x{:08X}".format(flash_checksum, file_checksum))
        else:
            update_fw = False
            # print("Firmware0 matches file! Will not update.")
    else:
        flash_checksum = vxr_get_checksum(Firmware_Blocks.Firmware1_Start, Firmware_Blocks.Firmware_Length)
        file_checksum = sum(firmware)
        if(flash_checksum != file_checksum):
            update_fw = True
            # print("Firmware1 doesn't match, Firmware0 needs to be updated.")
            # print("Flash: 0x{:08X}, File: 0x{:08X}".format(flash_checksum, file_checksum))
        else:
            update_fw = False
            # print("Firmware1 matches file! Will not update.")

    flash_checksum = vxr_get_checksum(Firmware_Blocks.ESM_Start, Firmware_Blocks.ESM_Length)
    file_checksum = sum(esm)
    if(flash_checksum != file_checksum):
        update_esm = True
        # print("ESM doesn't match, needs to be updated")
        # print("Flash: 0x{:08X}, File: 0x{:08X}".format(flash_checksum, file_checksum))
    else:
        update_esm = False
        # print("ESM matches! Will not update.")

    return (update_config, update_fw, update_esm)

def vxr_delete_bank(bank_to_delete):
    flash_blocks = []
    if(bank_to_delete == Firmware_Blocks.Config0):
        flash_blocks = [0]
    if(bank_to_delete == Firmware_Blocks.Config1):
        flash_blocks = [2]
    if(bank_to_delete == Firmware_Blocks.Firmware0):
        flash_blocks = [0,1]
    if(bank_to_delete == Firmware_Blocks.Firmware1):
        flash_blocks = [2,3]

    # bigs = hid.device()
    # bigs.open(0x35bd, 0x0101)
    bigs = hid_open_safe()
    for flash_block_num in flash_blocks:
        # bigs.send_feature_report(bytes([0, CMD_DEL_VXRFLASH_BANK, flash_block_num]))
        # Changed in firmware version 0.2.11
        # vxr delete now can delete in 4kB sizes, and requires a passkey "VXRDELETE" after the command
        bigs.send_feature_report(bytes([0, CMD_DEL_VXRFLASH_BANK, flash_block_num, 0]) + b'VXRDELETE')
        recpacket = wait_for_hid_packet(bigs, [CMD_SUCCESS, CMD_ERROR])
        if(recpacket[0] == CMD_ERROR):
            raise RuntimeError("Could not delete Flash block {}, error code 0x{:02X}.".format(flash_block_num, recpacket[1]))

def vxr_delete_esm():
    # bigs = hid.device()
    # bigs.open(0x35bd, 0x0101)
    bigs = hid_open_safe()
    for flash_block_num in range(4,8):
        # bigs.send_feature_report(bytes([0, CMD_DEL_VXRFLASH_BANK, flash_block_num]))
        # Changed in firmware version 0.2.11
        # vxr delete now can delete in 4kB sizes, and requires a passkey "VXRDELETE" after the command
        bigs.send_feature_report(bytes([0, CMD_DEL_VXRFLASH_BANK, flash_block_num, 0]) + b'VXRDELETE')
        recpacket = wait_for_hid_packet(bigs, [CMD_SUCCESS, CMD_ERROR])
        if(recpacket[0] == CMD_ERROR):
            raise RuntimeError("Could not delete Flash block {}, error code 0x{:02X}.".format(flash_block_num, recpacket[1]))
    
def vxr_program_region(program_bytes, start_addr, report_after_every=8*1024):
    # bigs = hid.device()
    # bigs.open(0x35bd, 0x0101)
    bigs = hid_open_safe()
    bytes_sent = 0

    # start splitting up into 32 byte chunks
    while(bytes_sent < len(program_bytes)):
        if((len(program_bytes) - bytes_sent) >= 32):
            program_chunk = program_bytes[bytes_sent:(bytes_sent+32)]
        else:
            # less than 32 bytes remain, grab everything that's left
            program_chunk = program_bytes[bytes_sent:]

        chunk_addr = start_addr + bytes_sent
        program_cmd = bytes([0, CMD_VXRFLASH_PROGRAM, len(program_chunk)]) + struct.pack('<I', chunk_addr) + program_chunk
        bigs.send_feature_report(program_cmd)

        recpacket = wait_for_hid_packet(bigs, [CMD_SUCCESS, CMD_ERROR])
        if(recpacket[0] == CMD_ERROR):
            raise RuntimeError("Could not program Flash at 0x{:05X}. Error code 0x{:02X}".format(chunk_addr, recpacket[1]))
        bytes_sent = bytes_sent + len(program_chunk)
        if(0 == (bytes_sent % report_after_every)):
            print('{} bytes programmed. {}% complete'.format(bytes_sent, int(100*(bytes_sent/len(program_bytes)))))
def vxr_flash_update(filename):
    # Updating the VXR7200 Flash has a number of steps:
    #   1) Check which banks are active (config and firmware)
    #   2) Check if the active config/firmware and esm already match
    #       what's in the file to be loaded
    #   3) Delete and program the regions that need to be updated

    print('Checking requirements...')
    (config0_block, firmware0_block, esm_block) = load_fullrom_file(filename)
    (update_config, update_fw, update_esm) = vxr_compare_file_to_flash_checksum(filename)
    active_banks = check_vxr_tags()
    active_config_bank = active_banks & (Firmware_Blocks.Config0 | Firmware_Blocks.Config1)
    active_firmware_bank = active_banks & (Firmware_Blocks.Firmware0 | Firmware_Blocks.Firmware1)

    if(active_config_bank == 0 or active_firmware_bank == 0):
        # At least one region is unprogrammed. Consider this a blank chip for 
        # config and firmware. 
        print('Config and/or Firmware are blank. Programming bank 0.')
        print('Deleting...')
        vxr_delete_bank(Firmware_Blocks.Firmware0)
        print('Programming config...')
        vxr_program_region(config0_block, Firmware_Blocks.Config0_Start)
        print('Programming firmware...')
        vxr_program_region(firmware0_block, Firmware_Blocks.Firmware0_Start)
        invalidate_vxr_tag(Firmware_Blocks.Config1 | Firmware_Blocks.Firmware1)
    else:
        if(update_config):
            if(not update_fw):
                # Only updating config region.

                # load into the opposite bank of what's active right now
                config_bank_to_delete = Firmware_Blocks.Config0 if active_config_bank == Firmware_Blocks.Config1 else Firmware_Blocks.Config1
                config_start_address = Firmware_Blocks.Config0_Start if active_config_bank == Firmware_Blocks.Config1 else Firmware_Blocks.Config1_Start
                print('Deleting config region...')
                vxr_delete_bank(config_bank_to_delete)
                print('Programming config...')
                vxr_program_region(config0_block, config_start_address)
                # Check if we just deleted some of the firmware region when updating.
                # ...if the config and FW are currently running in opposite banks, when
                # deleting the to-be-loaded config bank, the first 32k of active firmware will
                # also be deleted. This means we will have to reload it too.
                if(not config_and_fw_in_same_bank(active_banks)):
                    vxr_program_region(firmware0_block[:Firmware_Blocks.Config_Length], config_start_address + Firmware_Blocks.Config_Length)
                # now invalidate the old config bank
                invalidate_vxr_tag(active_config_bank)
            else:
                # Updating both config and firmware regions
                # Delete the entire bank (config + fw) opposite of the currently loaded firmware and place
                # both config and firmware there
                config_start_address = Firmware_Blocks.Config0_Start if active_firmware_bank == Firmware_Blocks.Firmware1 else Firmware_Blocks.Config1_Start
                fw_bank_to_delete = Firmware_Blocks.Firmware0 if active_firmware_bank == Firmware_Blocks.Firmware1 else Firmware_Blocks.Firmware1
                fw_start_address = Firmware_Blocks.Firmware0_Start if active_firmware_bank == Firmware_Blocks.Firmware1 else Firmware_Blocks.Firmware1_Start
                print('Deleting config and firmware regions...')
                vxr_delete_bank(fw_bank_to_delete)
                print('Programming config...')
                vxr_program_region(config0_block, config_start_address)
                print('Programming firmware...')
                vxr_program_region(firmware0_block, fw_start_address)
                invalidate_vxr_tag(active_config_bank)
                invalidate_vxr_tag(active_firmware_bank)
        else:
            # not updating config
            if(update_fw):
                # only updating firmware

                fw_bank_to_delete = Firmware_Blocks.Firmware0 if active_firmware_bank == Firmware_Blocks.Firmware1 else Firmware_Blocks.Firmware1
                fw_start_address = Firmware_Blocks.Firmware0_Start if active_firmware_bank == Firmware_Blocks.Firmware1 else Firmware_Blocks.Firmware1_Start
                print('Deleting firmware region...')
                vxr_delete_bank(fw_bank_to_delete)
                print('Programming firmware...')
                vxr_program_region(firmware0_block, fw_start_address)
                invalidate_vxr_tag(active_firmware_bank)

                if(not config_and_fw_in_same_bank(active_banks)):
                    # we have to reprogram config. Deleting the opposite FW bank will have deleted config too
                    config_start_address = Firmware_Blocks.Config0_Start if active_config_bank == Firmware_Blocks.Config1 else Firmware_Blocks.Config1_Start
                    vxr_program_region(config0_block, config_start_address)
                    invalidate_vxr_tag(active_config_bank)

    if(update_esm):
        print('Deleting HDCP ESM region...')
        vxr_delete_esm()
        print('Programming ESM...')
        vxr_program_region(esm_block, Firmware_Blocks.ESM_Start)
    
    print('Complete! Reset VXR7200 to continue.')

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Load Synaptics firmware.')
    parser.add_argument('-f', '--filename', type=str, 
                    help='path to the fullrom file')
    parser.add_argument('-e','--erase', help='complete Flash erase all firmware on the VXR7200', action='store_true')
    args = parser.parse_args()
    #print(args.filename)
    if args.erase:
        print("Deleting Firmware bank 0 region 128kB... (1/3)",end='')
        vxr_delete_bank(Firmware_Blocks.Firmware0)
        print('...done.')
        print("Deleting Firmware bank 1 region 128kB... (2/3)",end='')
        vxr_delete_bank(Firmware_Blocks.Firmware1)
        print('...done.')
        print("Deleting Firmware bank 1 region 256kB... (3/3)",end='')
        vxr_delete_esm()
        print('...done. VXR7200 Flash completely erased. Please unplug and replug HMD to continue.')
    elif args.filename:
        vxr_flash_update(args.filename)
    else:
        parser.print_help()