import cv2
import time
from typing import List, Optional
from fpga_commands import reset_fpga, reconfig_fpga, set_cam_reg, set_pwm_duty
from loguru import logger

logger.add("capture.log")

class FPSCounter:
    """
    A class to measure and display frames per second.
    """
    def __init__(self):
        self.prev_time = time.time()
        self.curr_fps = 0
        self.frame_count = 0
        self.update_interval = 1.0  # Update FPS every second

    def update(self) -> float:
        """
        Update FPS calculation.
        
        Returns:
            float: Current FPS value
        """
        self.frame_count += 1
        curr_time = time.time()
        time_diff = curr_time - self.prev_time

        if time_diff >= self.update_interval:
            self.curr_fps = self.frame_count / time_diff
            self.frame_count = 0
            self.prev_time = curr_time

        return self.curr_fps
    
    def is_stream_healthy(self) -> bool:
        """
        Check if the stream is healthy (FPS > 30).
        
        Returns:
            bool: True if stream is healthy, False otherwise
        """
        return self.curr_fps > 30

def enumerate_cameras() -> List[int]:
    """
    Enumerates all available cameras on the system.
    
    Returns:
        List[int]: List of available camera indices
    """
    logger.info("Starting camera enumeration")
    available_cameras = []
    index = 0
    
    while True:
        cap = cv2.VideoCapture(index)
        if not cap.isOpened():
            break
        available_cameras.append(index)
        cap.release()
        index += 1
        
    logger.info(f"Found {len(available_cameras)} camera(s): {available_cameras}")
    return available_cameras

def enumerate_800x400_camera() -> cv2.VideoCapture:
    """
    Enumerates cameras and returns the first camera device with 800x400 resolution.
    
    Returns:
        cv2.VideoCapture: Camera device with 800x400 resolution
        
    Raises:
        RuntimeError: If no camera with 800x400 resolution is found
    """
    logger.info("Searching for camera with 800x400 resolution")
    available_cameras = enumerate_cameras()
    
    for camera_index in available_cameras:
        cap = cv2.VideoCapture(camera_index, cv2.CAP_MSMF)
        if not cap.isOpened():
            continue
            
        width = cap.get(cv2.CAP_PROP_FRAME_WIDTH)
        height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
        
        if int(width) == 800 and int(height) == 400:
            logger.success(f"Found 800x400 camera at index {camera_index} with backend {cap.getBackendName()}")
            return cap
        
        cap.release()
            
    error_msg = "No camera with 800x400 resolution found"
    logger.error(error_msg)
    raise RuntimeError(error_msg)

def set_analog_gain():
    set_cam_reg(0x350a, 0x0a)
    set_cam_reg(0x350b, 0x00)

def initialize_camera(timeout: int = 30) -> cv2.VideoCapture:
    """
    Initialize the camera and ensure stream is healthy (>30 fps).
    
    Args:
        timeout: Maximum time in seconds to attempt initialization
        
    Returns:
        cv2.VideoCapture: Initialized camera with healthy stream
        
    Raises:
        RuntimeError: If initialization fails or healthy stream cannot be achieved
    """
    start_time = time.time()
    fps_counter = FPSCounter()
    
    while time.time() - start_time < timeout:
        try:
            logger.info("Resetting FPGA")
            reset_fpga()
            time.sleep(1)  # Allow FPGA to stabilize
            
            
            set_pwm_duty(5)
            logger.info("Set PWM duty to 5")
            # set_analog_gain()
            cap = enumerate_800x400_camera()
            
            # Check stream health
            logger.info("Checking stream health...")
            health_check_start = time.time()
            
            while time.time() - health_check_start < 5:  # 5 second health check window
                success = cap.grab()
                if not success:
                    break
                    
                success, frame = cap.retrieve()
                if not success:
                    break
                    
                fps = fps_counter.update()
                if fps_counter.is_stream_healthy():
                    logger.success(f"Stream is healthy with {fps:.1f} FPS")
                    return cap
                    
            logger.warning("Stream not healthy, retrying initialization...")
            cap.release()
            
        except Exception as e:
            logger.warning(f"Initialization attempt failed: {str(e)}")
            time.sleep(1)
            
    error_msg = f"Failed to achieve healthy stream within {timeout} seconds"
    logger.error(error_msg)
    raise RuntimeError(error_msg)


def main():
    logger.info("Starting capture application")
    video_writer = cv2.VideoWriter('output.mp4', -1, 60, (800, 400))
    try:
        cap = initialize_camera()
        fps_counter = FPSCounter()
        
        while True:
            success, frame = cap.read()
            if not success:
                logger.error("Failed to read frame")
                break
                
            # gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
            frame = frame[:,:,1]
            fps = fps_counter.update()
            
            frame = cv2.normalize(frame, None, 0, 255, cv2.NORM_MINMAX)
            video_writer.write(frame)

            logger.info(f"FPS: {fps:.1f}")
            cv2.imshow('frame', frame)
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break

        logger.info("Cleaning up resources")
        cap.release()
        video_writer.release()
        cv2.destroyAllWindows()
        logger.success("Application terminated successfully")
    except Exception as e:
        logger.exception(f"Error during capture: {str(e)}")
        raise

if __name__ == "__main__":
    try:
        main()
    except Exception as e:
        logger.exception(f"Application crashed: {str(e)}")
        raise
