# Creates a firmware image ready for upload with the USB bootloader.
# This tool adds the firmware length and CRC-32 code.

import os
import sys
from typing import NamedTuple
import ihex_loader
from crc import Crc32, Calculator
import re
from datetime import datetime

import argparse

class MCU_Config(NamedTuple):
    name: str
    flash_address: int
    flash_size: int
    bootloader_size: int
    signature: int
    page_size: int
    vector_offset: int
    crc_bits: int
    rows_per_block: int

samg55 = MCU_Config('SAMG55', 0x00400000, 0x00080000, 0, 0x1302FD67, 0x200, 0x400, 16, 4)

USB_BOOTLOADER_SIZE = 0x4000 # 16kB for USB bootloader, app starts after this

# hex_location should be a folder name
def get_sw_ver_from_file(hex_location):
    # assume the directory has the hex_location
    # in a build or debug or something folder
    # and the file "sw_ver.h" is src folder
    # at the same folder level
    # {proj_dir}
    # -- Debug
    # -- src
    maj_ver = None
    min_ver = None
    pat_ver = None
    one_step_up = os.path.abspath(os.path.join(hex_location,'..'))
    src_folder = os.path.abspath(os.path.join(one_step_up,'src'))
    # Ensure that the path exists
    sw_ver_h_path = os.path.join(src_folder,'sw_ver.h')
    if(os.path.exists(sw_ver_h_path)):
        # Read this file and extract the version string
        findmajor = re.compile(r"#define\s*MAJOR_VERSION\s*(\w*)")
        findminor = re.compile(r"#define\s*MINOR_VERSION\s*(\w*)")
        findpatch = re.compile(r"#define\s*PATCH_VERSION\s*(\w*)")
        with open(sw_ver_h_path, 'r') as fil:
            for eachline in fil.readlines():
                mmajor = findmajor.search(eachline)
                mminor = findminor.search(eachline)
                mpatch = findpatch.search(eachline)
                if(mmajor):
                    maj_ver = mmajor[1]
                if(mminor):
                    min_ver = mminor[1]
                if(mpatch):
                    pat_ver = mpatch[1]
    if(maj_ver is None or min_ver is None or pat_ver is None):
        return "XXXXXX"
    else:
        return maj_ver + min_ver + pat_ver
        try:
            maj_int = int(maj_ver)
            maj_str = "{:02}".format(maj_int)
        except ValueError:
            maj_str = maj_ver
        try:
            min_int = int(min_ver)
            min_str = "{:02}".format(min_int)
        except ValueError:
            min_str = min_ver
        try:
            pat_int = int(pat_ver)
            pat_str = "{:02}".format(pat_int)
        except ValueError:
            pat_str = pat_ver
        return maj_str + min_str + pat_str


def load_bin(filename):
    with open(filename, 'rb') as f:
        return [bb for bb in f.read()]

def save_bin(filename:str, bindata:list[int]):
    output_bytes = bytes(bindata)
    with open(filename, 'wb') as f:
        f.write(output_bytes)

def generate_lut_entry(lut_index):
    crc_tab_entry = lut_index
    for i in range(8, 0, -1):
        if (crc_tab_entry & 0x01):
            crc_tab_entry = (crc_tab_entry >> 1) ^ 0xEDB88320

        else:
            crc_tab_entry = crc_tab_entry >> 1
    return(crc_tab_entry)

def calc_end_crc(bfr, length):
    crc32 = 0xffffffff

    for i in range(0,length):
        data_byte = bfr[i]

        idx = (data_byte ^ (crc32 & 0x000000FF))

        # work out the CRC
        # the CRC function is the IEEE CRC32 function with a polynomial
        # of 0xEDB88320, implemented using a lookup table
        crc32 = (crc32 >> 8)

        crc32 = crc32 ^ generate_lut_entry(idx)

    return crc32

def record_16_crc(crc, ch):
    crcPoly = 0x8005

    m = (crc << 8) | ch

    for n in range(0,8):
        m = m << 1
        if (m & 0x1000000):
            m ^= (crcPoly << 8)

    return (m >> 8) & 0xffff

def get_crc16(frame):
    Frame_CRC_val = 0
    for j in range(2, len(frame)):
        Frame_CRC_val = record_16_crc(Frame_CRC_val, frame[j])
    Frame_CRC_val = record_16_crc(Frame_CRC_val, 0)
    Frame_CRC_val = record_16_crc(Frame_CRC_val, 0)

    return Frame_CRC_val

