import cocotb
from cocotb.triggers import Timer, RisingEdge, FallingEdge, ClockCycles, Join, First
from cocotb.clock import Clock
from random import getrandbits
import struct
import itertools

class FrameData:
    def __init__(self):
        self.dat = []

async def fake_camera_data(dut, data_to_send, line_width):
    # will de-assert PIXEL_DE after every line_width bytes as if it is doing lines of camera data
    # PIXEL_VS will de-assert after all data_to_send has been sent out

    retvals = FrameData()
    dut.PIXEL_VS.value = 0
    dut.PIXEL_DE.value = 0
    # Idle a few clocks here
    await ClockCycles(dut.MIPI_CLK, 50)
    dut.PIXEL_VS.value = 1
    await ClockCycles(dut.MIPI_CLK, 6)
    bytes_sent = 0
    for pxl_y, pxl_uv in data_to_send:
        dut.PIXEL_DE.value = 1
        dut.PIXEL_Y.value = pxl_y
        dut.PIXEL_UV.value = pxl_uv
        retvals.dat.append((pxl_y, pxl_uv))
        await RisingEdge(dut.MIPI_CLK)
        bytes_sent = bytes_sent + 1
        if(bytes_sent == line_width):
            bytes_sent = 0
            dut.PIXEL_DE.value = 0
            await ClockCycles(dut.MIPI_CLK, 6)
    return retvals

async def fake_camera_frame(dut, hsize, vsize):
    VBLANK = 100
    HBLANK = 20

    generated_frame = FrameData()
    dut.PIXEL_VS.value = 0
    dut.PIXEL_DE.value = 0
    dut.PIXEL_Y.value = 0
    dut.PIXEL_UV.value = 0
    # Idle a few clocks here
    await ClockCycles(dut.MIPI_CLK, 50)
    dut.PIXEL_VS.value = 1
    await ClockCycles(dut.MIPI_CLK, VBLANK)
    for yy in range(vsize):
        await ClockCycles(dut.MIPI_CLK, HBLANK)
        dut.PIXEL_DE.value = 1
        for vv in range(hsize):
            newydata = getrandbits(8)
            newuvdata = getrandbits(8)
            dut.PIXEL_Y.value = newydata
            dut.PIXEL_UV.value = newuvdata
            generated_frame.dat.append((newydata, newuvdata))
            await RisingEdge(dut.MIPI_CLK)
        dut.PIXEL_DE.value = 0
    dut.PIXEL_VS.value = 0
    await ClockCycles(dut.MIPI_CLK, VBLANK)

    return generated_frame

async def usb_function_controller(dut, usb_dat):
    txact_to_txpop_delay = 20
    last_txpop_to_pktfin_delay = 20
    pktfin_to_txact_delay = 0
    nak_delay = 1000
    between_packet_delay = 200
    # startup pattern - uses this array to set txpop for each cycle
    # after the end of the array, holds txpop high continuously until end of packet
    txpop_startup_pattern = [True, False, False, True, True, True, False, False, False, True, False]

    while True:
        dut.USB_TXACT.value = 1
        await ClockCycles(dut.USB_CLK, txact_to_txpop_delay)
        if(dut.USB_TXCORK.value == 1):
            # dont do anything on a NAK, when txcork is high
            dut.USB_TXACT.value = 0
            await ClockCycles(dut.USB_CLK, nak_delay)
        else:
            packet_data = []

            # grab the number of bytes we are sending
            bytes_in_packet = int(dut.USB_TXDAT_LEN.value)
            popped = 0
            # go through the startup pattern
            for pop_val in txpop_startup_pattern:
                if(pop_val):
                    dut.USB_TXPOP.value = 1
                    popped = popped + 1
                    await RisingEdge(dut.USB_CLK)
                    packet_data.append(int(dut.USB_TXDAT.value))
                    # my_usb_data.append(int(dut.txdat.value))
                else:
                    dut.USB_TXPOP.value = 0
                    await RisingEdge(dut.USB_CLK)
                if(popped == bytes_in_packet):
                    break
            # proceed with all remaining bytes
            dut.USB_TXPOP.value = 1
            while popped < bytes_in_packet:
                popped = popped + 1
                await RisingEdge(dut.USB_CLK)
                packet_data.append(int(dut.USB_TXDAT.value))
                # my_usb_data.append(int(dut.txdat.value))
            
            usb_dat.dat.append(packet_data)
            dut.USB_TXPOP.value = 0
            await ClockCycles(dut.USB_CLK, last_txpop_to_pktfin_delay)
            dut.USB_PKTFIN.value = 1
            await RisingEdge(dut.USB_CLK)
            dut.USB_PKTFIN.value = 0
            await ClockCycles(dut.USB_CLK, pktfin_to_txact_delay)
            dut.USB_TXACT.value = 0
            await ClockCycles(dut.USB_CLK, between_packet_delay)
            # return my_usb_data

