import hid
from threading import Thread
from queue import Queue
from typing import Optional
import time
import argparse
import os
import sys
import struct
import crc
#from crc import Crc32, CrcCalculator

from ihex_loader import *

BL_OVERWRITE_HEX = r"..\build\BL_Overwriter_v013_230519_FW.hex"
FLASH_START_ADDRESS = 0x00400000
BOOTLOADER_START_ADDRESS = FLASH_START_ADDRESS
APP_START_ADDRESS = BOOTLOADER_START_ADDRESS + 0x4000
LENGTH_OFFSET = 0x400
CHUNKSIZE = 32
PAGE_SIZE = 0x200

BIGSCREEN_VID = 0x35BD
BEYOND_PID = 0x0101
BOOTLOADER_PID = 0x4004
OVERWRITER_PID = 0x5004

HID_CONNECT_TIMEOUT = 20

def wait_for_hid_device(vid: int, pid: int, descriptive_name: str, timeout_sec: int) -> Optional[hid.Device]:
    if(hasattr(hid, "HIDException")):
        hid_exception = hid.HIDException
    else:
        hid_exception = OSError
    print('Connecting to {}'.format(descriptive_name))
    if(timeout_sec > 1):
        print('Waiting up to {} seconds'.format(timeout_sec))
    for i in range(timeout_sec):
        try:
            retval = hid.Device(vid=vid, pid=pid)
            if(timeout_sec > 1):
                print('')
            return retval
        except(hid_exception):
            # continue waiting
            pass
        time.sleep(1.0)
        if(i < timeout_sec - 1):
            print('...',end='')
            sys.stdout.flush()
    if(timeout_sec > 1):
        print('')
    return None

def generate_lut_entry(lut_index):
    crc_tab_entry = lut_index
    for i in range(8, 0, -1):
        if (crc_tab_entry & 0x01):
            crc_tab_entry = (crc_tab_entry >> 1) ^ 0xEDB88320

        else:
            crc_tab_entry = crc_tab_entry >> 1
    return(crc_tab_entry)

def calc_end_crc(bfr, length):
    crc32 = 0xffffffff

    for i in range(0,length):
        data_byte = bfr[i]

        idx = (data_byte ^ (crc32 & 0x000000FF))

        # work out the CRC
        # the CRC function is the IEEE CRC32 function with a polynomial
        # of 0xEDB88320, implemented using a lookup table
        crc32 = (crc32 >> 8)

        crc32 = crc32 ^ generate_lut_entry(idx)

    return crc32

def crc8_calc(crc8_init, databytes):
    CRC8POLY = 0x07
    for bb in databytes:
        crc8_init = crc8_init ^ bb
        for k in range(8):
            if(crc8_init & 0x80):
                crc8_init = (0xFF & (crc8_init << 1)) ^ CRC8POLY
            else:
                crc8_init = (0xFF & (crc8_init << 1))
    return crc8_init

def create_hid_message(hidtype:int, datalen:int, addr:int, data:bytes) -> bytes:
    bytes_out = struct.pack('<BBI', hidtype, datalen, addr)
    bytes_out = bytes_out + bytes(data)
    crcbyte = crc8_calc(0, bytes_out)
    bytes_out = bytes_out + bytes([crcbyte])

    return bytes_out

def erase_app(bootloader_hid: hid.Device) -> bool:
    bootloader_hid.write(b'\x00\x22')
    # give it 10 seconds to clear the app flash
    status = bootloader_hid.read(64, timeout=10000)
    if(len(status) > 1):
        if(status[1] == 0):
            return True
    return False

def checkstatus(status: bytes) -> bool:
    if(len(status) < 2):
        return False
    if(status[1] != 0):
        return False
    return True

def checkstatus_emptyok(status: bytes) -> bool:
    if(len(status) > 1):
        if(status[1] != 0):
            return False
    return True

