import os.path
import cv2
import numpy as np
import struct
import atexit
import subprocess
import win32pipe
import win32file
import steamvr
import logging
from scipy.spatial.transform import Rotation
import time
import matplotlib.pyplot as plt
from skimage.draw import line
from extended_mode_display import DirectModeDisplay

def find_camera_index():
    for i in range(5):
        video = cv2.VideoCapture(i)
        if video.isOpened():
            width = video.get(cv2.CAP_PROP_FRAME_WIDTH)
            height = video.get(cv2.CAP_PROP_FRAME_HEIGHT)
            video.release()

            # If width is 640 and height is 480, we found the right camera
            if width == 640 and height == 480:
                print(f"Found VAP camera at index {i}")
                return i
        else:
            video.release()
    return 0

camera_index = find_camera_index()
video = cv2.VideoCapture(camera_index)

pipe = win32pipe.CreateNamedPipe(r'\\.\pipe\vap_pipe',
                                 win32pipe.PIPE_ACCESS_DUPLEX,
                                 win32pipe.PIPE_TYPE_MESSAGE | win32pipe.PIPE_WAIT,
                                 1, 65536, 65536, 300, None)

lh = steamvr.LighthouseConsole()
lh.open()
try:
    config = lh.download_config()
except:
    config = None
    print("Could not download config")

class Subprocess:
    subprocess = None

subprocess_path = "resources/DirectX12Seed.exe"
subprocess_dir_path = "resources/"

if os.path.exists("_internal/") is True:  # If we're in a Pyinstaller build, find assets in the _internal folder
    subprocess_path = "_internal/" + subprocess_path
    subprocess_dir_path = "_internal/" + subprocess_dir_path

asset_path = "assets/"
if os.path.exists("_internal/") is True:  # If we're in a Pyinstaller build, find assets in the _internal folder
    asset_path = "_internal/" + asset_path

charuco_board_asset_path_left = asset_path + "images/ChArUco_Marker_left.png"
charuco_board_asset_path_right = asset_path + "images/ChArUco_Marker_right.png"

logger = logging.getLogger("ALIGN")
logger.setLevel(logging.INFO)

# Initialize parameters
left_center = (0, 0)
right_center = (0, 0)
DOWNWARD_CANT = -5
HORIZONTAL_CANT = 6.17
ANGLE_TEST_LIMIT = 7
ANGLE_TEST_INCREMENT = 0.1
is_first_run = True

def is_pipe_connected(pipe_handle):
    try:
        # Try to get pipe state
        state = win32pipe.GetNamedPipeInfo(pipe_handle)
        return True
    except Exception as e:
        logger.error(f"Pipe not connected: {str(e)}")
        return False

def wait_for_unity():
    # This will block until we've received something written to us over the pipe.
    if not is_pipe_connected(pipe):
        raise RuntimeError("Pipe is not connected")
    _, data = win32file.ReadFile(pipe, 256)


def align_init(serial, ipd, left_optic_center, right_optic_center):
    global is_first_run, pipe, lh, config, video, left_center, right_center
    serial_number = serial
    left_center = left_optic_center
    right_center = right_optic_center
    left_center = left_center.squeeze()
    left_center = left_center.astype(int)
    right_center = right_center.squeeze()
    right_center = right_center.astype(int)

    cv2.namedWindow("Bigscreen VAP Tool")

    steamvr.steamvr.set_ipd_default_mm(config, ipd)
    lh.upload_config(config)

    if is_first_run is False:
        camera_index = find_camera_index()
        video = cv2.VideoCapture(camera_index)
        pipe = win32pipe.CreateNamedPipe(r'\\.\pipe\vap_pipe',
                                         win32pipe.PIPE_ACCESS_DUPLEX,
                                         win32pipe.PIPE_TYPE_MESSAGE | win32pipe.PIPE_WAIT,
                                         1, 65536, 65536, 300, None)
        lh = steamvr.LighthouseConsole()
        lh.open()
        config = lh.download_config()

    align_reset_config_values()
    ret, img = video.read()
    cv2.imshow('Bigscreen VAP Tool', img)

    Subprocess.subprocess = subprocess.Popen(subprocess_path, stdin=subprocess.PIPE, stdout=subprocess.PIPE, cwd=subprocess_dir_path)
    atexit.register(align_destroy)
    try:
        logger.info("Attempting to connect to named pipe...")
        win32pipe.ConnectNamedPipe(pipe, None)
        logger.info("Successfully connected to named pipe")
    except Exception as e:
        logger.error(f"Failed to connect to named pipe: {str(e)}")
        # Clean up resources
        if Subprocess.subprocess is not None:
            Subprocess.subprocess.kill()
            Subprocess.subprocess = None
        win32file.CloseHandle(pipe)
        video.release()
        cv2.destroyAllWindows()
        raise RuntimeError(f"Failed to establish pipe connection: {str(e)}")
    is_first_run = False


