import re
import os
import argparse
import struct
from typing import NamedTuple

class ListItem_t(NamedTuple):
    xItemValue: int
    pxNext: int
    pxPrevious: int
    pvOwner: int
    pvContainer: int

class tskTCB(NamedTuple):
    pxTopOfStack:int
    xStateListItem: ListItem_t
    xEventListItem: ListItem_t
    uxPriority: int
    pxStack: int
    pcTaskName: str # fixed to 16 bytes
    uxTCBNumber: int
    uxTaskNumber: int
    uxBasePriority: int
    uxMutexesHeld: int
    ulNotifiedValue: int
    ucNotifyState: int

def load_error_file(filename:str):
    if(not os.path.exists(filename)):
       return None
    
    file_contents = {}
    file_contents["regs"] = [0]*16
    file_contents["fpu_regs"] = [0]*32
    file_contents["regions"] = []

    memory_started = False
    in_mem_section = False

    with open(filename, "r") as fil:
        for eachline in fil.readlines():
            if(not memory_started):
                # look for hfsr, cfsr, and registers
                mm = re.search(r"Hard fault status register:\s*0x([A-Fa-f\d]+)",eachline)
                if(mm):
                    file_contents["hfsr"] = int(mm[1],16)
                mm = re.search(r"Configurable fault status register:\s*0x([A-Fa-f\d]+)", eachline)
                if(mm):
                    file_contents["cfsr"] = int(mm[1],16)
                mm = re.search(r"R(\d+)\)?:\s*0x([A-Fa-f\d]+)",eachline)
                if(mm):
                    file_contents["regs"][int(mm[1])] = int(mm[2],16)
                mm = re.search(r"S(\d+):\s*0x([A-Fa-f\d]+)",eachline)
                if(mm):
                    file_contents["fpu_regs"][int(mm[1])] = int(mm[2],16)
                mm = re.search(r"PSR:\s*0x([A-Fa-f\d]+)",eachline)
                if(mm):
                    file_contents["psr"] = int(mm[1],16)
                mm = re.search(r"Exception's PSR:\s*0x([A-Fa-f\d]+)",eachline)
                if(mm):
                    file_contents["epsr"] = int(mm[1],16)
                mm = re.search(r"MSP:\s*0x([A-Fa-f\d]+)",eachline)
                if(mm):
                    file_contents["msp"] = int(mm[1],16)
                mm = re.search(r"PSP:\s*0x([A-Fa-f\d]+)",eachline)
                if(mm):
                    file_contents["psp"] = int(mm[1],16)
                mm = re.search(r"FPSCR:\s*0x([A-Fa-f\d]+)",eachline)
                if(mm):
                    file_contents["fpscr"] = int(mm[1],16)

                # check if we can start the memory regions
                if("MEMORY REGION" in eachline):
                    memory_started = True
            else:
                # on first address, save this as the start address of this region
                if(not in_mem_section):
                    mm = re.search(r"([a-fA-F\d]{8})\s+[a-fA-F\d]+",eachline)
                    if(mm):
                        in_mem_section = True
                        starting_address = int(mm[1],16)
                        file_contents["regions"].append({"start":int(mm[1],16),"data":[]})
                # now I can ignore the initial address
                # but check that it is a valid number before grabbing data
                # otherwise we're done with this memory section
                try:
                    int(eachline[:8],16)
                    after_address = eachline[8:].lstrip().rstrip()
                    # create pairs of characters
                    char_pairs = map(''.join, zip(*[iter(after_address)]*2))
                    file_contents["regions"][-1]["data"].extend([int(bb,16) for bb in char_pairs])
                except:
                    in_mem_section = False
    return file_contents

def hexstr(ins:list[int]|bytes|tuple[int]) -> str:
    return ''.join(['{:02X}'.format(bb) for bb in ins])

