import time
import argparse
import serial
from uart_packet import *

TIMEOUT_NS = int(1e9 * 1) # 0.5 seconds

CMD_GET_ACTIVE_VXR_TAGS = ord('T')
CMD_DEL_VXRFLASH_BANK = ord('D')
CMD_GET_VXRFLASH_CHKSUM = ord('K')
CMD_VXRFLASH_PROGRAM = ord('A')
CMD_SUCCESS = ord('$')
CMD_ERROR = ord('E')

class Firmware_Blocks:
    Config0 = 1
    Config1 = 2
    Firmware0 = 4
    Firmware1 = 8

    Config_Length = 0x08000
    Firmware_Length = 0x18000
    ESM_Length = 0x40000

    Config0_Start		= 0x00000
    Config0_Tag_Start	= 0x07FF0
    Config0_Tag_End		= 0x07FFF
    Config0_End			= Config0_Tag_End
    Firmware0_Start		= 0x08000
    Firmware0_Tag_Start	= 0x1FFF0
    Firmware0_Tag_End	= 0x1FFFF
    Firmware0_End		= Firmware0_Tag_End
    Config1_Start		= 0x20000
    Config1_Tag_Start	= 0x27FF0
    Config1_Tag_End		= 0x27FFF
    Config1_End			= Config1_Tag_End
    Firmware1_Start		= 0x28000
    Firmware1_Tag_Start	= 0x3FFF0
    Firmware1_Tag_End	= 0x3FFFF
    Firmware1_End		= Firmware1_Tag_End
    ESM_Start           = 0x40000
    ESM_End             = 0x7FFFF

def config_and_fw_in_same_bank(active_tags):
    if(active_tags & Firmware_Blocks.Config0):
        if(active_tags & Firmware_Blocks.Firmware0):
            return True
    if(active_tags & Firmware_Blocks.Config1):
        if(active_tags & Firmware_Blocks.Firmware1):
            return True

    return False

def load_fullrom_file(filename):

    with open(filename, "rb") as fileh:
        config0_block = fileh.read(Firmware_Blocks.Config_Length)
        firmware0_block = fileh.read(Firmware_Blocks.Firmware_Length)
        # config1 and fw1 are copies of config0 and fw0
        # regardless, we have to read them to get to the last block (HDCP ESM)
        config1_block = fileh.read(Firmware_Blocks.Config_Length)
        firmware1_block = fileh.read(Firmware_Blocks.Firmware_Length)
        esm_block = fileh.read()

    # Confirm a few things:
    if((len(config0_block) != Firmware_Blocks.Config_Length) or (len(config1_block) != Firmware_Blocks.Config_Length) or
        (len(firmware0_block) != Firmware_Blocks.Firmware_Length) or (len(firmware1_block) != Firmware_Blocks.Firmware_Length)
        or (len(esm_block) != Firmware_Blocks.ESM_Length)):
        raise RuntimeError("Firmware file has invalid format (length incorrect)")
    if((config0_block != config1_block) or (firmware0_block != firmware1_block)):
        raise RuntimeError("Firmware file has invalid format (blocks 0 and 1 do not match)")

    return (config0_block, firmware0_block, esm_block)

def wait_for_uart_response(comport, wordy = False):
    if wordy: print("Waiting for new reply...")
    start_time = time.time_ns()
    uart_rx_buf = b''
    while(time.time_ns() - start_time < TIMEOUT_NS):
        # Grab all available bytes at the comport
        if(comport.in_waiting != 0):
            newbytes = comport.read(comport.in_waiting)
            if wordy: print("Newly arrived: ",end="")
            if wordy: print(''.join('0x{:02X}, '.format(bb) for bb in newbytes))
            uart_rx_buf = uart_rx_buf + newbytes
            if wordy: print("Full buffer: ",end="")
            if wordy: print(''.join('0x{:02X}, '.format(bb) for bb in uart_rx_buf))
        # Check for a packet
        (errcode, startpos, retpkt) = DataPacket.extract_packet(uart_rx_buf)
        if(errcode == DataPacket.PKRES_SUCCESS):
            return retpkt
        elif(errcode == DataPacket.PKRES_NO_PACKET):
            # continue waiting
            # clear unnecessary bytes from the buffer
            if(startpos != -1):
                uart_rx_buf = uart_rx_buf[startpos:]
        else:
            # some other error has occurred
            raise RuntimeError('Error in packet reply: 0x{:02X}'.format(errcode))
    
    return None

