import cocotb
from cocotb.triggers import Timer, RisingEdge, FallingEdge, ClockCycles, Join, First
from cocotb.clock import Clock
from random import getrandbits

async def reset_pulse(reset_wire, clk, delay_pre, reset_len, delay_post):
    # all times are given in clocks of the "clk" signal
    await ClockCycles(clk, delay_pre)
    reset_wire.value = 1
    await ClockCycles(clk, reset_len)
    reset_wire.value = 0
    await ClockCycles(clk, delay_post)

async def ram_buffer(dut, data):

    while True:
        """
        always @(posedge dut.clk) begin
            dut.buffer_data <= data[dut.buffer_address]
        end
        """
        await RisingEdge(dut.clk)
        if int(dut.buffer_address.value) < len(data):
            dut.video_buffer_data.value = data[int(dut.buffer_address.value)]
        else:
            dut.video_buffer_data.value = 0

async def usb_function_controller(dut):
    txact_to_txpop_delay = 20
    last_txpop_to_pktfin_delay = 20
    pktfin_to_txact_delay = 0
    # startup pattern - uses this array to set txpop for each cycle
    # after the end of the array, holds txpop high continuously until end of packet
    txpop_startup_pattern = [True, False, False, True, True, True, False, False, False, True, False]

    dut.txact.value = 1
    await ClockCycles(dut.clk, txact_to_txpop_delay)
    if(dut.txcork.value == 1):
        # dont do anything on a NAK, when txcork is high
        dut.txact.value = 0
        return []
    
    my_usb_data = []

    # grab the number of bytes we are sending
    bytes_in_packet = int(dut.txdat_len.value)
    popped = 0
    # go through the startup pattern
    for pop_val in txpop_startup_pattern:
        if(pop_val):
            dut.txpop.value = 1
            popped = popped + 1
            await RisingEdge(dut.clk)
            my_usb_data.append(int(dut.txdat.value))
        else:
            dut.txpop.value = 0
            await RisingEdge(dut.clk)
        if(popped == bytes_in_packet):
            break
    # proceed with all remaining bytes
    dut.txpop.value = 1
    while popped < bytes_in_packet:
        popped = popped + 1
        await RisingEdge(dut.clk)
        my_usb_data.append(int(dut.txdat.value))

    dut.txpop.value = 0
    await ClockCycles(dut.clk, last_txpop_to_pktfin_delay)
    dut.pktfin.value = 1
    await RisingEdge(dut.clk)
    dut.pktfin.value = 0
    await ClockCycles(dut.clk, pktfin_to_txact_delay)
    dut.txact.value = 0
    
    return my_usb_data


async def test_packet_send(dut, data_len, num_usb_packets, last=False, reset=True):
    # initial values
    dut.areset.value = 0
    dut.txact.value = 0
    dut.txpop.value = 0
    dut.pktfin.value = 0
    dut.video_buffer_data.value = 0
    dut.video_frame_length.value = 0
    dut.last_buffer_of_frame.value = 0
    dut.buffer_ready.value = 0

    cocotb.start_soon(Clock(dut.clk, 16.666, units="ns").start()) # usb clock at 60MHz
    if reset:
        await reset_pulse(dut.areset, dut.clk, 10, 4, 10)

    # create some random data
    buffer_dat = [getrandbits(8) for _ in range(data_len)]
    # setup a block of ram for the buffer
    cocotb.start_soon(ram_buffer(dut, buffer_dat))
    # begin the send!
    dut.video_frame_length.value = len(buffer_dat)
    dut.buffer_ready.value = 1
    if(last):
        dut.last_buffer_of_frame.value = 1

    await ClockCycles(dut.clk, 10) # need to latch in the buffer ready while USB is inactive,
    # so wait a little while before starting the usb function controller

    # send USB packets
    packet_data = []
    for i in range(num_usb_packets):
        packet_data.append(await usb_function_controller(dut))
        if num_usb_packets > 1 and i < (num_usb_packets-1):
            await ClockCycles(dut.clk, 30) # some interpacket delay required

    # see if we get a handshake
    await First(RisingEdge(dut.buffer_complete), ClockCycles(dut.clk, 100))
    if(dut.buffer_complete.value == 1):
        await RisingEdge(dut.clk)
        dut.buffer_ready.value = 0
        await First(FallingEdge(dut.buffer_complete), ClockCycles(dut.clk, 100))

    await ClockCycles(dut.clk, 100) # little finishing time

    # data mushing
    full_data = []
    for num, packet in enumerate(packet_data):
        assert packet[0] == 2, "First byte was not packet length (2)"
        if last and (num == len(packet_data) - 1):
            assert packet[1]&0xFE == 0x02, "End of image bit not set"
        else:
            assert packet[1]&0xFE == 0x00, "End of image bit not cleared"
        
        full_data.extend(packet[2:])

    assert len(full_data) == len(buffer_dat), "Sent data length did not match"
    assert full_data == buffer_dat, "Sent data did not match"

    dut._log.info("Packet frame ID was {}".format(packet[1]&0x01))

## Tests ##
"""
@cocotb.test()
async def single_packet_send(dut):
    await test_packet_send(dut, 510, 1)

@cocotb.test()
async def single_packet_send_last(dut):
    await test_packet_send(dut, 510, 1, last=True)
    
@cocotb.test()
async def single_buffer_multi_packet_send(dut):
    await test_packet_send(dut, 1020, 2)

@cocotb.test()
async def single_buffer_multi_packet_send_last(dut):
    await test_packet_send(dut, 1020, 2, last=True)

@cocotb.test()
async def single_packet_send_uneven_size(dut):
    await test_packet_send(dut, 4, 1, last=True)

@cocotb.test()
async def single_buffer_multi_packet_send_uneven_size(dut):
    await test_packet_send(dut, 1021, 3, last=True)


@cocotb.test()
async def multi_buffer_multi_packet_send_uneven_size(dut):
    await test_packet_send(dut, 2040, 4, last=False, reset=True)
    await test_packet_send(dut, 2040, 4, last=False, reset=False)
    await test_packet_send(dut, 2040, 4, last=False, reset=False)
    await test_packet_send(dut, 1055, 3, last=True, reset=False)
"""

@cocotb.test()
async def frame_increment(dut):
    await test_packet_send(dut, 2030, 4, last=True, reset=True)
    await test_packet_send(dut, 2040, 4, last=False, reset=False)
    await test_packet_send(dut, 2040, 4, last=False, reset=False)
    await test_packet_send(dut, 110, 1, last=True, reset=False)
    await test_packet_send(dut, 2040, 4, last=False, reset=False)