def align_find_horizon_left():
    global left_center, right_center

    pivot = 0.0
    i = pivot + ANGLE_TEST_LIMIT
    angles, left_offsets = [], []
    line_length = 100

    # Flags and counters
    left_zero_found = False
    left_zero_post_count = 0
    post_crossing_margin = 5
    left_zero_crossing = -100
    left_zero_crossing_index = -100

    while i > pivot - ANGLE_TEST_LIMIT:
        angle = float(i)
        my_bytes = struct.pack("f", i)
        if not is_pipe_connected(pipe):
            raise RuntimeError("Pipe is not connected")
        win32file.WriteFile(pipe, my_bytes)
        wait_for_unity()

        # Grab the latest camera frame.
        ret, new_img = video.read()  # Capture the new image.
        angle_str = "Angle: {:.2f}".format(angle)
        new_gray = cv2.cvtColor(new_img, cv2.COLOR_BGR2GRAY)
        cv2.putText(new_img, angle_str, (0, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2, cv2.LINE_AA)

        # LEFT EYE calculations
        left_eye_top = (int((left_center[1] + right_center[1]) / 2 + line_length), left_center[0])  # (y, x) preferred by skimage
        left_eye_bottom = (int((left_center[1] + right_center[1]) / 2 - line_length), left_center[0])
        left_rr, left_cc = line(*left_eye_bottom, *left_eye_top)
        left_intensities = new_gray[left_rr, left_cc].astype('float64')
        left_intensities -= np.average(left_intensities)  # Normalize
        left_step_filter = np.hstack((np.ones(len(left_intensities)), -1 * np.ones(len(left_intensities))))
        left_step_conv = np.convolve(left_intensities, left_step_filter, mode='valid')
        left_step_index = np.argmax(left_step_conv)
        left_pixel_offset = left_step_index - line_length
        print ("Left eye offset: ", left_pixel_offset)

        cv2.line(new_img, left_center,
                 (left_center[0], left_center[1] + left_pixel_offset),
                 (0, 255, 0), 2)

        angles.append(angle)
        left_offsets.append(left_pixel_offset)

        # Check if zero-crossing is found for left eye and update the post zero-crossing counter
        if left_zero_found:
            left_zero_post_count += 1

        # Check for zero-crossing for left eye if not found yet
        if not left_zero_found:
            for idx in range(1, len(left_offsets)):
                y1, y2 = left_offsets[idx - 1], left_offsets[idx]
                x1, x2 = angles[idx - 1], angles[idx]
                if y1 * y2 <= 0 and abs(y2 - y1) < 20:
                    left_zero_crossing = x1 + (0 - y1) * (x2 - x1) / (y2 - y1)
                    left_zero_crossing_index = idx
                    left_zero_found = True
                    print("Left zero crossing found at: ", left_zero_crossing)
                    break

        # Break the loop if we've passed both crossings by the post-crossing margin
        if left_zero_found and left_zero_post_count >= post_crossing_margin:
            break

        i -= ANGLE_TEST_INCREMENT

        key = cv2.waitKey(1) & 0xFF
        if key == ord('s'):
            return angles[-1]

        cv2.imshow("Bigscreen VAP Tool", new_img)

        time.sleep(0.25)

    # Calculate average slope, pixels per degree (search a 10-value range around zero crossing)
    start_index = left_zero_crossing_index - post_crossing_margin
    end_index = left_zero_crossing_index + post_crossing_margin
    total_slope = 0
    n = end_index - start_index
    for i in range(start_index, end_index):
        total_slope += (left_offsets[i] - left_offsets[i - 1]) / (angles[i] - angles[i - 1])
    avg_slope = total_slope / (n - 1)
    pixels_per_degree = 1 / avg_slope

    print("Pixels per degree: ", pixels_per_degree)
    print("Average slope: ", avg_slope)
    print("Left crossing: ", left_zero_crossing)
    logger.info(f"Pixels per degree: {pixels_per_degree}")
    logger.info(f"Average slope: {avg_slope}")
    logger.info(f"Left crossing: {left_zero_crossing}")

    formatted_angles = ', '.join(f'{num:.2f}' for num in angles)
    formatted_left_offsets = ', '.join(f'{num:.2f}' for num in left_offsets)
    logger.info(f"Angles: {formatted_angles}")
    logger.info(f"Left Offsets: {formatted_left_offsets}")

    # Plot original data points
    plt.scatter(angles, left_offsets, color='blue', label='Data Points')
    plt.scatter(
        angles[left_zero_crossing_index - post_crossing_margin:left_zero_crossing_index + post_crossing_margin],
        left_offsets[
        left_zero_crossing_index - post_crossing_margin:left_zero_crossing_index + post_crossing_margin],
        color='green', label='Left Crossing')

    plt.xlabel('Angles')
    plt.ylabel('Offsets')
    plt.legend()
    plt.grid(True)
    plt.savefig(f"left_eye_offsets_{time.strftime('%Y%m%d_%H%M%S')}.png")
    plt.close()

    return left_zero_crossing


def align_find_horizon_right():
    global right_center, left_center

    pivot = 0.0
    i = pivot + ANGLE_TEST_LIMIT
    angles, right_offsets = [], []
    line_length = 100

    # Flags and counters
    right_zero_found = False
    right_zero_post_count = 0
    post_crossing_margin = 5
    right_zero_crossing = -100
    right_zero_crossing_index = -100

    while i > pivot - ANGLE_TEST_LIMIT:
        angle = float(i)
        my_bytes = struct.pack("f", i)
        if not is_pipe_connected(pipe):
            raise RuntimeError("Pipe is not connected")
        win32file.WriteFile(pipe, my_bytes)
        wait_for_unity()

        # Grab the latest camera frame.
        ret, new_img = video.read()  # Capture the new image.
        angle_str = "Angle: {:.2f}".format(angle)
        new_gray = cv2.cvtColor(new_img, cv2.COLOR_BGR2GRAY)
        cv2.putText(new_img, angle_str, (0, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2, cv2.LINE_AA)

        # RIGHT EYE calculations
        right_eye_top = (int((left_center[1] + right_center[1]) / 2 + line_length), right_center[0])  # (y, x) preferred by skimage
        right_eye_bottom = (int((left_center[1] + right_center[1]) / 2 - line_length), right_center[0])
        right_rr, right_cc = line(*right_eye_bottom, *right_eye_top)
        right_intensities = new_gray[right_rr, right_cc].astype('float64')
        right_intensities -= np.average(right_intensities)
        right_step_filter = np.hstack((np.ones(len(right_intensities)), -1 * np.ones(len(right_intensities))))
        right_step_conv = np.convolve(right_intensities, right_step_filter, mode='valid')
        right_step_index = np.argmax(right_step_conv)
        right_pixel_offset = right_step_index - line_length
        print("Right eye offset: ", right_pixel_offset)

        cv2.line(new_img, right_center,
                 (right_center[0], right_center[1] + right_pixel_offset),
                 (0, 255, 0), 2)

        angles.append(angle)
        right_offsets.append(right_pixel_offset)

        # Check if zero-crossing is found for left eye and update the post zero-crossing counter
        if right_zero_found:
            right_zero_post_count += 1

        # Check for zero-crossing for left eye if not found yet
        if not right_zero_found:
            for idx in range(1, len(right_offsets)):
                y1, y2 = right_offsets[idx - 1], right_offsets[idx]
                x1, x2 = angles[idx - 1], angles[idx]
                if y1 * y2 <= 0 and abs(y2 - y1) < 20:
                    right_zero_crossing = x1 + (0 - y1) * (x2 - x1) / (y2 - y1)
                    right_zero_crossing_index = idx
                    right_zero_found = True
                    print("Right zero crossing found at: ", right_zero_crossing)
                    break

        # Break the loop if we've passed both crossings by the post-crossing margin
        if right_zero_found and right_zero_post_count >= post_crossing_margin:
            break

        i -= ANGLE_TEST_INCREMENT

        key = cv2.waitKey(1) & 0xFF
        if key == ord('s'):
            return angles[-1]

        cv2.imshow("Bigscreen VAP Tool", new_img)

        time.sleep(0.25)

    # Calculate average slope, pixels per degree (search a 10-value range around zero crossing)
    start_index = right_zero_crossing_index - post_crossing_margin
    end_index = right_zero_crossing_index + post_crossing_margin
    total_slope = 0
    n = end_index - start_index
    for i in range(start_index, end_index):
        total_slope += (right_offsets[i] - right_offsets[i - 1]) / (angles[i] - angles[i - 1])
    avg_slope = total_slope / (n - 1)
    pixels_per_degree = 1 / avg_slope

    print("Pixels per degree: ", pixels_per_degree)
    print("Average slope: ", avg_slope)
    print("Right crossing: ", right_zero_crossing)
    logger.info(f"Pixels per degree: {pixels_per_degree}")
    logger.info(f"Average slope: {avg_slope}")
    logger.info(f"Right crossing: {right_zero_crossing}")

    formatted_angles = ', '.join(f'{num:.2f}' for num in angles)
    formatted_right_offsets = ', '.join(f'{num:.2f}' for num in right_offsets)
    logger.info(f"Angles: {formatted_angles}")
    logger.info(f"Right Offsets: {formatted_right_offsets}")

    # Plot original data points
    plt.scatter(angles, right_offsets, color='blue', label='Data Points')
    plt.scatter(
        angles[right_zero_crossing_index - post_crossing_margin:right_zero_crossing_index + post_crossing_margin],
        right_offsets[
        right_zero_crossing_index - post_crossing_margin:right_zero_crossing_index + post_crossing_margin],
        color='green', label='Right Crossing')

    plt.xlabel('Angles')
    plt.ylabel('Offsets')
    plt.legend()
    plt.grid(True)
    plt.savefig(f"right_eye_offsets_{time.strftime('%Y%m%d_%H%M%S')}.png")
    plt.close()

    return right_zero_crossing


def align_find_center_left():
    global video
    # Load the predefined dictionary
    aruco_dict = cv2.aruco.getPredefinedDictionary(cv2.aruco.DICT_4X4_1000)

    # Create the Charuco board
    board = cv2.aruco.CharucoBoard(
        [28, 14], 0.045454545454545, 0.022727272727272, aruco_dict
    )

    image = cv2.imread(charuco_board_asset_path_left)
    dmd = DirectModeDisplay()

    # Read the image
    i = 0
    coords = (0, 0)
    while i < 30:
        dmd.im_show(image)
        ret, img = video.read()

        if not ret:
            # Attempt to re-init the video capture
            print("Failed to read from video capture, re-initializing...")
            video.release()
            camera_index = find_camera_index()
            video = cv2.VideoCapture(camera_index)
            continue

        gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

        # Detect markers
        parameters = cv2.aruco.CharucoParameters()
        detector = cv2.aruco.CharucoDetector(board, parameters)
        detections = detector.detectBoard(image=gray_img)

        detections = list(detections)
        detections[0] = [] if detections[0] is None else detections[0]
        detections[1] = [] if detections[1] is None else detections[1]
        detections[2] = [] if detections[2] is None else detections[2]
        detections[3] = [] if detections[3] is None else detections[3]

        charucoCorners, charucoIds, markerCorners, markerIds = detections

        charucoCorners = np.array(charucoCorners).reshape(-1, 2)
        charucoIds = np.array(charucoIds).squeeze()
        markerCorners = np.array(markerCorners).reshape(-1, 2)
        markerIds = (4 * np.array(markerIds)[:, None] + np.arange(4)).reshape(-1)

        index = np.where(charucoIds == 169)[0]
        coords = charucoCorners[index]
        if len(coords) > 0:
            cv2.circle(gray_img, tuple(coords[0].astype(int)), 5, (255, 0, 0), -1)

        cv2.imshow('Bigscreen VAP Tool', gray_img)
        key = cv2.waitKey(1) & 0xFF
        time.sleep(0.25)

        i += 1

    dmd.close()
    return coords


def align_find_center_right():
    global video
    # Load the predefined dictionary
    aruco_dict = cv2.aruco.getPredefinedDictionary(cv2.aruco.DICT_4X4_1000)

    # Create the Charuco board
    board = cv2.aruco.CharucoBoard(
        [28, 14], 0.045454545454545, 0.022727272727272, aruco_dict
    )

    image = cv2.imread(charuco_board_asset_path_right)
    dmd = DirectModeDisplay()

    # Read the image
    i = 0
    coords = (0, 0)
    while i < 30:
        dmd.im_show(image)
        ret, img = video.read()

        if not ret:
            # Attempt to re-init the video capture
            print("Failed to read from video capture, re-initializing...")
            video.release()
            camera_index = find_camera_index()
            video = cv2.VideoCapture(camera_index)
            continue

        gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

        # Detect markers
        parameters = cv2.aruco.CharucoParameters()
        detector = cv2.aruco.CharucoDetector(board, parameters)
        detections = detector.detectBoard(image=gray_img)

        detections = list(detections)
        detections[0] = [] if detections[0] is None else detections[0]
        detections[1] = [] if detections[1] is None else detections[1]
        detections[2] = [] if detections[2] is None else detections[2]
        detections[3] = [] if detections[3] is None else detections[3]

        charucoCorners, charucoIds, markerCorners, markerIds = detections

        charucoCorners = np.array(charucoCorners).reshape(-1, 2)
        charucoIds = np.array(charucoIds).squeeze()
        markerCorners = np.array(markerCorners).reshape(-1, 2)
        markerIds = (4 * np.array(markerIds)[:, None] + np.arange(4)).reshape(-1)

        index = np.where(charucoIds == 181)[0]
        coords = charucoCorners[index]
        if len(coords) > 0:
            cv2.circle(gray_img, tuple(coords[0].astype(int)), 5, (255, 0, 0), -1)

        cv2.imshow('Bigscreen VAP Tool', gray_img)
        key = cv2.waitKey(1) & 0xFF
        time.sleep(0.25)

        i += 1

    dmd.close()
    return coords


def align_find_transition_point():
    transition_point = (0, 0)
    i = 0
    while i < 40:
        angle = float(i)
        my_bytes = struct.pack("f", DOWNWARD_CANT)
        if not is_pipe_connected(pipe):
            raise RuntimeError("Pipe is not connected")
        win32file.WriteFile(pipe, my_bytes)
        wait_for_unity()

        ret, img = video.read()
        gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

        transition_point = find_transition_point((320, 240), gray_img)

        cv2.circle(gray_img, transition_point, 5, (255, 255, 255), -1)
        cv2.imshow('Bigscreen VAP Tool', gray_img)
        key = cv2.waitKey(1) & 0xFF
        time.sleep(0.25)

        # Early exit if we've found the transition point
        if transition_point != (0, 0):
            break

        i += 1

    return transition_point


def find_transition_point(coord, img, scan_range=125, white_threshold=50):
    x, y = coord
    for i in range(y - scan_range, y + scan_range):
        if i >= img.shape[0]:  # Ensure we don't go out of image bounds
            break
        if img[i, x] > white_threshold:
            return (x, i)
    return (0, 0)


def align_write_config_values(left_eye_pitch, right_eye_pitch):
    rot = Rotation.from_euler("XYZ", [left_eye_pitch + DOWNWARD_CANT, HORIZONTAL_CANT, 0], degrees=True)
    config['tracking_to_eye_transform'][0]['eye_to_head'] = rot.as_matrix().tolist()
    rot = Rotation.from_euler("XYZ", [right_eye_pitch + DOWNWARD_CANT, -HORIZONTAL_CANT, 0], degrees=True)
    config['tracking_to_eye_transform'][1]['eye_to_head'] = rot.as_matrix().tolist()

    lh.upload_config(config)
    align_restart_steamvr()


def align_destroy():
    if Subprocess.subprocess is not None:
        Subprocess.subprocess.kill()
        Subprocess.subprocess = None
        if pipe is not None:
            win32file.CloseHandle(pipe)
    video.release()
    cv2.destroyAllWindows()
    subprocess.call(['taskkill', '/IM', 'vrmonitor.exe', '/F'])
    subprocess.call(['taskkill', '/IM', 'vrserver.exe', '/F'])
    subprocess.call(['taskkill', '/IM', 'vrwebhelper.exe', '/F'])
    time.sleep(2)


def align_reset_config_values():
    rot = Rotation.from_euler("XYZ", [0, HORIZONTAL_CANT, 0], degrees=True)
    config['tracking_to_eye_transform'][0]['eye_to_head'] = rot.as_matrix().tolist()
    rot = Rotation.from_euler("XYZ", [0, -HORIZONTAL_CANT, 0], degrees=True)
    config['tracking_to_eye_transform'][1]['eye_to_head'] = rot.as_matrix().tolist()

    lh.upload_config(config)
    align_restart_steamvr()


def align_restart_steamvr():
    global pipe

    is_unity_open = False
    if Subprocess.subprocess is not None:
        is_unity_open = True
        Subprocess.subprocess.kill()
        Subprocess.subprocess = None
        if pipe is not None:
            win32file.CloseHandle(pipe)

    subprocess.call(['taskkill', '/IM', 'vrmonitor.exe', '/F'])
    subprocess.call(['taskkill', '/IM', 'vrserver.exe', '/F'])
    subprocess.call(['taskkill', '/IM', 'vrwebhelper.exe', '/F'])
    time.sleep(2)

    subprocess.Popen(["C:\\Program Files (x86)\\Steam\\steamapps\\common\\SteamVR\\bin\\win64\\vrmonitor.exe"])
    time.sleep(2)

    if is_unity_open is True:
        pipe = win32pipe.CreateNamedPipe(r'\\.\pipe\vap_pipe',
                                         win32pipe.PIPE_ACCESS_DUPLEX,
                                         win32pipe.PIPE_TYPE_MESSAGE | win32pipe.PIPE_WAIT,
                                         1, 65536, 65536, 300, None)

        Subprocess.subprocess = subprocess.Popen(subprocess_path, stdin=subprocess.PIPE, stdout=subprocess.PIPE, cwd=subprocess_dir_path)
        atexit.register(align_destroy)
        try:
            logger.info("Attempting to connect to named pipe...")
            win32pipe.ConnectNamedPipe(pipe, None)
            if not is_pipe_connected(pipe):
                raise RuntimeError("Pipe connection failed")
            logger.info("Successfully connected to named pipe")
        except Exception as e:
            logger.error(f"Failed to connect to named pipe: {str(e)}")
            # Clean up resources
            if Subprocess.subprocess is not None:
                Subprocess.subprocess.kill()
                Subprocess.subprocess = None
            if pipe is not None:
                win32file.CloseHandle(pipe)
            video.release()
            cv2.destroyAllWindows()
            raise RuntimeError(f"Failed to establish pipe connection: {str(e)}")


def confirm_distortion():
    # Read through config, make sure there are exactly 2 distortion values
    print("Confirming distortion values...")
    logger.info("Confirming distortion values...")
    distortion_values = []
    try:
        distortion_values.append(config['tracking_to_eye_transform'][0]['distortion_red'])
        distortion_values.append(config['tracking_to_eye_transform'][1]['distortion_red'])
    except KeyError:
        print("Failed to detect calibration, run calibration upload again")
        return False

    print(f"distortion_red values: {distortion_values}")
    logger.info(f"distortion_red values: {distortion_values}")
    return True


def confirm_tracking_calibration() -> bool:
    if config is None:
        return False

    return True