def check_vxr_tags(beyond_comport):
    beyond_comport.write(DataPacket(CMD_GET_ACTIVE_VXR_TAGS,[]).create_packet())

    tags = None
    # Wait until we get the reply packet or an error
    recpacket = wait_for_uart_response(beyond_comport)
    if(recpacket != None):
        if(recpacket.PktType == CMD_GET_ACTIVE_VXR_TAGS):
            tags = recpacket.PktData[0]

    if(tags == None):
        raise RuntimeError('Could not read VXR active tags')
    return tags

def vxr_get_checksum(beyond_comport, start_addr, length):
    pktdata = struct.pack("<II", start_addr, length)
    beyond_comport.write(DataPacket(CMD_GET_VXRFLASH_CHKSUM,pktdata).create_packet())

    checksum = None
    # Wait until we get the reply packet or an error
    recpacket = wait_for_uart_response(beyond_comport)
    if(recpacket != None):
        if((recpacket.PktType == CMD_GET_VXRFLASH_CHKSUM) and (len(recpacket.PktData) == 4)):
            checksum = struct.unpack("<I",bytes(recpacket.PktData))[0]
    
    if(checksum == None):
        raise RuntimeError("Error in received checksum format")

    return checksum

def vxr_program(beyond_comport, start_addr, prog_bytes):
    if((len(prog_bytes) == 0) or (len(prog_bytes) > 32)):
        raise RuntimeError("Invalid programming length. Should be 1-32. Got {}".format(len(prog_bytes)))
    partial_command = struct.pack("<BI", len(prog_bytes), start_addr)
    beyond_comport.write(DataPacket(CMD_VXRFLASH_PROGRAM, partial_command+prog_bytes).create_packet())
    recpacket = wait_for_uart_response(beyond_comport)
    if((recpacket == None) or (recpacket.PktType != CMD_SUCCESS)):
        raise RuntimeError("Could not program Flash. Addr: 0x{:05X}, length: {}".format(start_addr, len(prog_bytes)))

# Sets the first and last bytes of the tag (final 16 bytes of a region) to zero
# This causes the VXR7200 bootloader to consider the region invalid
def invalidate_vxr_tag(beyond_comport, region):
    if(region & Firmware_Blocks.Config0):
        vxr_program(beyond_comport, Firmware_Blocks.Config0_Tag_Start, bytes([0]))
        vxr_program(beyond_comport, Firmware_Blocks.Config0_Tag_End, bytes([0]))
    elif(region & Firmware_Blocks.Config1):
        vxr_program(beyond_comport, Firmware_Blocks.Config1_Tag_Start, bytes([0]))
        vxr_program(beyond_comport, Firmware_Blocks.Config1_Tag_End, bytes([0]))
    elif(region & Firmware_Blocks.Firmware0):
        vxr_program(beyond_comport, Firmware_Blocks.Firmware0_Tag_Start, bytes([0]))
        vxr_program(beyond_comport, Firmware_Blocks.Firmware0_Tag_End, bytes([0]))
    elif(region & Firmware_Blocks.Firmware1):
        vxr_program(beyond_comport, Firmware_Blocks.Firmware1_Tag_Start, bytes([0]))
        vxr_program(beyond_comport, Firmware_Blocks.Firmware1_Tag_End, bytes([0]))

