import hid
import struct
import time
import os
import sys
from progressbar import progressbar

import tkinter as tk
from tkinter import ttk, messagebox, filedialog

from threading import Thread
from queue import Queue
import time

from ihex_loader import *

CHUNKSIZE = 32
APP_ADDRESS = 0x00404000
PAGE_SIZE = 0x200
SMALL_SUBSECTOR_SIZE = 0x2000 # There are two 8kB subsectors at the start of Flash
LARGE_SUBSECTOR_SIZE = 0x1C000 # 128kB - 16kB subsector just after the two small ones
SECTOR_SIZE = 0x20000 # 128kB sector size, there are 4 total and the first one is split into subsectors
ALLOWED_APP_SIZE = LARGE_SUBSECTOR_SIZE + 3*SECTOR_SIZE
MIN_ERASABLE_SIZE = 0x2000  # 16 pages = 8kB. Small subsector can erase in smaller increments, but
                            # not the larger subsector or full-size sectors
''' 
Calculates CRC-8 using 0x07 as polynomial
generator. This implementation does not
shift, reflect, or invert either the input or output.
'''

def crc8_calc(crc8_init, databytes):
    CRC8POLY = 0x07
    for bb in databytes:
        crc8_init = crc8_init ^ bb
        for k in range(8):
            if(crc8_init & 0x80):
                crc8_init = (0xFF & (crc8_init << 1)) ^ CRC8POLY
            else:
                crc8_init = (0xFF & (crc8_init << 1))
    return crc8_init


'''
Calculates CRC-32 using 0x04C11DB7 as polynomial
generator. This function inverts the output.
'''

def crc32_calc(crc32_init, databytes):
    CRC32POLY = 0x04C11DB7
    for bb in databytes:
        crc32_init = crc32_init ^ (bb << 24)
        for k in range(8):
            if(crc32_init & 0x80000000):
                crc32_init = (0xFFFFFFFF & (crc32_init << 1)) ^ CRC32POLY
            else:
                crc32_init = (0xFFFFFFFF & (crc32_init << 1))
    return 0xFFFFFFFF & (~crc32_init)