def get_crc32(frame):
    return calc_end_crc(frame[2:], len(frame)-2)

def check_is_blank_or_ff(input_bytes):
    for bb in input_bytes:
        if(bb != 0 and bb != 0xFF):
            return False

    return True

def create_full_image_with_bootloader(input_file, bootloader_file, output_file):
     # Check if input_file is recognizable format
    (head,tail) = os.path.split(input_file)
    if(tail.split('.')[-1] == 'hex'):
        # Hex format file
        try:
            app_image, startaddr = ihex_loader.load_ihex(input_file)
        except:
            return (False, "Failed to load hex file, check file format.")
        if(startaddr != (samg55.flash_address + samg55.bootloader_size + USB_BOOTLOADER_SIZE)):
            return (False, "Start address of application doesn't match configuration. Expected 0x{:08X}, saw 0x{:08x}.".format((
                samg55.flash_address + samg55.bootloader_size + USB_BOOTLOADER_SIZE), startaddr))
    elif(tail.split('.')[-1] == 'bin'):
        # Binary format file
        app_image = load_bin(input_file)
        startaddr = samg55.flash_address + samg55.bootloader_size + USB_BOOTLOADER_SIZE
    else:
        return (False, "Unknown input file type. Only accepts *.bin or *.hex files.")
    
     # Check if bootloader is recognizable format
    (head,tail) = os.path.split(bootloader_file)
    if(tail.split('.')[-1] == 'hex'):
        # Hex format file
        try:
            boot_image, bootstartaddr = ihex_loader.load_ihex(bootloader_file)
        except:
            return (False, "Failed to load hex file, check file format.")
        if(bootstartaddr != (samg55.flash_address + samg55.bootloader_size)):
            return (False, "Start address of bootloader doesn't match configuration. Expected 0x{:08X}, saw 0x{:08x}.".format((
                samg55.flash_address + samg55.bootloader_size), bootstartaddr))
    elif(tail.split('.')[-1] == 'bin'):
        # Binary format file
        boot_image = load_bin(bootloader_file)
        bootstartaddr = samg55.flash_address + samg55.bootloader_size
    else:
        return (False, "Unknown bootloader file type. Only accepts *.bin or *.hex files.")
    
    # Fit to 32-byte alignment, but leave 4 bytes at the end for CRC-32
    while (len(app_image) % 32) != 28:
        app_image.append(0xff)

    # Place application length at predetermined slot. Note: this should be empty of any program
    # data. Make sure your linker script skips this section and leaves it empty.
    if(not check_is_blank_or_ff(app_image[samg55.vector_offset:(samg55.vector_offset + 4)])):
        return (False, "Vector location is not blank. Data exists at start location + 0x400.")
    app_image[samg55.vector_offset] = (int(len(app_image)) >> 0) & 0xff
    app_image[samg55.vector_offset+1] = (int(len(app_image)) >> 8) & 0xff
    app_image[samg55.vector_offset+2] = (int(len(app_image)) >> 16) & 0xff
    app_image[samg55.vector_offset+3] = (int(len(app_image)) >> 24) & 0xff

    # Place calculated CRC in the last 4 bytes, after the length of the application
    # Calculate CRC on application
    crccalc = Calculator(Crc32.BZIP2, True)
    crc32val = crccalc.checksum(bytes(app_image))
    # crc32 = calc_end_crc(app_image, len(app_image))
    app_image.append(int((crc32val >> 0) & 0xff))
    app_image.append(int((crc32val >> 8) & 0xff))
    app_image.append(int((crc32val >> 16) & 0xff))
    app_image.append(int((crc32val >> 24) & 0xff))

    # Join the bootloader and application hex files. Ensure there are no gaps between them, data
    # must be contiguous.

    if(len(boot_image) > USB_BOOTLOADER_SIZE):
        return (False, "Bootloader region exceeds maximum reserved size (0x{:08X} bytes of 0x{:08X} bytes available)".format(
            len(boot_image),USB_BOOTLOADER_SIZE))

    while(len(boot_image) < USB_BOOTLOADER_SIZE):
        boot_image.append(0xFF)


    (head,tail) = os.path.split(output_file)
    if(tail.split('.')[-1] == 'hex'):
        # Save new Hex file
        ihex_loader.save_ihex(output_file, boot_image + app_image, bootstartaddr)
        return (True, "")
    elif(tail.split('.')[-1] == 'bin'):
        # Save new binary file
        save_bin(output_file, boot_image + app_image)
        return (True, "")
    return (False, 'Incompatible output format. Must be .hex or .bin')

