import time
import subprocess
import threading
import cv2
import numpy as np
from polling2 import poll, TimeoutException
import tempfile
import os
import atexit
from loguru import logger
from pathlib import Path
import weakref

from extended_mode_display.resources import resource_path
from functools import lru_cache


class DirectModeDisplay:
    def __init__(self):
        self._start_time = time.time()
        self.proc = None
        self.temp_dir = tempfile.TemporaryDirectory()
        self._image_cache = {}
        self._thread = None
        self._is_closed = False
        logger.info("Initializing DirectModeDisplay")

        # Register cleanup both for normal exit and when garbage collected
        atexit.register(self._cleanup)
        self._finalizer = weakref.finalize(self, self._cleanup)

        self._start_display_image_process()

    def _start_display_image_process(self):
        assert os.path.exists(
            resource_path("resources/display_image.exe")
        ), "display_image.exe not found at " + resource_path(
            "resources/display_image.exe"
        )
        im = np.ones((100, 100, 3), dtype=np.uint8) * 127
        cv2.imwrite(os.path.join(self.temp_dir.name, "temp_image.png"), im)
        self.proc = subprocess.Popen(
            [
                resource_path("resources/display_image.exe"),
                "-image",
                Path(os.path.join(self.temp_dir.name, "temp_image.png")).resolve(),
            ],
            stdin=subprocess.PIPE,
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
            text=True,
        )
        time.sleep(6)
        self._start_output_reader()

    def _output_reader(self):
        try:
            for line in iter(self.proc.stdout.readline, b""):
                pass
                # print(f"{line}", end="")
        except ValueError:
            pass

    def _start_output_reader(self):
        if self._thread is None:
            self._thread = threading.Thread(target=self._output_reader)
            self._thread.daemon = True
            self._thread.start()

    def convert_im_to_temp_path(self, im: np.ndarray) -> str:
        im_hash = hash(im.tobytes())

        if im_hash in self._image_cache:
            return self._image_cache[im_hash]
        temp_im_path = os.path.join(self.temp_dir.name, f"temp_image_{im_hash}.png")
        if im.ndim == 2:
            im = cv2.cvtColor(im, cv2.COLOR_GRAY2BGR)
        cv2.imwrite(temp_im_path, im)
        self._image_cache[im_hash] = temp_im_path
        return temp_im_path

    def im_show(self, im: np.ndarray | str):
        if isinstance(im, np.ndarray):
            im = self.convert_im_to_temp_path(im)
        full_path = str(Path(im).resolve()).replace("\\", "/")

        assert os.path.exists(full_path), "Image file not found"
        logger.info(f"Displaying image from {full_path}")
        self.proc.stdin.write("F" + full_path + "\n")
        self.proc.stdin.flush()

    def _cleanup(self):
        if self._is_closed:
            return

        try:
            if self.proc and self.proc.poll() is None:
                logger.info("Cleaning up DirectModeDisplay process")
                try:
                    self.proc.communicate(
                        "q\n", timeout=1
                    )  # Try graceful shutdown first
                except subprocess.TimeoutExpired:
                    self.proc.kill()  # Force kill if graceful shutdown fails

            if hasattr(self, "temp_dir"):
                self.temp_dir.cleanup()

            self._is_closed = True
        except Exception as e:
            logger.error(f"Error during cleanup: {e}")

    def close(self):
        self._cleanup()

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()


def rotate_image(image, angle_degrees):
    """Rotate image using affine transformation while maintaining image size"""
    height, width = image.shape[:2]
    center = (width // 2, height // 2)

    # Get rotation matrix
    rotation_matrix = cv2.getRotationMatrix2D(center, np.deg2rad(angle_degrees), 1.0)

    # Perform rotation while maintaining image size
    return cv2.warpAffine(image, rotation_matrix, (width, height))


if __name__ == "__main__":
    # Example 1: Simple display with context manager
    with DirectModeDisplay() as dmd:
        for i in range(10):
            dmd.im_show(resource_path("test_images/green.png"))
            time.sleep(0.01)

    # Example 2: Image rotation with context manager
    with DirectModeDisplay() as dmd:
        im = cv2.imread(resource_path("test_images/green.png"))
        angle = 0
        for i in range(30):
            angle += 20  # random angle to avoid cache
            im = rotate_image(im, angle)
            dmd.im_show(im)
            time.sleep(0.5)