def load_app(bootloader_hid: hid.Device, new_app: bytes, start_address: int) -> bool:
    addr = start_address
    cur_page_addr = start_address
    num_chunks = int(len(new_app) / CHUNKSIZE)
    for curchunk in range(num_chunks):
        bin_chunk = new_app[(CHUNKSIZE*curchunk):(CHUNKSIZE*(curchunk+1))]
        bin_len = CHUNKSIZE

        bytes_out = create_hid_message(0x44, bin_len, addr, bin_chunk)
        # bytes_out = struct.pack('<BBI', 0x44, bin_len, addr)
        # bytes_out = bytes_out + bin_chunk
        # crcbyte = crc8_calc(0, bytes_out)
        # bytes_out = bytes_out + bytes([crcbyte])

        bootloader_hid.write(bytes([0]) + bytes_out)
        status = bootloader_hid.read(64,timeout=0) 

        # print('Sent: {}'.format(bytes_out.hex()))
        if(not checkstatus_emptyok(status)):
            print('1 Write Flash returned with error code: {}'.format(status[1]))
            return False

        addr = addr + bin_len

        # program a page once all data has been sent for this page
        if((addr % PAGE_SIZE) == 0):
            bytes_out = create_hid_message(0x50, 0, cur_page_addr, b'')
            # bytes_out = struct.pack('<BBI', 0x50, 0, cur_page_addr)
            # bytes_out = bytes_out + bytes([crc8_calc(0, bytes_out)])
            
            bootloader_hid.write(bytes([0]) + bytes_out)
            status = bootloader_hid.read(64, timeout = 1000)
            if(not checkstatus_emptyok(status)):
                print('2 Write Flash returned with error code: {}'.format(status[1]))
                return False

            cur_page_addr = cur_page_addr + PAGE_SIZE

            print("{}% done".format(int(100 * curchunk / num_chunks)))

    if((len(new_app) % CHUNKSIZE) != 0):
        # Still one more partial chunk to send
        bin_chunk = new_app[(CHUNKSIZE*num_chunks):]
        bin_len = len(bin_chunk)
        if((bin_len % 8) != 0):
            # append 0xFF as filler
            bin_chunk = bin_chunk + [0xFF]*(8 - bin_len % 8)
        bytes_out = create_hid_message(0x44, bin_len, addr, bin_chunk)
        # bytes_out = struct.pack('<BBI', 0x44, bin_len, addr)
        # bytes_out = bytes_out + bin_chunk
        # crcbyte = crc8_calc(0, bytes_out)
        # bytes_out = bytes_out + bytes([crcbyte])

        bootloader_hid.write(bytes([0]) + bytes_out)
        status = bootloader_hid.read(64, timeout=0)

        if(not checkstatus_emptyok(status)):
            print('3 Write Flash returned with code: {}'.format(status[1]))
            return False
        
        # Program final page
        bytes_out = create_hid_message(0x50, 0, cur_page_addr, b'')
        # bytes_out = struct.pack('<BBI', 0x50, 0, cur_page_addr)
        # bytes_out = bytes_out + bytes([crc8_calc(0, bytes_out)])
        
        bootloader_hid.write(bytes([0]) + bytes_out)
        status = bootloader_hid.read(64, timeout = 1000)
        if(not checkstatus_emptyok(status)):
            print('4 Program Flash returned with error code: {}'.format(status[1]))
            return False
    else:
        # See if we need to program the last stuff written by the for loop
        if(addr > cur_page_addr):
            # Program final page
            bytes_out = create_hid_message(0x50, 0, cur_page_addr, b'')
            # bytes_out = struct.pack('<BBI', 0x50, 0, cur_page_addr)
            # bytes_out = bytes_out + bytes([crc8_calc(0, bytes_out)])
            
            bootloader_hid.write(bytes([0]) + bytes_out)
            status = bootloader_hid.read(64, timeout = 1000)
            if(not checkstatus_emptyok(status)):
                print('5 Program Flash returned with error code: {}'.format(status[1]))
                return False

    return True