'''
def oldstyle():
    for dev in hid.enumerate(vid=0xB165):
        boot = hid.Device(path=dev['path'])
        boot.write(b'\x00\x22')
        # give it 10 seconds to clear the app flash
        status = boot.read(64, timeout_ms=10000)

        if(len(status) > 1):
            print("Flash Erase returned with code: {}".format(status[1]))
            if(status[1] == 0):
                print('Successfully erased')
            else:
                print('Error in erasing flash. Quitting.')
                sys.exit()
        inputfile = r"Bigleap_Freertos_fw.hex"
        with open(inputfile, 'rb') as bin:
            app_data = bin.read()
        
        applen = len(app_data)
        addr = APP_ADDRESS
        cur_page_addr = APP_ADDRESS
        num_chunks = int(len(app_data) / CHUNKSIZE)
        for curchunk in progressbar(range(num_chunks)):
            bin_chunk = app_data[(CHUNKSIZE*curchunk):(CHUNKSIZE*(curchunk+1))]
            bin_len = CHUNKSIZE

            bytes_out = struct.pack('<BBI', 0x44, bin_len, addr)
            bytes_out = bytes_out + bin_chunk
            crcbyte = crc8_calc(0, bytes_out)
            bytes_out = bytes_out + bytes([crcbyte])

            boot.write(bytes([0]) + bytes_out)
            status = boot.read(65,timeout_ms=1) # no timeout, if nothing is there just skip

            # print('Sent: {}'.format(bytes_out.hex()))
            if(len(status) > 1):
                # print('Write Flash returned with code: {}'.format(status[1]))
                if(status[1] != 0):
                    print('Write Flash returned with error code: {}'.format(status[1]))

            addr = addr + bin_len

            # program a page once all data has been sent for this page
            if((addr % PAGE_SIZE) == 0):
                bytes_out = struct.pack('<BBI', 0x50, 0, cur_page_addr)
                bytes_out = bytes_out + bytes([crc8_calc(0, bytes_out)])
                
                boot.write(bytes([0]) + bytes_out)
                status = boot.read(65, timeout_ms = 1000)
                if(len(status) > 1):
                    if(status[1] != 0):
                        print('Program Flash returned with error code: {}'.format(status[1]))
                cur_page_addr = cur_page_addr + PAGE_SIZE

        if((len(app_data) % CHUNKSIZE) != 0):
            # Still one more partial chunk to send
            bin_chunk = app_data[(CHUNKSIZE*num_chunks):]
            bin_len = len(bin_chunk)
            if((bin_len % 8) != 0):
                # append 0xFF as filler
                bin_chunk = bin_chunk + b'\xFF'*(8 - bin_len % 8)
            bytes_out = struct.pack('<BBI', 0x44, bin_len, addr)
            bytes_out = bytes_out + bin_chunk
            crcbyte = crc8_calc(0, bytes_out)
            bytes_out = bytes_out + bytes([crcbyte])

            boot.write(bytes([0]) + bytes_out)
            status = boot.read(65, timeout_ms=1)

            if(len(status) > 1):
                # print('Write Flash returned with code: {}'.format(status[1]))
                if(status[1] != 0):
                    print('Write Flash returned with error code: {}'.format(status[1]))
            
            # Program final page
            bytes_out = struct.pack('<BBI', 0x50, 0, cur_page_addr)
            bytes_out = bytes_out + bytes([crc8_calc(0, bytes_out)])
            
            boot.write(bytes([0]) + bytes_out)
            status = boot.read(65, timeout_ms = 1000)
            if(len(status) > 1):
                if(status[1] != 0):
                    print('Program Flash returned with error code: {}'.format(status[1]))
        else:
            # See if we need to program the last stuff written by the for loop
            if(addr > cur_page_addr):
                # Program final page
                bytes_out = struct.pack('<BBI', 0x50, 0, cur_page_addr)
                bytes_out = bytes_out + bytes([crc8_calc(0, bytes_out)])
                
                boot.write(bytes([0]) + bytes_out)
                status = boot.read(65, timeout_ms = 1000)
                if(len(status) > 1):
                    if(status[1] != 0):
                        print('Program Flash returned with error code: {}'.format(status[1]))

        print('Application written to Flash')

        # Check CRC. The one built-in to the firmware and the one calculated by the bootloader.
        boot.write(bytes([0, ord('C')]))
        stored_crc = boot.read(65, timeout_ms=1000)
        printable_stored_crc = ''.join(['{:02X}'.format(bb) for bb in stored_crc[0:4]])
        boot.write(bytes([0,ord('M')]))
        calc_crc = boot.read(65, timeout_ms=1000)
        printable_calc_crc = ''.join(['{:02X}'.format(bb) for bb in calc_crc[0:4]])

        print('Saved CRC = 0x{}'.format(printable_stored_crc))
        print('Calculated CRC = 0x{}'.format(printable_calc_crc))

        if(printable_stored_crc == printable_calc_crc):
            print('CRCs match! Booting into application.')
            boot.write(bytes([0, ord('B')]))
        else:
            print('Error. Incorrect CRC. Application will not boot, check firmware file.')
'''
class EVENT_TYPE:
    USBDRIVERFAIL = 0
    USBDRIVERSWAP = 1
    ERASESTART = 2
    ERASEFAIL = 3
    ERASEPASS = 4
    PROGRAMERR = 5
    PROGRAMPROGRESS = 6
    PROGRAMDONE = 7   

class QueueItem:
    def __init__(self, evt_type, data):
        self.type = evt_type
        self.data = data