# Returns a tuple (update_config, update_fw, update_esm)
# if the "update_*" variable is true, that region needs
# to be updated in Flash from the file
# if false, it can be skipped
def vxr_compare_file_to_flash_checksum(beyond_comport, filename):
    # return values
    update_config = False
    update_fw = False
    update_esm = False

    # get the blocks from the file
    (config, firmware, esm) = load_fullrom_file(filename)

    # get which banks are active
    active_tags = check_vxr_tags(beyond_comport)
    # print("Currently active: {} and {}".format(
    #     "Config0" if (active_tags & Firmware_Blocks.Config0) else "Config1",
    #     "Firmware0" if (active_tags & Firmware_Blocks.Firmware0) else "Firmware1"
    # ))
    if(active_tags & Firmware_Blocks.Config0):
        flash_checksum = vxr_get_checksum(beyond_comport, Firmware_Blocks.Config0_Start, Firmware_Blocks.Config_Length)
        file_checksum = sum(config)
        if(flash_checksum != file_checksum):
            update_config = True
            # print("Config0 doesn't match file, Config1 needs to be updated")
            # print("Flash: 0x{:08X}, File: 0x{:08X}".format(flash_checksum, file_checksum))
        else:
            update_config = False
            # print("Config0 matches file! Will not update.")
    else:
        flash_checksum = vxr_get_checksum(beyond_comport, Firmware_Blocks.Config1_Start, Firmware_Blocks.Config_Length)
        file_checksum = sum(config)
        if(flash_checksum != file_checksum):
            update_config = True
            # print("Config1 doesn't match file, Config0 needs to be updated")
            # print("Flash: 0x{:08X}, File: 0x{:08X}".format(flash_checksum, file_checksum))
        else:
            update_config = False
            # print("Config1 matches file! Will not update.")

    if(active_tags & Firmware_Blocks.Firmware0):
        flash_checksum = vxr_get_checksum(beyond_comport, Firmware_Blocks.Firmware0_Start, Firmware_Blocks.Firmware_Length)
        file_checksum = sum(firmware)
        if(flash_checksum != file_checksum):
            update_fw = True
            # print("Firmware0 doesn't match, Firmware1 needs to be updated")
            # print("Flash: 0x{:08X}, File: 0x{:08X}".format(flash_checksum, file_checksum))
        else:
            update_fw = False
            # print("Firmware0 matches file! Will not update.")
    else:
        flash_checksum = vxr_get_checksum(beyond_comport, Firmware_Blocks.Firmware1_Start, Firmware_Blocks.Firmware_Length)
        file_checksum = sum(firmware)
        if(flash_checksum != file_checksum):
            update_fw = True
            # print("Firmware1 doesn't match, Firmware0 needs to be updated.")
            # print("Flash: 0x{:08X}, File: 0x{:08X}".format(flash_checksum, file_checksum))
        else:
            update_fw = False
            # print("Firmware1 matches file! Will not update.")

    flash_checksum = vxr_get_checksum(beyond_comport, Firmware_Blocks.ESM_Start, Firmware_Blocks.ESM_Length)
    file_checksum = sum(esm)
    if(flash_checksum != file_checksum):
        update_esm = True
        # print("ESM doesn't match, needs to be updated")
        # print("Flash: 0x{:08X}, File: 0x{:08X}".format(flash_checksum, file_checksum))
    else:
        update_esm = False
        # print("ESM matches! Will not update.")

    return (update_config, update_fw, update_esm)

def vxr_delete_bank(beyond_comport, bank_to_delete):
    flash_blocks = []
    if(bank_to_delete == Firmware_Blocks.Config0):
        flash_blocks = [0]
    if(bank_to_delete == Firmware_Blocks.Config1):
        flash_blocks = [2]
    if(bank_to_delete == Firmware_Blocks.Firmware0):
        flash_blocks = [0,1]
    if(bank_to_delete == Firmware_Blocks.Firmware1):
        flash_blocks = [2,3]

    for flash_block_num in flash_blocks:
        beyond_comport.write(DataPacket(CMD_DEL_VXRFLASH_BANK, bytes([flash_block_num])).create_packet())
        recpacket = wait_for_uart_response(beyond_comport)
        if((recpacket == None) or (recpacket.PktType == CMD_ERROR)):
            raise RuntimeError("Could not delete Flash block {}.".format(flash_block_num))