# @cocotb.test()
async def single_frame(dut):
    # Initialize all inputs
    dut.PIXEL_Y.value = 0
    dut.PIXEL_UV.value = 0
    dut.PIXEL_DE.value = 0
    dut.PIXEL_VS.value = 0

    dut.USB_TXACT.value = 0
    dut.USB_TXPOP.value = 0
    dut.USB_PKTFIN.value = 0

    cocotb.start_soon(Clock(dut.MIPI_CLK, 50, units="ns").start()) # mipi clock at 20MHz
    cocotb.start_soon(Clock(dut.USB_CLK, 12, units="ns").start()) # usb clock at 83.33MHz

    await ClockCycles(dut.MIPI_CLK, 10)
    dut.ARESET.value = 1
    await ClockCycles(dut.MIPI_CLK, 4)
    dut.ARESET.value = 0
    await ClockCycles(dut.MIPI_CLK, 20)

    usb_rec_data = FrameData()
    # Start the USB receiver
    cocotb.start_soon(usb_function_controller(dut, usb_rec_data))

    # Send a frame
    frame_data = await fake_camera_frame(dut, 120, 120)

    # count the received bytes
    # should be 2x the number of sent pixels
    sent_pxl_count = len(frame_data.dat)
    received_data_count = sum([len(pkt_dat) - 2 for pkt_dat in usb_rec_data.dat])
    dut._log.info("At end of frame, sent {} pixels, received {} bytes".format(sent_pxl_count, received_data_count))
    dut._log.info("Expect {} more bytes ({} more pixels)".format(2*sent_pxl_count - received_data_count, sent_pxl_count - int(received_data_count/2)))

    while(received_data_count < (2*sent_pxl_count)):
        fret = await First(ClockCycles(dut.USB_CLK, 6000), RisingEdge(dut.buffer_complete_i))
        if(type(fret) == ClockCycles):
            dut._log.warning("Timed out waiting for USB transfer to complete!")
            dut._log.info("Didn't receive enough pixel bytes sent:{}, received:{}".format(sent_pxl_count, received_data_count))
            return
        received_data_count = sum([len(pkt_dat) - 2 for pkt_dat in usb_rec_data.dat])

    dut._log.info("Received all pixel bytes sent:{} == received:{}".format(sent_pxl_count, received_data_count))
    
    received_pixels = []
    # Compare sent and received data
    for pkt in usb_rec_data.dat:
        pixels = [(aa, bb) for aa,bb in zip(pkt[2::2], pkt[3::2])]
        received_pixels.extend(pixels)

    assert len(received_pixels) == len(frame_data.dat), "Data lengths did not match"
    for usb_pxl, mipi_pxl, num_pxl in zip(received_pixels, frame_data.dat, range(len(received_pixels))):
        assert usb_pxl == mipi_pxl, "Pixel #{} did not match. Sent {}, received {}".format(num_pxl, mipi_pxl, usb_pxl)