class bootloader_thread(Thread):
    BOOT_VID = 0x35BD
    APP_VID = 0x35BD
    BOOT_PID = 0x4004
    APP_PID = 0x0101
    def __init__(self, infile_app, tkroot):
        super().__init__()
        self.infile_app = infile_app
        self.tkroot = tkroot
        self.queue = Queue(maxsize=0)

    # pack programming data into a message to send over HID
    # includes messagetype, length of data, programming address,
    # the data to program, and 8-bit CRC
    def create_hid_message(self, hidtype, datalen, addr, data):
        bytes_out = struct.pack('<BBI', hidtype, datalen, addr)
        bytes_out = bytes_out + bytes(data)
        crcbyte = crc8_calc(0, bytes_out)
        bytes_out = bytes_out + bytes([crcbyte])

        return bytes_out

    def run(self):
        # Connect to HID device
        self.boot_hid = hid.device()
        try:
            self.boot_hid.open(vendor_id = self.BOOT_VID, product_id = self.BOOT_PID)
        except(OSError):
            # If bootloader device is not present, try entering bootloader from the 
            # main application
            try:
                self.boot_hid.open(vendor_id = self.APP_VID, product_id = self.APP_PID)
                self.boot_hid.send_feature_report(bytes([0, ord('B')])) # jump to bootloader command
                self.queue.put(QueueItem(EVENT_TYPE.USBDRIVERSWAP, None))
                time.sleep(10) # give it a few secs to load the USB device
                self.boot_hid.close()
                self.boot_hid.open(vendor_id = self.BOOT_VID, product_id = self.BOOT_PID)
            except(OSError):
                self.queue.put(QueueItem(EVENT_TYPE.USBDRIVERFAIL, None))
                return

        # open a hex firmware file, save it to a local array
        (app, app_addr) = load_ihex(self.infile_app)
        # Erase only the portion of the Flash that is being overwritten
        self.queue.put(QueueItem(EVENT_TYPE.ERASESTART, None))
        # retval = self.erase_app() # This erases the entire Flash
        
        # Does the application fit in memory?
        if(len(app) > ALLOWED_APP_SIZE):
            self.queue.put(QueueItem(EVENT_TYPE.ERASEFAIL, None))
            return
        # Is the app at the correct address
        if(app_addr != APP_ADDRESS):
            self.queue.put(QueueItem(EVENT_TYPE.ERASEFAIL, None))
            return
        
        # First attempt to erase in smaller chunks. Only newer bootloader (v0.1.2 or 
        # more recent) can do this.
        # Older (unversioned) bootloader can only delete the entire application region
        # by deleting full sectors of Flash
        retval = self.erase_app_by_blocks(app_addr, len(app))

        if(retval != EVENT_TYPE.ERASEPASS):
            # Try simply erasing the whole application region
            # This failure could be from having an older bootloader on there
            # print('block erase failed, trying whole app erase')
            retval = self.erase_app()

        self.queue.put(QueueItem(retval, None))
        if(retval != EVENT_TYPE.ERASEPASS):
            return

        applen = len(app)
        addr = app_addr
        cur_page_addr = app_addr
        num_chunks = int(applen / CHUNKSIZE)
        for curchunk in range(num_chunks):
            bin_chunk = app[(CHUNKSIZE*curchunk):(CHUNKSIZE*(curchunk+1))]
            bin_len = CHUNKSIZE

            bytes_out = self.create_hid_message(0x44, bin_len, addr, bin_chunk)
            # bytes_out = struct.pack('<BBI', 0x44, bin_len, addr)
            # bytes_out = bytes_out + bin_chunk
            # crcbyte = crc8_calc(0, bytes_out)
            # bytes_out = bytes_out + bytes([crcbyte])

            self.boot_hid.write(bytes([0]) + bytes_out)
            status = self.boot_hid.read(65,timeout_ms=1) # no timeout, if nothing is there just skip

            # print('Sent: {}'.format(bytes_out.hex()))
            if(len(status) > 1):
                # print('Write Flash returned with code: {}'.format(status[1]))
                if(status[1] != 0):
                    # print('Write Flash returned with error code: {}'.format(status[1]))
                    self.queue.put(QueueItem(EVENT_TYPE.PROGRAMERR, status[1]))
                    return

            addr = addr + bin_len

            # program a page once all data has been sent for this page
            if((addr % PAGE_SIZE) == 0):
                bytes_out = self.create_hid_message(0x50, 0, cur_page_addr, b'')
                # bytes_out = struct.pack('<BBI', 0x50, 0, cur_page_addr)
                # bytes_out = bytes_out + bytes([crc8_calc(0, bytes_out)])
                
                self.boot_hid.write(bytes([0]) + bytes_out)
                status = self.boot_hid.read(65, timeout_ms = 1000)
                if(len(status) > 1):
                    if(status[1] != 0):
                        #print('Program Flash returned with error code: {}'.format(status[1]))
                        self.queue.put(QueueItem(EVENT_TYPE.PROGRAMERR, status[1]))
                        return
                cur_page_addr = cur_page_addr + PAGE_SIZE

                self.queue.put(QueueItem(EVENT_TYPE.PROGRAMPROGRESS, int(100 * curchunk / num_chunks)))

        if((len(app) % CHUNKSIZE) != 0):
            # Still one more partial chunk to send
            bin_chunk = app[(CHUNKSIZE*num_chunks):]
            bin_len = len(bin_chunk)
            if((bin_len % 8) != 0):
                # append 0xFF as filler
                bin_chunk = bin_chunk + [0xFF]*(8 - bin_len % 8)
            bytes_out = self.create_hid_message(0x44, bin_len, addr, bin_chunk)
            # bytes_out = struct.pack('<BBI', 0x44, bin_len, addr)
            # bytes_out = bytes_out + bin_chunk
            # crcbyte = crc8_calc(0, bytes_out)
            # bytes_out = bytes_out + bytes([crcbyte])

            self.boot_hid.write(bytes([0]) + bytes_out)
            status = self.boot_hid.read(65, timeout_ms=1)

            if(len(status) > 1):
                # print('Write Flash returned with code: {}'.format(status[1]))
                if(status[1] != 0):
                    #print('Write Flash returned with error code: {}'.format(status[1]))
                    self.queue.put(QueueItem(EVENT_TYPE.PROGRAMERR, status[1]))
                    return
            
            # Program final page
            bytes_out = self.create_hid_message(0x50, 0, cur_page_addr, b'')
            # bytes_out = struct.pack('<BBI', 0x50, 0, cur_page_addr)
            # bytes_out = bytes_out + bytes([crc8_calc(0, bytes_out)])
            
            self.boot_hid.write(bytes([0]) + bytes_out)
            status = self.boot_hid.read(65, timeout_ms = 1000)
            if(len(status) > 1):
                if(status[1] != 0):
                    self.queue.put(QueueItem(EVENT_TYPE.PROGRAMERR, status[1]))
                    return
                    #print('Program Flash returned with error code: {}'.format(status[1]))
        else:
            # See if we need to program the last stuff written by the for loop
            if(addr > cur_page_addr):
                # Program final page
                bytes_out = self.create_hid_message(0x50, 0, cur_page_addr, b'')
                # bytes_out = struct.pack('<BBI', 0x50, 0, cur_page_addr)
                # bytes_out = bytes_out + bytes([crc8_calc(0, bytes_out)])
                
                self.boot_hid.write(bytes([0]) + bytes_out)
                status = self.boot_hid.read(65, timeout_ms = 1000)
                if(len(status) > 1):
                    if(status[1] != 0):
                        self.queue.put(QueueItem(EVENT_TYPE.PROGRAMERR, status[1]))
                        return
                        # print('Program Flash returned with error code: {}'.format(status[1]))

        #print('Application written to Flash')
        self.queue.put(QueueItem(EVENT_TYPE.PROGRAMDONE, None))

    def erase_sector(self, erase_addr: int):
        bytes_out = struct.pack('<BBI', 0x65, 0x02, erase_addr)
        crcbyte = crc8_calc(0, bytes_out)
        bytes_out = bytes_out + bytes([crcbyte])
        self.boot_hid.write(bytes([0]) + bytes_out)
        # give it a couple seconds to delete this sector
        status = self.boot_hid.read(64, timeout_ms=2000)
        
        if(len(status) > 1):
            #print("Flash Erase sector returned with code: {}".format(status[1]))
            if(status[1] == 0):
                return EVENT_TYPE.ERASEPASS
        return EVENT_TYPE.ERASEFAIL

    def erase_blocks(self, erase_addr: int, bytes_to_erase: int):
        while(bytes_to_erase > MIN_ERASABLE_SIZE):
            retval = self.erase_16kB_block(erase_addr)
            if(retval != EVENT_TYPE.ERASEPASS):
                return retval
            bytes_to_erase = bytes_to_erase - (2*MIN_ERASABLE_SIZE)
            erase_addr = erase_addr + (2*MIN_ERASABLE_SIZE)
        if(bytes_to_erase > 0):
            retval = self.erase_8kB_block(erase_addr)
        return retval

    def erase_8kB_block(self, erase_addr: int):
        bytes_out = struct.pack('<BBI', 0x65, 0x00, erase_addr)
        crcbyte = crc8_calc(0, bytes_out)
        bytes_out = bytes_out + bytes([crcbyte])
        self.boot_hid.write(bytes([0]) + bytes_out)
        # give it a second to delete this block
        status = self.boot_hid.read(64, timeout_ms=1000)

        if(len(status) > 1):
            #print("Flash Erase 8kB returned with code: {}".format(status[1]))
            if(status[1] == 0):
                return EVENT_TYPE.ERASEPASS
        return EVENT_TYPE.ERASEFAIL

    def erase_16kB_block(self, erase_addr: int):
        bytes_out = struct.pack('<BBI', 0x65, 0x01, erase_addr)
        crcbyte = crc8_calc(0, bytes_out)
        bytes_out = bytes_out + bytes([crcbyte])
        self.boot_hid.write(bytes([0]) + bytes_out)
        # give it a second to delete this block
        status = self.boot_hid.read(64, timeout_ms=1000)

        if(len(status) > 1):
            #print("Flash Erase 16kB returned with code: {}".format(status[1]))
            if(status[1] == 0):
                return EVENT_TYPE.ERASEPASS
        return EVENT_TYPE.ERASEFAIL

    def erase_app_by_blocks(self, erase_addr: int, bytes_to_erase: int):
        divisions = [LARGE_SUBSECTOR_SIZE, SECTOR_SIZE, SECTOR_SIZE, SECTOR_SIZE]
        for div in divisions:
            if(bytes_to_erase > div):
                # Erase this sector and move to the next
                retval = self.erase_sector(erase_addr)
                if(retval != EVENT_TYPE.ERASEPASS):
                    return retval
                erase_addr = erase_addr + div
                bytes_to_erase = bytes_to_erase - div
            elif(bytes_to_erase > 0):
                # Erase blocks in this sector
                retval = self.erase_blocks(erase_addr, bytes_to_erase)
                if(retval != EVENT_TYPE.ERASEPASS):
                    return retval
                erase_addr = erase_addr + bytes_to_erase
                bytes_to_erase = 0
        return EVENT_TYPE.ERASEPASS

    def erase_app(self):
        self.boot_hid.write(b'\x00\x22')
        # give it 10 seconds to clear the app flash
        status = self.boot_hid.read(64, timeout_ms=10000)

        if(len(status) > 1):
            #print("Flash Erase returned with code: {}".format(status[1]))
            if(status[1] == 0):
                return EVENT_TYPE.ERASEPASS
            else:
                #print('Error in erasing flash. Quitting.')
                return EVENT_TYPE.ERASEFAIL
        return EVENT_TYPE.ERASEPASS # if no status reply, assume it worked