def vxr_delete_esm(beyond_comport):
    for flash_block_num in range(4,8):
        beyond_comport.write(DataPacket(CMD_DEL_VXRFLASH_BANK, bytes([flash_block_num])).create_packet())
        recpacket = wait_for_uart_response(beyond_comport)
        if((recpacket == None) or (recpacket.PktType == CMD_ERROR)):
            raise RuntimeError("Could not delete Flash block {}.".format(flash_block_num))

def vxr_program_region(beyond_comport, program_bytes, start_addr, report_after_every=8*1024):
    bytes_sent = 0

    # start splitting up into 32 byte chunks
    while(bytes_sent < len(program_bytes)):
        if((len(program_bytes) - bytes_sent) >= 32):
            program_chunk = program_bytes[bytes_sent:(bytes_sent+32)]
        else:
            # less than 32 bytes remain, grab everything that's left
            program_chunk = program_bytes[bytes_sent:]

        chunk_addr = start_addr + bytes_sent
        program_cmd = struct.pack('<BI', len(program_chunk), chunk_addr) + program_chunk
        beyond_comport.write(DataPacket(CMD_VXRFLASH_PROGRAM, program_cmd).create_packet())

        recpacket = wait_for_uart_response(beyond_comport)
        if(recpacket == None) :
            raise RuntimeError("Could not program Flash at 0x{:05X}. Timed out.".format(chunk_addr))
        if(recpacket.PktType == CMD_ERROR):
            raise RuntimeError("Could not program Flash at 0x{:05X}. Received error response.".format(chunk_addr))
        bytes_sent = bytes_sent + len(program_chunk)
        if(0 == (bytes_sent % report_after_every)):
            print('{} bytes programmed. {}% complete'.format(bytes_sent, int(100*(bytes_sent/len(program_bytes)))))