def create_fw_image_for_usb_bootloader(input_file, output_file):
    # Check if input_file is recognizable format
    (head,tail) = os.path.split(input_file)
    if(tail.split('.')[-1] == 'hex'):
        # Hex format file
        try:
            app_image, startaddr = ihex_loader.load_ihex(input_file)
        except:
            return (False, "Failed to load hex file, check file format.")
        if(startaddr != (samg55.flash_address + samg55.bootloader_size + USB_BOOTLOADER_SIZE)):
            return (False, "Start address doesn't match configuration.")
        
    elif(tail.split('.')[-1] == 'bin'):
        # Binary format file
        app_image = load_bin(input_file)
        startaddr = samg55.flash_address + samg55.bootloader_size + USB_BOOTLOADER_SIZE
    else:
        # Can't read any other type
        return (False, "Unknown file type. Only accepts *.bin or *.hex files.")

    # Fit to 32-byte alignment, but leave 4 bytes at the end for CRC-32
    while (len(app_image) % 32) != 28:
        app_image.append(0xff)

    # Place application length at predetermined slot. Note: this should be empty of any program
    # data. Make sure your linker script skips this section and leaves it empty.
    if(not check_is_blank_or_ff(app_image[samg55.vector_offset:(samg55.vector_offset + 4)])):
        return (False, "Vector location is not blank. Data exists at start location + 0x400.")
    app_image[samg55.vector_offset] = (int(len(app_image)) >> 0) & 0xff
    app_image[samg55.vector_offset+1] = (int(len(app_image)) >> 8) & 0xff
    app_image[samg55.vector_offset+2] = (int(len(app_image)) >> 16) & 0xff
    app_image[samg55.vector_offset+3] = (int(len(app_image)) >> 24) & 0xff
    
    # Place calculated CRC in the last 4 bytes, after the length of the application
    # Calculate CRC on application
    crccalc = Calculator(Crc32.BZIP2, True)
    crc32val = crccalc.checksum(bytes(app_image))
    # crc32 = calc_end_crc(app_image, len(app_image))
    app_image.append(int((crc32val >> 0) & 0xff))
    app_image.append(int((crc32val >> 8) & 0xff))
    app_image.append(int((crc32val >> 16) & 0xff))
    app_image.append(int((crc32val >> 24) & 0xff))
    (head,tail) = os.path.split(output_file)
    if(tail.split('.')[-1] == 'hex'):
        # Save new Hex file
        ihex_loader.save_ihex(output_file, app_image, startaddr)
        return (True, "")
    elif(tail.split('.')[-1] == 'bin'):
        # Save new binary file
        save_bin(output_file, app_image)
        return (True, "")
    
    return (False, 'Incompatible output format. Must be .hex or .bin')


def file_pick(filename):
    (head,tail) = os.path.split(filename)

    # Figure out some useful info about this file
    if(tail.split('.')[-1] == 'hex'):
        # Hex format file
        try:
            app_image, startaddr = ihex_loader.load_ihex(self.filen)
        except:
            print("Error: Failed to load hex file, check file format.")
            return

        # Check app start address
        if(startaddr != (samg55.flash_address + samg55.bootloader_size + USB_BOOTLOADER_SIZE)):
            print("Error: Start address doesn't match configuration (0x{:08X})".format((samg55.flash_address + samg55.bootloader_size + USB_BOOTLOADER_SIZE)))
            return
        print('Hex file loaded: {} bytes, starting at 0x{:08X}'.format(len(app_image), startaddr))
    elif(tail.split('.')[-1] == 'bin'):
        # Binary format file
        app_image = load_bin(filename)
        startaddr = samg55.flash_address + samg55.bootloader_size + USB_BOOTLOADER_SIZE
        print('Bin file loaded: {} bytes, starting at 0x{:08X}'.format(len(app_image), startaddr))
    else:
        print('Error: App format must be either .hex or .bin')
        return

    # Check app length
    if(len(app_image) + samg55.bootloader_size + USB_BOOTLOADER_SIZE > samg55.flash_size - 4):
        # Note: the minus 4 at the end comes from needing 4 bytes for the CRC-32 after the app firmware image
        print("Error: Application length exceeds flash region.")
        return

    return app_image, startaddr