class usb_bootloader_gui:
    def __init__(self):
        self.root = tk.Tk()
        self.root.title('Bigscreen USB Firmware Loader')
        self.file_picked_app = False
        self.progressbarvalue = tk.IntVar(self.root, 0)
        self.load_ui()

    def load_ui(self):
        self.content = ttk.Frame(self.root, padding=(5,5,5,5), width=400, height=300)
        self.content.grid(column=0, row=0, sticky=(tk.N, tk.S, tk.E, tk.W))

        self.lbl_app = ttk.Label(self.content, text='App File:')
        self.filenlbl_app = ttk.Label(self.content, text='Pick an application firmware file -->', relief=tk.SUNKEN, wraplength=500)
        self.btnFilePick_app = ttk.Button(self.content, text='...', command=self.file_pick_app)
        self.btnLoad = ttk.Button(self.content, text='Load Firmware', command=self.load_firmware)
        self.statusbar = ttk.Label(self.content, text='Ready', border=1, relief=tk.SUNKEN)
        self.progressbar = ttk.Progressbar(self.content, mode='determinate', variable=self.progressbarvalue)

        self.lbl_app.grid(column=0, row=0, sticky="nsew", padx=5, pady=5)
        self.filenlbl_app.grid(column=1, row=0, columnspan=2, sticky="ew", padx=5, pady=5)
        self.btnFilePick_app.grid(column=3, row=0, padx=5, pady=5)
        self.btnLoad.grid(column=0, row=1, columnspan=4, sticky="ew", padx=5, pady=5)
        self.statusbar.grid(column=0, row=2, columnspan=4,sticky="sew")
        self.progressbar.grid(column=0, row=3, columnspan=4, sticky="nswe")
        
        self.root.columnconfigure(0, weight=1)
        self.root.rowconfigure(0, weight=1)

        self.content.columnconfigure(0, weight=1)
        self.content.columnconfigure(1, weight=1)
        self.content.columnconfigure(2, weight=1)
        self.content.columnconfigure(3, weight=1)
        self.content.rowconfigure(0, weight=1)
        self.content.rowconfigure(1, weight=1)
        self.content.rowconfigure(2, weight=1)
        self.content.rowconfigure(3, weight=1)

    def start_main_loop(self):
        self.root.mainloop()

    def file_pick_app(self):
        self.filen_app = filedialog.askopenfilename(filetypes=[('HEX',"*.hex")],initialdir=os.getcwd())
        if(len(self.filen_app) > 0):
            self.filenlbl_app.config(text=self.filen_app)
            self.file_picked_app = True

    def load_firmware(self):
        if(self.file_picked_app):
            # Create firmware loader thread class
            self.loader = bootloader_thread(self.filen_app, self.root)
            self.loader.start()
            # Create checker to listen for queue items
            self.root.after(10, lambda: self.monitor_thread())
            self.btnLoad["state"] = "disabled"
        else:
            messagebox.showwarning(title="Need file", message="Please pick a .bin/.hex file first")

    def monitor_thread(self):
        # Check queue
        while(not self.loader.queue.empty()):
            # remove item
            queue_item = self.loader.queue.get()
            # Check item type
            if(queue_item.type == EVENT_TYPE.USBDRIVERFAIL):
                messagebox.showerror(title="USB not connected", message="Could not connect to USB bootloader")
            if(queue_item.type == EVENT_TYPE.ERASESTART):
                self.statusbar.config(text="Erasing app region...")
            if(queue_item.type == EVENT_TYPE.ERASEFAIL):
                messagebox.showerror(title="Erase failed", message="Could not erase application region")
            if(queue_item.type == EVENT_TYPE.ERASEPASS):
                self.statusbar.config(text="App region erased, programming...")
            if(queue_item.type == EVENT_TYPE.PROGRAMERR):
                messagebox.showerror(title="Programming error", message="Programming returned error code: {}".format(queue_item.data))
            if(queue_item.type == EVENT_TYPE.PROGRAMPROGRESS):
                # self.statusbar.config(text="Bootloader progress: {}%".format(queue_item.data))
                self.progressbarvalue.set(queue_item.data)
            if(queue_item.type == EVENT_TYPE.PROGRAMDONE):
                self.statusbar.config(text='Application firmware successfully loaded! Reset power to run it.')
                messagebox.showinfo(title="SUCCESS!", message = "Firmware load complete")
            if(queue_item.type == EVENT_TYPE.USBDRIVERSWAP):
                self.statusbar.config(text="Resetting in bootloader mode. Please wait a moment.")

        if(self.loader.is_alive()):
            # rerun monitor
            self.root.after(10, self.monitor_thread)
        else:
            # thread is done, re-enable button
            self.btnLoad["state"] = "normal"
            self.statusbar.config(text="Ready")

if __name__ == "__main__":
    widget = usb_bootloader_gui()
    tk.Tcl().eval('set tcl_platform(threaded)')
    sys.exit(widget.start_main_loop())