def vxr_flash_update_uart(beyond_comport, filename):
    # Updating the VXR7200 Flash has a number of steps:
    #   1) Check which banks are active (config and firmware)
    #   2) Check if the active config/firmware and esm already match
    #       what's in the file to be loaded
    #   3) Delete and program the regions that need to be updated

    print('Checking requirements...')
    (config0_block, firmware0_block, esm_block) = load_fullrom_file(filename)
    (update_config, update_fw, update_esm) = vxr_compare_file_to_flash_checksum(beyond_comport, filename)
    active_banks = check_vxr_tags(beyond_comport)
    active_config_bank = active_banks & (Firmware_Blocks.Config0 | Firmware_Blocks.Config1)
    active_firmware_bank = active_banks & (Firmware_Blocks.Firmware0 | Firmware_Blocks.Firmware1)

    if(active_config_bank == 0 or active_firmware_bank == 0):
        # At least one region is unprogrammed. Consider this a blank chip for 
        # config and firmware. 
        print('Config and/or Firmware are blank. Programming bank 0.')
        print('Deleting...')
        vxr_delete_bank(beyond_comport, Firmware_Blocks.Firmware0)
        print('Programming config...')
        vxr_program_region(beyond_comport, config0_block, Firmware_Blocks.Config0_Start)
        print('Programming firmware...')
        vxr_program_region(beyond_comport, firmware0_block, Firmware_Blocks.Firmware0_Start)
        invalidate_vxr_tag(beyond_comport, Firmware_Blocks.Config1 | Firmware_Blocks.Firmware1)
    else:
        if(update_config):
            if(not update_fw):
                # Only updating config region.

                # load into the opposite bank of what's active right now
                config_bank_to_delete = Firmware_Blocks.Config0 if active_config_bank == Firmware_Blocks.Config1 else Firmware_Blocks.Config1
                config_start_address = Firmware_Blocks.Config0_Start if active_config_bank == Firmware_Blocks.Config1 else Firmware_Blocks.Config1_Start
                print('Deleting config region...')
                vxr_delete_bank(beyond_comport, config_bank_to_delete)
                print('Programming config...')
                vxr_program_region(beyond_comport, config0_block, config_start_address)
                # Check if we just deleted some of the firmware region when updating.
                # ...if the config and FW are currently running in opposite banks, when
                # deleting the to-be-loaded config bank, the first 32k of active firmware will
                # also be deleted. This means we will have to reload it too.
                if(not config_and_fw_in_same_bank(active_banks)):
                    vxr_program_region(beyond_comport, firmware0_block[:Firmware_Blocks.Config_Length], config_start_address + Firmware_Blocks.Config_Length)
                # now invalidate the old config bank
                invalidate_vxr_tag(beyond_comport, active_config_bank)
            else:
                # Updating both config and firmware regions
                # Delete the entire bank (config + fw) opposite of the currently loaded firmware and place
                # both config and firmware there
                config_start_address = Firmware_Blocks.Config0_Start if active_firmware_bank == Firmware_Blocks.Firmware1 else Firmware_Blocks.Config1_Start
                fw_bank_to_delete = Firmware_Blocks.Firmware0 if active_firmware_bank == Firmware_Blocks.Firmware1 else Firmware_Blocks.Firmware1
                fw_start_address = Firmware_Blocks.Firmware0_Start if active_firmware_bank == Firmware_Blocks.Firmware1 else Firmware_Blocks.Firmware1_Start
                print('Deleting config and firmware regions...')
                vxr_delete_bank(beyond_comport, fw_bank_to_delete)
                print('Programming config...')
                vxr_program_region(beyond_comport, config0_block, config_start_address)
                print('Programming firmware...')
                vxr_program_region(beyond_comport, firmware0_block, fw_start_address)
                invalidate_vxr_tag(beyond_comport, active_config_bank)
                invalidate_vxr_tag(beyond_comport, active_firmware_bank)
        else:
            # not updating config
            if(update_fw):
                # only updating firmware

                fw_bank_to_delete = Firmware_Blocks.Firmware0 if active_firmware_bank == Firmware_Blocks.Firmware1 else Firmware_Blocks.Firmware1
                fw_start_address = Firmware_Blocks.Firmware0_Start if active_firmware_bank == Firmware_Blocks.Firmware1 else Firmware_Blocks.Firmware1_Start
                print('Deleting firmware region...')
                vxr_delete_bank(beyond_comport, fw_bank_to_delete)
                print('Programming firmware...')
                vxr_program_region(beyond_comport, firmware0_block, fw_start_address)
                invalidate_vxr_tag(beyond_comport, active_firmware_bank)

                if(not config_and_fw_in_same_bank(active_banks)):
                    # we have to reprogram config. Deleting the opposite FW bank will have deleted config too
                    config_start_address = Firmware_Blocks.Config0_Start if active_config_bank == Firmware_Blocks.Config1 else Firmware_Blocks.Config1_Start
                    vxr_program_region(beyond_comport, config0_block, config_start_address)
                    invalidate_vxr_tag(beyond_comport, active_config_bank)

    if(update_esm):
        print('Deleting HDCP ESM region...')
        vxr_delete_esm(beyond_comport)
        print('Programming ESM...')
        vxr_program_region(beyond_comport, esm_block, Firmware_Blocks.ESM_Start)
    
    print('Complete! Reset VXR7200 to continue.')

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Load Synaptics firmware.')
    parser.add_argument('filename', type=str, 
                    help='path to the fullrom file')

    args = parser.parse_args()
    #print(args.filename)

    # Open the com port to the Beyond 
    comport = serial.Serial('COM3', 115200, timeout=1)
    vxr_flash_update_uart(comport, args.filename)