def create_crashdebug_file(filename: str, file_contents:dict):
    with open(filename, "w") as fil:
        # crash catcher signature
        fil.write(hexstr([ord('c'), ord('C'), 3, 0])+'\n')
        # flags (which is just a bit for the fpu enabled)
        flags = struct.pack('<I',1)
        fil.write(hexstr(flags)+'\n')

        # Registers. R0 through R15, PSR, MSP, PSP, and exception PSR
        # - first 8 registers
        fil.write(hexstr(struct.pack('<'+'I'*8, *file_contents['regs'][0:8])) + '\n')
        # - next 8
        fil.write(hexstr(struct.pack('<'+'I'*8, *file_contents['regs'][8:16])) + '\n')
        # - special regs
        fil.write(hexstr(struct.pack('<IIII', file_contents['psr'], file_contents['msp'], file_contents['psp'], file_contents['epsr'])) + '\n')
        # FPU registers - S0-S31 and FPSCR
        # - split into packs of 8
        fil.write(hexstr(struct.pack('<'+'I'*8, *file_contents['fpu_regs'][0:8])) + '\n')
        fil.write(hexstr(struct.pack('<'+'I'*8, *file_contents['fpu_regs'][8:16])) + '\n')
        fil.write(hexstr(struct.pack('<'+'I'*8, *file_contents['fpu_regs'][16:24])) + '\n')
        fil.write(hexstr(struct.pack('<'+'I'*8, *file_contents['fpu_regs'][24:32])) + '\n')
        # - and the FPSCR
        fil.write(hexstr(struct.pack('<I',file_contents['fpscr'])) + '\n')
        ### done with registers ###

        # Memory regions
        for region in file_contents['regions']:
            startaddr = region['start']
            endaddr = startaddr + len(region['data'])
            fil.write(hexstr(struct.pack('<II',startaddr, endaddr)) + '\n')
            # print 16 bytes per line. 
            list(map(lambda n: fil.write(hexstr(n)+'\n'), zip(*[iter(region['data'])]*16)))
            # anything left at the end? print it on the last line
            remaining = len(region['data']) % 16
            if(remaining != 0):
                fil.write(hexstr(region['data'][-remaining:]) + '\n')
        ### done with memory regions ###
        # fault status registers. HFSR and CFSR
        # optional but since we have it, why not. 
        # crashcatcher also allows for debug fault status register, memmanage fault address register, and bus fault address register
        # but we're not catching those
        startaddr = 0xE000ED28
        endaddr = 0xE000ED3C
        fil.write(hexstr(struct.pack('<II',startaddr,endaddr)) + '\n')
        fil.write(hexstr(struct.pack('<IIIII', file_contents['cfsr'], file_contents['hfsr'], 0, 0, 0)) + '\n')



# Finds and returns arbitrary memory contents within the 
# space provided in file_contents. Any of the memory regions
# can be searched to find the requested address.
# Returns a list of ints if found, otherwise None.
# Cannot span multiple regions. If the length extends
# past the end of the region (even if the start address is
# valid inside that region) will return None.
def get_memory(file_contents, address, length):
    for reg in file_contents['regions']:
        if(address >= reg['start'] and (address + length) <= (reg['start'] + len(reg['data']))):
            offset = address - reg['start']
            return reg['data'][offset:(offset+length)]
        
# Finds the 4-byte word at address within the 
# memory space of the file contents
# Returns value as a single integer
# If it cannot be found, returns None
# Cannot handle words spanning multiple
# regions. All 4 bytes must be in the
# same memory region.
def get_word(file_contents, address):
    memval = get_memory(file_contents, address, 4)
    if(memval is None):
        return None
    return struct.unpack('<I', memval)[0]

# Same as get_word but works in half-words (2-byte values)
def get_halfword(file_contents, address):
    memval = get_memory(file_contents, address, 2)
    if(memval is None):
        return None
    return struct.unpack('<H', memval)[0]

# Same as get_word but works in bytes
def get_byte(file_contents, address):
    return get_memory(file_contents, address, 1)