if __name__ == "__main__":
    overwrite_hex = os.path.abspath(BL_OVERWRITE_HEX)

    parser = argparse.ArgumentParser(description="Updates the Beyond bootloader. "
                                    "NOTE: This is a dangerous operation! If interrupted, device will "
                                    "be non-responsive unless Flashed by JTAG.")
    parser.add_argument('-b', help="Bootloader update hex file: application code that acts as a bootloader. "
                        "If not specified, will use the default file at \"{}\"".format(overwrite_hex))
    parser.add_argument('file_name', help='New bootloader hex file')

    args = parser.parse_args()

    if(args.b):
        overwrite_hex = os.path.abspath(args.b)
    if(not os.path.exists(overwrite_hex)):
        print('ERROR: bootloader overwrite firmware does not exist (\"{}\")'.format(overwrite_hex))
        sys.exit()

    new_bl_hex = os.path.abspath(args.file_name)
    if(not os.path.exists(new_bl_hex)):
        print('ERROR: new bootloader firmware does not exist (\"{}\")'.format(new_bl_hex))

    # check both hex files for correct CRC and start address
    (overwrite_bytes, overwrite_start) = load_ihex(overwrite_hex)
    (new_bl_bytes, new_bl_start) = load_ihex(new_bl_hex)

    if(overwrite_start != APP_START_ADDRESS):
        print("ERROR: Overwrite hex file does not start at correct address (should be: 0x{:08X}, got: 0x{:08X})".format(APP_START_ADDRESS,overwrite_start))
        sys.exit()
    if(new_bl_start != BOOTLOADER_START_ADDRESS):
        print("ERROR: New bootloader hex file does not start at correct address (should be: 0x{:08X}, got: 0x{:08X})".format(BOOTLOADER_START_ADDRESS,new_bl_start))
        sys.exit()

    # generate CRC for each firmware
    (overwrite_len,) = struct.unpack('<I',bytes(overwrite_bytes[LENGTH_OFFSET:LENGTH_OFFSET+4]))
    (new_bl_len,) = struct.unpack('<I',bytes(new_bl_bytes[LENGTH_OFFSET:LENGTH_OFFSET+4]))

    #print('len overwrite {:08X}'.format(overwrite_len))
    #print('len new bl {:08X}'.format(new_bl_len))

    if(hasattr(crc,"CrcCalculator")):
        # older version of crc package
        app_crc_calc = crc.CrcCalculator(crc.Crc32.BZIP2, True)
        overwrite_crc_local = app_crc_calc.calculate_checksum(bytes(overwrite_bytes[:overwrite_len]))
    else:
        # newer version
        app_crc_calc = crc.Calculator(crc.Crc32.BZIP2,True)
        overwrite_crc_local = app_crc_calc.checksum(bytes(overwrite_bytes[:overwrite_len]))
    (overwrite_crc_saved,) = struct.unpack('<I',bytes(overwrite_bytes[overwrite_len:overwrite_len+4]))

    new_bl_crc_local = calc_end_crc(new_bl_bytes,new_bl_len-4)
    (new_bl_crc_saved,) = struct.unpack('<I',bytes(new_bl_bytes[(new_bl_len-4):new_bl_len]))

    # print('overwrite saved crc: {:08X}'.format(overwrite_crc_saved))
    # print('overwrite calculated crc: {:08X}'.format(overwrite_crc_local))
    # print('new bootloader saved crc: {:08X}'.format(new_bl_crc_saved))
    # print('new bootloader calculated crc: {:08X}'.format(new_bl_crc_local))

    if(overwrite_crc_local != overwrite_crc_saved):
        print("ERROR: Overwriter file checksum doesn't match (read 0x{:08X}, calculated 0x{:08x})".format(overwrite_crc_saved, overwrite_crc_local))
        sys.exit()
    if(new_bl_crc_local != new_bl_crc_saved):
        print("ERROR: New bootloader file checksum doesn't match (read 0x{:08X}, calculated 0x{:08x})".format(new_bl_crc_saved, new_bl_crc_local))
        sys.exit()

    if(hasattr(hid, "HIDException")):
        hid_exception = hid.HIDException
    else:
        hid_exception = OSError

    do_bootloader_load_overwriter = False # true if we need to load overwriter program

    # now try to connect to the HMD
    bynd = wait_for_hid_device(BIGSCREEN_VID, BEYOND_PID, "Beyond", 1)

    if(bynd is None):
        # Try the bootloader next
        bootl = wait_for_hid_device(BIGSCREEN_VID, BOOTLOADER_PID, "Bootloader", 1)
        if(bootl is None):
            # Finally try the overwriter
            overwriter = wait_for_hid_device(BIGSCREEN_VID, OVERWRITER_PID, "Overwriter application", 1)
            if(overwriter is None):
                # Failed to connect for all devices, quit
                print('ERROR: Could not find any Beyond device. Quitting.')
                sys.exit()
        else:
            do_bootloader_load_overwriter = True
    else:
        # Connected to Beyond, try going to bootloader now
        print('Entering bootloader mode.')
        bynd.send_feature_report(bytes([0, ord('B')]))
        # Try opening bootloader
        bootl = wait_for_hid_device(BIGSCREEN_VID, BOOTLOADER_PID, "Bootloader", HID_CONNECT_TIMEOUT)
        if(bootl is None):
            # Could not find bootloader after swap
            print('ERROR: Bootloader not found. Quitting.')
            sys.exit()
        else:
            do_bootloader_load_overwriter = True
    
    """
    try:    
        print('Attempting connection to Beyond.')
        beyond = hid.Device(vid = BIGSCREEN_VID, pid = BEYOND_PID)
        print('Entering bootloader mode.')
        beyond.send_feature_report(bytes([0, ord('B')]))
        for i in range(10):
            print('waiting {}/10 seconds'.format(i+1))
            time.sleep(1.0)
    except(hid_exception):
        print('Beyond not found.')
    
    try:
        print('Attempting connection to bootloader.')
        bootloader = hid.Device(vid = BIGSCREEN_VID, pid = BOOTLOADER_PID)
    except(hid_exception):
        print('Bootloader not found. Exiting.')
        sys.exit()
    """

    if(do_bootloader_load_overwriter):
        # Delete application region
        print('Deleting application region...')
        if(not erase_app(bootl)):
            print('Erase failed. Exiting.')
            sys.exit()

        # Write the bootloader overwriter
        print('Loading overwriting program')
        if(not load_app(bootl, overwrite_bytes, overwrite_start)):
            print('Load failed, exiting.')
            sys.exit()

        # Restart into overwriter
        bootl.write(bytes([0, ord('B')]))
        overwriter = wait_for_hid_device(BIGSCREEN_VID, OVERWRITER_PID, "Overwriter application", HID_CONNECT_TIMEOUT)
        if(overwriter is None):
            print('Overwriter not found. Exiting. Use firmware loader to reload application.')
            sys.exit()

    # Delete the bootloader. Cross your fingers.
    print('Deleting bootloader.')
    erase_result = False
    attempt_number = 0
    while(not erase_result):
        erase_result = erase_app(overwriter)
        attempt_number = attempt_number + 1
        if(not erase_result):
            if(attempt_number > 5):
                print('Erase failed. Exiting. JTAG may be necessary to recover.')
                sys.exit()
            else:
                print('Erase failed. Retrying...')

    # Flashing the new bootloader
    print('Loading new bootloader')
    load_result = False
    attempt_number = 0
    while(not load_result):
        load_result = load_app(overwriter, new_bl_bytes, new_bl_start)
        attempt_number = attempt_number + 1
        if(not load_result):
            if(attempt_number > 5):
                print('Load failed. Exiting. JTAG may be necessary to recover.')
                sys.exit()
            else:
                print('Load failed. Retrying...')

    print('New bootloader Flashed successfully!')
    print('Entering bootloader again to remove overwriter program')
    overwriter.write(bytes([0, ord('B')]))
    bootl = wait_for_hid_device(BIGSCREEN_VID, BOOTLOADER_PID, "Bootloader", HID_CONNECT_TIMEOUT)
    if(bootl is None):
        print('Bootloader not found. Exiting.')
        sys.exit()
    
    print('Erasing updater program')
    if(not erase_app(bootl)):
        print('Erase failed. Exiting.')
        sys.exit()
    print('Bootloader update completed!')