def create_usb_file(self, filename, suffix):
    app_image, startaddr = file_pick(filename)
    # Split the filename so we can add a suffix like "_FW" to the end
    (head, tail) = os.path.split(filename)
    tail = tail.split('.')[0]+suffix+'.hex'
    outfile = os.path.join(head,tail)
    
    # Fit to 32-byte alignment, but leave 4 bytes at the end for CRC-32
    while (len(app_image) % 32) != 28:
        app_image.append(0xff)

    # Place application length at predetermined slot. Note: this should be empty of any program
    # data. Make sure your linker script skips this section and leaves it empty.
    if(not check_is_blank_or_ff(app_image[samg55.vector_offset:(samg55.vector_offset + 4)])):
        print("Error: App length location in firmware is not blank. Data exists at start location + 0x400.")
        return
    app_image[samg55.vector_offset] = (int(len(app_image)) >> 0) & 0xff
    app_image[samg55.vector_offset+1] = (int(len(app_image)) >> 8) & 0xff
    app_image[samg55.vector_offset+2] = (int(len(app_image)) >> 16) & 0xff
    app_image[samg55.vector_offset+3] = (int(len(app_image)) >> 24) & 0xff

    # Place calculated CRC in the last 4 bytes, after the data of the application
    # Compute CRC
    crccalc = Calculator(Crc32.BZIP2, True)
    crc32val = crccalc.checksum(bytes(app_image))
    app_image.append(int((crc32val >> 0) & 0xff))
    app_image.append(int((crc32val >> 8) & 0xff))
    app_image.append(int((crc32val >> 16) & 0xff))
    app_image.append(int((crc32val >> 24) & 0xff))

    # Save new Hex file
    ihex_loader.save_ihex(outfile, app_image, startaddr)
    print('Firmware file created! {}'.format(outfile))

    create_file_pass, ret_msg = create_fw_image_for_usb_bootloader(self.filen,outfile)
    if(create_file_pass):
        self.statusbar.config(text='Firmware file created! {}'.format(outfile))
    else:
        self.statusbar.config(text='Failed to create firmware for USB bootloader: {}'.format(ret_msg))

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Creates firmware hex for USB booloader.")
    parser.add_argument('-s', '--suffix', help='suffix to add to filename before the .hex (default=_FW)')
    parser.add_argument('-o', '--output', help='output file name (.hex will be automatically appended)')
    parser.add_argument('-b', '--bootloader', help='Bootloader bin/hex file to join with the application file')
    parser.add_argument('file_name', help='Input filename')

    args = parser.parse_args()
    
    (inhead,intail) = os.path.split(args.file_name)
    
    if(args.suffix and args.output):
        print('Cannot provide both --output and --suffix arguments, choose one or the other')
        quit()

    if(args.output):
        (outhead, outtail) = os.path.split(args.output)
        parts = outtail.split('.')
        if(len(parts) == 1 or parts[1] == '' or parts[1] == 'hex'):
            tail_outfile = parts[0] + '.hex'
            swver = get_sw_ver_from_file(inhead)
            date = datetime.now().strftime("%y%m%d") # creates zero-padded YYMMDD format
            # parse the output file name for "{}" style format strings
            # only accepted arguments are VERSION and DATE
            formstring = re.compile(r"\{VERSION\}")
            version_replaced = formstring.sub(swver,tail_outfile)
            formstring = re.compile(r"\{DATE\}")
            date_replaced = formstring.sub(date, version_replaced)

            if(outhead == ''):
                outfile = os.path.join(inhead,date_replaced)
            else:
                outfile = os.path.join(outhead,date_replaced)
        else:
            print('Incompatible output format. Must be .hex')
            quit()
    elif(args.suffix):
        tail_outfile = intail.split('.')[0]+args.suffix+'.hex'
        outfile = os.path.join(inhead,tail_outfile)
    else:
        tail_outfile = intail.split('.')[0]+'_FW.hex'
        outfile = os.path.join(inhead,tail_outfile)

    if(args.bootloader):
        passed, error_str = create_full_image_with_bootloader(args.file_name, args.bootloader, outfile)
    else:
        passed, error_str = create_fw_image_for_usb_bootloader(args.file_name, outfile)
    
    if(not passed):
        print(error_str)
    else:
        print('Firmware file created! {}'.format(outfile))

    parts = outfile.split('.')
    outfile = parts[0] + '.bin'
    if(args.bootloader):
        passed, error_str = create_full_image_with_bootloader(args.file_name, args.bootloader, outfile)
    else:
        passed, error_str = create_fw_image_for_usb_bootloader(args.file_name, outfile)

    if(not passed):
        print(error_str)
    else:
        print('Firmware file created! {}'.format(outfile)) 