def pull_state_from_tcb(file_contents, address):
    # Task Control Block (TCB) is 92 bytes
    # pxTopOfStack (4) 
    # xStateListItem (20) - 
    # xEventListItem (20)
    # uxPriority (4)
    # pxStack (4)
    # pcTaskName (16)
    # uxTCBNumber (4)
    # uxTaskNumber (4)
    # uxBasePriority (4)
    # uxMutexesHeld (4)
    # ulNotifiedValue (4)
    # ucNotifyState (1)
    # -- plus 3 bytes padding
    # = 92 total
    tcb = get_memory(file_contents, address, 92)
    (pxTopOfStack,) = struct.unpack('<I',bytes(tcb[:4]))
    xStateListItem = ListItem_t(*struct.unpack('<IIIII',bytes(tcb[4:24])))
    xEventListItem = ListItem_t(*struct.unpack('<IIIII',bytes(tcb[24:44])))
    (uxPriority, pxStack,) = struct.unpack('<II',bytes(tcb[44:52]))
    pcTaskName = bytes(tcb[52:68])
    (uxTCBNumber, uxTaskNumber, uxBasePriority, uxMutexesHeld, ulNotifiedValue, ucNotifyState, ) = struct.unpack('<IIIIIB',bytes(tcb[68:89]))

    my_tcb = tskTCB(pxTopOfStack, xStateListItem, xEventListItem, uxPriority, pxStack, pcTaskName, uxTCBNumber, uxTaskNumber, uxBasePriority,
                    uxMutexesHeld,  ulNotifiedValue, ucNotifyState)
    pretty_task_name = my_tcb.pcTaskName.split(b'\x00')[0].decode('ascii')

    #print("Changing active registers to task \""+pretty_task_name + "\"")

    # Now that we know the task's top of stack, we can start popping values off
    # going upwards in stack from the definition in FreeRTOS "port.c" xPortPendSVHandler
    # R4-R11, R14 (LR) <-- note this is the exception's link register
    # if(using fpu [[check LR for bit 1<<4 (0x10)]])
    #   16 fpu registers (s16-s31)
    # Then the standard exception entry stack
    # R0-R3, R12, R14, Return Address, xPSR
    # if(using fpu)
    #   16 lower fpu registers (s0-15), fpscr
    pendsv_stacked_regs = get_memory(file_contents, my_tcb.pxTopOfStack, 9*4) # R4-R11 and Exception LR
    pendsv_stacked_regs = list(map(lambda n: struct.unpack('<I',bytes(n))[0], zip(*[iter(pendsv_stacked_regs)]*4)))
    fpu_used = (0x0 == (pendsv_stacked_regs[8] & 0x10))
    if(fpu_used):
        high_fpu_regs = get_memory(file_contents, my_tcb.pxTopOfStack + 36, 16*4)
        exception_stacked_regs = get_memory(file_contents, my_tcb.pxTopOfStack + 100, 8*4)
        exception_stacked_fpu_regs = get_memory(file_contents, my_tcb.pxTopOfStack + 132, 17*4)
    else:
        exception_stacked_regs = get_memory(file_contents, my_tcb.pxTopOfStack + 36, 8*4)

    # Alright time to mangle our processor state
     
    exception_stacked_regs = list(map(lambda n: struct.unpack('<I',bytes(n))[0], zip(*[iter(exception_stacked_regs)]*4)))
    if(fpu_used):
        high_fpu_regs = list(map(lambda n: struct.unpack('<I',bytes(n))[0], zip(*[iter(high_fpu_regs)]*4))) 
        exception_stacked_fpu_regs = list(map(lambda n: struct.unpack('<I',bytes(n))[0], zip(*[iter(exception_stacked_fpu_regs)]*4))) 

    # Figure out our new stack pointer
    new_sp = my_tcb.pxTopOfStack
    if(fpu_used):
        new_sp = new_sp + (50*4)+4 # yeah that's a lot of registers stacked if FPU is used :/
        # the extra +4 is because armv7m always adds an extra padding word on FPU exception entry stacking
        # to make sure it is 8-byte aligned (otherwise it would stack 25 words, not an 8-byte alignment)
    else:
        new_sp = new_sp + (17*4)
    # do we need to offset for 8-byte alignment?
    # last value in exception stacked regs is the stacked PSR.
    # Bit 9 of this PSR tells us if alignment occurred
    if(0 != (exception_stacked_regs[-1] & (1<<9))):
        new_sp = new_sp + 4
    # exception stacked: [R0, R1, R2, R3, R12, R14, PC, PSR]
    # pendsv_stacked: [R4, R5, R6, R7, R8, R9, R10, R11, Exception R14]
    file_contents["regs"] = exception_stacked_regs[0:4] + pendsv_stacked_regs[0:8] + [exception_stacked_regs[4]] + [new_sp] + exception_stacked_regs[5:7]
    file_contents["psr"] = exception_stacked_regs[-1]
    if(fpu_used):
        file_contents["fpu_regs"] = exception_stacked_fpu_regs[0:16] + high_fpu_regs
        file_contents["fpscr"] = exception_stacked_fpu_regs[-1]
    file_contents["psp"] = new_sp

    return (file_contents, pretty_task_name)

if __name__ == '__main__':
    ap=argparse.ArgumentParser(prog="create_crashdebug",
                               description="Reformats an error handler output from HMDUtility into CrashDebug format")
    
    ap.add_argument("inputfile", help="Filename of the error log")
    ap.add_argument("-s","--stack-exchange", metavar='ADDR', help="Exchanges the current processor state with the stacked registers at ADDR")

    args = ap.parse_args()

    file_contents = load_error_file(args.inputfile)

    (pathhead, pathtail) = os.path.split(args.inputfile)
    parts = pathtail.split('.')
    
    if(args.stack_exchange):
        if(args.stack_exchange.startswith('0x')):
            tcb_address = int(args.stack_exchange[2:], 16)
        else:
            tcb_address = int(args.stack_exchange, 10)
        
        (file_contents, task_name) = pull_state_from_tcb(file_contents, tcb_address)
        outtail = 'cC_' + parts[0] + task_name.replace(' ','_') + '.txt'

    else:
        outtail = 'cC_' + parts[0] + '.txt'

    create_crashdebug_file(os.path.join(pathhead, outtail), file_contents)