@cocotb.test()
async def multi_frame(dut):
    # Initialize all inputs
    dut.PIXEL_Y.value = 0
    dut.PIXEL_UV.value = 0
    dut.PIXEL_DE.value = 0
    dut.PIXEL_VS.value = 0

    dut.USB_TXACT.value = 0
    dut.USB_TXPOP.value = 0
    dut.USB_PKTFIN.value = 0

    cocotb.start_soon(Clock(dut.MIPI_CLK, 50, units="ns").start()) # mipi clock at 20MHz
    cocotb.start_soon(Clock(dut.USB_CLK, 12, units="ns").start()) # usb clock at 83.33MHz

    await ClockCycles(dut.MIPI_CLK, 10)
    dut.ARESET.value = 1
    await ClockCycles(dut.MIPI_CLK, 4)
    dut.ARESET.value = 0
    await ClockCycles(dut.MIPI_CLK, 20)

    usb_rec_data = FrameData()
    # Start the USB receiver
    cocotb.start_soon(usb_function_controller(dut, usb_rec_data))

    # Send a frame
    frame_data = await fake_camera_frame(dut, 120, 120)

    # count the received bytes
    # should be 2x the number of sent pixels
    sent_pxl_count = len(frame_data.dat)
    received_data_count = sum([len(pkt_dat) - 2 for pkt_dat in usb_rec_data.dat])
    dut._log.info("At end of frame, sent {} pixels, received {} bytes".format(sent_pxl_count, received_data_count))
    dut._log.info("Expect {} more bytes ({} more pixels)".format(2*sent_pxl_count - received_data_count, sent_pxl_count - int(received_data_count/2)))

    while(received_data_count < (2*sent_pxl_count)):
        fret = await First(ClockCycles(dut.USB_CLK, 6000), RisingEdge(dut.buffer_complete_i))
        if(type(fret) == ClockCycles):
            dut._log.warning("Timed out waiting for USB transfer to complete!")
            dut._log.info("Didn't receive enough pixel bytes sent:{}, received:{}".format(sent_pxl_count, received_data_count))
            return
        received_data_count = sum([len(pkt_dat) - 2 for pkt_dat in usb_rec_data.dat])

    dut._log.info("Received all pixel bytes sent:{} == received:{}".format(sent_pxl_count, received_data_count))
    
    received_pixels = []
    hdr_lengths = []
    hdr_bytes = []
    # Compare sent and received data
    for pkt in usb_rec_data.dat:
        pixels = [(aa, bb) for aa,bb in zip(pkt[2::2], pkt[3::2])]
        received_pixels.extend(pixels)
        hdr_lengths.append(pkt[0])
        hdr_bytes.append(pkt[1])
    
    # check the headers
    for ii, bb in enumerate(hdr_lengths):
        assert bb == 2, "Header length byte wasn't 2 for packet {}".format(ii)

    frame_id = hdr_bytes[0] & 0x01
    for ii, bb in enumerate(hdr_bytes[:-1]):
        assert bb == frame_id, f"Header frame id inconsistent for packet {ii}"
    assert hdr_bytes[-1] == (frame_id | 0x02), "End of frame not set for last packet"

    assert len(received_pixels) == len(frame_data.dat), "Data lengths did not match"
    for usb_pxl, mipi_pxl, num_pxl in zip(received_pixels, frame_data.dat, range(len(received_pixels))):
        assert usb_pxl == mipi_pxl, "Pixel #{} did not match. Sent {}, received {}".format(num_pxl, mipi_pxl, usb_pxl)

    # Frame complete!
    dut._log.info("Frame 1 completed successfully")

    await ClockCycles(dut.MIPI_CLK, 400)

    # Send another frame!
    usb_rec_data.dat = [] # clear the received data
    frame_data = await fake_camera_frame(dut, 120, 120)

    # count the received bytes
    # should be 2x the number of sent pixels
    sent_pxl_count = len(frame_data.dat)
    received_data_count = sum([len(pkt_dat) - 2 for pkt_dat in usb_rec_data.dat])
    dut._log.info("At end of frame, sent {} pixels, received {} bytes".format(sent_pxl_count, received_data_count))
    dut._log.info("Expect {} more bytes ({} more pixels)".format(2*sent_pxl_count - received_data_count, sent_pxl_count - int(received_data_count/2)))

    while(received_data_count < (2*sent_pxl_count)):
        fret = await First(ClockCycles(dut.USB_CLK, 6000), RisingEdge(dut.buffer_complete_i))
        if(type(fret) == ClockCycles):
            dut._log.warning("Timed out waiting for USB transfer to complete!")
            dut._log.info("Didn't receive enough pixel bytes sent:{}, received:{}".format(sent_pxl_count, received_data_count))
            return
        received_data_count = sum([len(pkt_dat) - 2 for pkt_dat in usb_rec_data.dat])

    dut._log.info("Received all pixel bytes sent:{} == received:{}".format(sent_pxl_count, received_data_count))
    
    received_pixels = []
    hdr_lengths = []
    hdr_bytes = []
    # Compare sent and received data
    for pkt in usb_rec_data.dat:
        pixels = [(aa, bb) for aa,bb in zip(pkt[2::2], pkt[3::2])]
        received_pixels.extend(pixels)
        hdr_lengths.append(pkt[0])
        hdr_bytes.append(pkt[1])
    # check the headers
    for ii, bb in enumerate(hdr_lengths):
        assert bb == 2, "Header length byte wasn't 2 for packet {}".format(ii)
    frame_id2 = hdr_bytes[0] & 0x01
    assert frame_id != frame_id2, "Frame ID did not increment between frames"
    for ii, bb in enumerate(hdr_bytes[:-1]):
        assert bb == frame_id2, f"Header frame id inconsistent for packet {ii}"
    assert hdr_bytes[-1] == (frame_id2 | 0x02), "End of frame not set for last packet"

    assert len(received_pixels) == len(frame_data.dat), "Data lengths did not match"
    for usb_pxl, mipi_pxl, num_pxl in zip(received_pixels, frame_data.dat, range(len(received_pixels))):
        assert usb_pxl == mipi_pxl, "Pixel #{} did not match. Sent {}, received {}".format(num_pxl, mipi_pxl, usb_pxl)

    # Frame complete!
    dut._log.info("Frame 2 completed successfully")