import os
import time
import hid
import glfw
from OpenGL.GL import *
import imgui
import logging
from imgui.integrations.glfw import GlfwRenderer
from image import load_image
from image import load_image_rgba
from image import render_image
from motor_controller import MotorController
from proximity import reset_proximity
from proximity import proximity_read_proximity
from proximity import proximity_write_to_headset
from proximity import proximity_threshold_write_to_headset
from alignment import align_init
from alignment import align_destroy
from alignment import align_find_horizon_left
from alignment import align_find_horizon_right
from alignment import align_find_center_left
from alignment import align_find_center_right
from alignment import align_write_config_values
from alignment import confirm_distortion
from alignment import confirm_tracking_calibration
from alignment import align_find_transition_point
from hmd_config import hmd_config_write_serial
from mes_logger import mes_add
from extended_mode_display import enable_direct_mode, prepare_headset_for_measurement

LEFT_EYE = 0
RIGHT_EYE = 1

window_width = 1000
window_height = 800
hid_device = hid.device()

idle_time = 0.0  # Used to automatically advance past confirmation screens
max_idle_time = 3.0  # Seconds

serial_number = ""
should_check_red_distortion = False

MES_SW_VER = "0.0.1"
uut_start = 0
uut_stop = 0


class STATES:
    PROMPT = 0
    CONNECT = 1
    PROXIMITY_RESET = 2
    PROXIMITY_NO_PAD = 3
    PROXIMITY_NO_PAD_CALC = 4
    PROXIMITY_RESET_BEFORE_PAD = 5
    PROXIMITY_PAD = 6
    PROXIMITY_PAD_CALC = 7
    PROXIMITY_REJECTED = 8
    PROXIMITY_WRITE_THRESHOLD = 9
    ALIGN_INIT = 10
    ALIGN_ACTIVE = 11
    ALIGN_SHOW_RESULTS = 12
    CALIBRATION_REJECTED = 13


def centered(width):
    imgui.set_cursor_pos_x((window_width - width) * 0.5)


def text_centered(text):
    text_width = imgui.calc_text_size(text).x
    imgui.set_cursor_pos_x((window_width - text_width) * 0.5)
    imgui.text(text)


def render_ui():
    imgui.end()
    imgui.end_frame()
    imgui_renderer.process_inputs()
    imgui.render()
    imgui_renderer.render(imgui.get_draw_data())
    glfw.poll_events()
    glfw.swap_buffers(glfw_window)
    glClearColor(1.0, 1.0, 1.0, 1.0)
    glClear(GL_COLOR_BUFFER_BIT)
    imgui.new_frame()
    imgui.set_next_window_size(imgui.get_io().display_size.x,
                               imgui.get_io().display_size.y)
    imgui.set_next_window_position(0, 0)
    imgui.begin("Bigscreen VAP Tool", True, imgui.WINDOW_NO_COLLAPSE | imgui.WINDOW_NO_RESIZE)


def restart_vap():
    global calibration, threshold, raw_config, idle_time, last_error, alignment_result_text
    print("Restarting VAP...")
    calibration = 0.0
    threshold = 0.0
    idle_time = 0.0
    last_error = ""
    alignment_result_text = ""
    raw_config = []
    hid_device.close()


# Setup logging
logger = logging.getLogger("VAP")
logger.setLevel(logging.INFO)

# Setup glfw window and IMGUI.
if not glfw.init():
    print("Failed to start glfw.")
    exit()
glfw.window_hint(glfw.RESIZABLE, 0)
glfw_window = glfw.create_window(window_width, window_height,
                                 "Bigscreen VAP Tool", None, None)
if not glfw_window:
    print("Failed to create glfw window.")
    exit()
glfw.make_context_current(glfw_window)
glViewport(0, 0, window_width, window_height)
glEnable(GL_BLEND)
glBlendFunc(GL_SRC_ALPHA, GL_ONE_MINUS_SRC_ALPHA)
glMatrixMode(GL_PROJECTION)
glLoadIdentity()
glOrtho(0.0, float(window_width), float(window_height), 0.0, 0.0, 1.0)
imgui.create_context()
imgui.get_io().display_size = [window_width, window_height]

# Load assets
asset_path = "assets/"
if os.path.exists("_internal/") is True:  # If we're in a Pyinstaller build, find assets in the _internal folder
    asset_path = "_internal/" + asset_path

imgui.get_io().fonts.add_font_from_file_ttf(asset_path + "fonts/arialuni.ttf", 48.0, None,
                                            imgui.get_io().fonts.get_glyph_ranges_chinese())
imgui.get_io().fonts.get_tex_data_as_rgba32()
imgui_renderer = GlfwRenderer(glfw_window)

# Load images.
connect_image = load_image(asset_path + "/images/connect.png")
good_image = load_image_rgba(asset_path + "/images/good.png")
bad_image = load_image_rgba(asset_path + "/images/bad.png")

# Main loop.
progress = 0.0
state = STATES.PROMPT
raw_config = []
calibration = 0.0
threshold = 0.0
last_error = ""
alignment_result_text = ""
last_key_states = [glfw.RELEASE] * 350
rig_controller = MotorController()

should_run_alignment = [True]
should_run_proximity = [True]

keyboard_focus_set = False  # Prevent repetitive focus calls in PROMPT state

mes_add(serial_number, "PASS", uut_start, uut_stop, MES_SW_VER)

while not glfw.window_should_close(glfw_window):
    glClearColor(1.0, 1.0, 1.0, 1.0)
    glClear(GL_COLOR_BUFFER_BIT)
    imgui.new_frame()
    imgui.set_next_window_size(imgui.get_io().display_size.x,
                               imgui.get_io().display_size.y)
    imgui.set_next_window_position(0, 0)
    imgui.begin("Bigscreen VAP Tool", True, imgui.WINDOW_NO_COLLAPSE | imgui.WINDOW_NO_RESIZE)

    # Process input
    enter_key_just_pressed = False

    enter_key_state = glfw.get_key(glfw_window, glfw.KEY_ENTER)
    kp_enter_key_state = glfw.get_key(glfw_window, glfw.KEY_KP_ENTER)

    if (enter_key_state == glfw.PRESS and last_key_states[glfw.KEY_ENTER] == glfw.RELEASE) or \
            (kp_enter_key_state == glfw.PRESS and last_key_states[glfw.KEY_KP_ENTER] == glfw.RELEASE):
        enter_key_just_pressed = True

    last_key_states[glfw.KEY_ENTER] = enter_key_state
    last_key_states[glfw.KEY_KP_ENTER] = kp_enter_key_state

    # UI to display.
    match state:
        case STATES.PROMPT:
            prox_checked, should_run_proximity[0] = imgui.checkbox("Run Proximity Calibration", should_run_proximity[0])
            va_checked, should_run_alignment[0] = imgui.checkbox("Run Alignment", should_run_alignment[0])
            imgui.dummy(0.0, 150.0)
            prompt_text = "Enter Serial #: "
            text_centered(prompt_text)
            imgui.push_item_width(400)
            centered(400)
            input_field_cleared_this_frame = False
            if serial_number != "" and serial_number[-1] == 'c':
                serial_number = ""
                input_field_cleared_this_frame = True
            else:
                _, serial_number = imgui.input_text('##sn', serial_number, 256, imgui.INPUT_TEXT_ENTER_RETURNS_TRUE)
            imgui.pop_item_width()
            if keyboard_focus_set is False:  # Set focus to text input on the first frame
                imgui.set_keyboard_focus_here(0)
                keyboard_focus_set = True
            if va_checked or prox_checked:  # If checkboxes were checked, bring focus back to text input
                imgui.set_keyboard_focus_here(0)
            if input_field_cleared_this_frame:
                keyboard_focus_set = False
            if enter_key_just_pressed:
                state = STATES.CONNECT
                rig_controller.go_home()
                rig_controller.go_right_eye()
                logger.handlers = []
                file_handler = logging.FileHandler(time.strftime(f"VAP_%Y_%m_%d") + f"_%s.log" % serial_number)
                formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
                file_handler.setFormatter(formatter)
                logger.addHandler(file_handler)
                logger.info("VAP process started. Serial number: %s" % serial_number)
        case STATES.CONNECT:
            try:
                hid_device.open(0x35bd, 257)
            except OSError:
                imgui.dummy(0.0, 20.0)
                text_centered("Connect linkbox to headset...")
                imgui.dummy(0.0, 20.0)
                text_centered("将串流盒连接到耳机")
                imgui.dummy(0.0, 20.0)
            else:
                hid_device.set_nonblocking(1)
                raw_config = hmd_config_write_serial(hid_device, serial_number)
                if confirm_distortion() is False and should_check_red_distortion is True:
                    state = STATES.CALIBRATION_REJECTED
                    last_error = "Failed to detect optical calibration, run calibration upload again"
                elif confirm_tracking_calibration() is False:
                    state = STATES.CALIBRATION_REJECTED
                    last_error = "Failed to detect tracking calibration, run calibration upload again"
                elif should_run_proximity[0] is False:
                    state = STATES.ALIGN_INIT
                else:
                    state = STATES.PROXIMITY_RESET
        case STATES.CALIBRATION_REJECTED:
            err_str = last_error
            text_centered(err_str)
            logger.error(err_str)
            imgui.dummy(0.0, 60.0)
            text_centered("无法检测到校准，请重新运行校准上传")

            # If enter is pressed after rejection, restart VAP
            if enter_key_just_pressed:
                logger.info("Failed red distortion detection.")
                break
        case STATES.PROXIMITY_RESET:
            uut_start = time.time()
            raw_config = reset_proximity(hid_device)
            logger.info("Proximity values reset.")
            progress = 0.0
            # Reboot.
            hid_device.send_feature_report([0, 0x43])
            logger.info("Rebooting headset.")
            for p in range(10):
                imgui.dummy(0.0, 300.0)
                text_centered("Proximity calibration has begun, please wait...")
                imgui.dummy(0.0, 40.0)
                text_centered("接近校准已开始")
                text_centered("请稍等")
                imgui.dummy(0.0, 20.0)
                imgui.dummy(0.0, 40.0)
                centered(300)
                imgui.progress_bar(progress, (300, 20))
                progress += 0.1
                render_ui()
            hid_device.close()
            hid_device = hid.device()
            device_open = False
            while device_open is False:
                try:
                    hid_device.open(0x35bd, 257)
                except OSError:
                    time.sleep(0.1)
                else:
                    device_open = True
                    hid_device.set_nonblocking(1)
                    state = STATES.PROXIMITY_NO_PAD_CALC
        case STATES.PROXIMITY_NO_PAD_CALC:
            # Flush out built up states.
            while True:
                data = hid_device.read(64)
                if not data:
                    break
            progress = 0
            min_value = float('inf')  # Tracks minimum read prox value for range calc
            max_value = float('-inf')  # Tracks maximum read prox value for range calc
            for p in range(10):
                imgui.dummy(0.0, 300.0)
                text_centered("Calculating...")
                text_centered("计算...")
                imgui.dummy(0.0, 40.0)
                read_proximity = proximity_read_proximity(hid_device)
                calibration += read_proximity
                print(read_proximity, flush=True)
                min_value = min(min_value, read_proximity)
                max_value = max(max_value, read_proximity)
                progress += 0.1
                centered(300)
                imgui.progress_bar(progress, (300, 20))
                render_ui()
            calibration /= 10.0
            print("Calibration: ", calibration, flush=True)
            logger.info("Calibration value: %f" % calibration)
            integer_calibration = int(calibration)

            if max_value - min_value > 150:
                state = STATES.PROXIMITY_REJECTED
                logger.error("Calibration value rejected. Range of read values exceeds 150.")
            else:
                raw_config = proximity_write_to_headset(hid_device, integer_calibration, raw_config)
                state = STATES.PROXIMITY_RESET_BEFORE_PAD
                logger.info("Calibration value written to headset: %f" % calibration)
        case STATES.PROXIMITY_RESET_BEFORE_PAD:
            progress = 0
            # Reboot.
            hid_device.send_feature_report([0, 0x43])
            logger.info("Rebooting headset.")
            for p in range(20):
                imgui.dummy(0.0, 300.0)
                text_centered("Calibrating, please wait...")
                imgui.dummy(0.0, 40.0)
                text_centered("正在校准，请稍候...")
                text_centered("请稍等")
                imgui.dummy(0.0, 60.0)
                centered(300)
                imgui.progress_bar(progress, (300, 20))
                progress += 0.05
                render_ui()
                time.sleep(0.1)
            hid_device.close()
            hid_device = hid.device()
            device_open = False
            while not device_open:
                try:
                    hid_device.open(0x35bd, 257)
                except OSError:
                    time.sleep(0.1)
                else:
                    device_open = True
                    hid_device.set_nonblocking(1)
            rig_controller.go_left_eye()
            state = STATES.PROXIMITY_PAD_CALC
        case STATES.PROXIMITY_PAD_CALC:
            # Flush out built up states.
            while True:
                data = hid_device.read(64)
                if not data:
                    break
                time.sleep(0.1)
            calibrated_proximity = 0.0
            progress = 0
            min_value = float('inf')  # Tracks minimum read prox value for range calc
            max_value = float('-inf')  # Tracks maximum read prox value for range calc
            for p in range(10):
                imgui.dummy(0.0, 300.0)
                text_centered("Calculating...")
                text_centered("计算")
                imgui.dummy(0.0, 40.0)
                read_proximity = proximity_read_proximity(hid_device)
                threshold += read_proximity
                print(read_proximity, flush=True)
                min_value = min(min_value, read_proximity)
                max_value = max(max_value, read_proximity)
                progress += 0.1
                centered(300)
                imgui.progress_bar(progress, (300, 20))
                render_ui()
            threshold /= 10.0
            threshold_offset = 250 # Offset to ensure the threshold is high enough
            threshold += threshold_offset
            print(threshold, flush=True)
            logger.info("Threshold value: %f" % threshold)

            # Determine if the proximity threshold is acceptable
            if threshold < 500.0:
                state = STATES.PROXIMITY_REJECTED
                logger.error("Threshold value rejected. Value is below minimum acceptable threshold.")
            elif max_value - min_value > 150:
                logger.error("Threshold value rejected. Range of read values exceeds 150.")
            else:
                integer_threshold = int(threshold)
                raw_config = proximity_threshold_write_to_headset(hid_device, integer_threshold,
                                                                  raw_config)
                state = STATES.PROXIMITY_WRITE_THRESHOLD
                logger.info("Threshold value written to headset: %f" % threshold)
        case STATES.PROXIMITY_REJECTED:
            calibration_str = "Calibration: %i" % int(calibration)
            threshold_str = "Threshold: %i" % int(threshold)
            imgui.dummy(0.0, 20.0)
            text_centered(calibration_str)
            imgui.dummy(0.0, 20.0)
            text_centered(threshold_str)
            imgui.dummy(0.0, 60.0)
            text_centered("HMD did not pass the proximity calibration")
            imgui.dummy(0.0, 40.0)
            text_centered("没有通过接近校准")

            # If enter is pressed after rejection, restart VAP
            if enter_key_just_pressed:
                uut_stop = time.time()
                formatted_string = mes_add(serial_number, "FAIL", uut_start, uut_stop, MES_SW_VER)
                logger.info("Writing proximity result to MES...")
                logger.info(formatted_string)
                logger.info("Failed proximity calibration.")
                break
        case STATES.PROXIMITY_WRITE_THRESHOLD:
            uut_stop = time.time()
            calibration_str = "Calibration: %i" % int(calibration)
            threshold_str = "Threshold: %i" % int(threshold)
            imgui.dummy(0.0, 10.0)
            text_centered(calibration_str)
            imgui.dummy(0.0, 10.0)
            text_centered(threshold_str)
            imgui.dummy(0.0, 10.0)
            text_centered("Successfully calibrated proximity")
            imgui.dummy(0.0, 10.0)
            centered(200)

            # Increment idle_time by the time since the last frame
            idle_time += imgui.get_io().delta_time

            if imgui.button("  OK 好的  ") or enter_key_just_pressed or idle_time > max_idle_time:
                idle_time = 0.0
                formatted_string = mes_add(serial_number, "PASS", uut_start, uut_stop, MES_SW_VER)
                logger.info("Writing proximity result to MES...")
                logger.info(formatted_string)

                if should_run_alignment[0]:
                    state = STATES.ALIGN_INIT
                else:
                    restart_vap()
                    keyboard_focus_set = False
                    state = STATES.PROMPT
                    logger.info("VAP restarted after proximity calibration.")
            imgui.dummy(0.0, 40.0)
            text_centered("成功校准接近度")
        case STATES.ALIGN_INIT:
            imgui.dummy(0.0, 300.0)
            text_centered("Starting alignment...")
            text_centered("开始对齐...")
            text_centered("请稍等")
            render_ui()

            # Make sure the HID device is actively connected
            while True:
                try:
                    hid_device.open(0x35bd, 257)
                    break
                except:
                    hid_device.close()
                    time.sleep(0.1)

            hid_device.send_feature_report(bytes([0, 0x70]))  # Temporarily disable proximity sensor

            # Determine optical center of each eye
            rig_controller.go_left_eye()
            left_coords = align_find_center_left()
            rig_controller.go_right_eye()
            right_coords = align_find_center_right()

            # If either center is not found, report fail
            if (left_coords == (0, 0)).all() or (right_coords == (0, 0)).all():
                alignment_result_text = "Alignment failed. Couldn't find optical center. Please try again."
                uut_stop = time.time()
                formatted_string = mes_add(serial_number, "FAIL", uut_start, uut_stop, MES_SW_VER)
                logger.info("Writing MES...")
                logger.info(formatted_string)

            logger.info("Starting alignment phase of VAP.")
            align_init(serial_number, 64.0, left_coords, right_coords)
            hid_device.send_feature_report(bytes([0, 0x70]))  # Temporarily disable proximity sensor
            state = STATES.ALIGN_ACTIVE
        case STATES.ALIGN_ACTIVE:
            imgui.dummy(0.0, 300.0)
            text_centered("Finding best pitch...")
            text_centered("寻找最佳音高...")
            text_centered("请稍等")
            render_ui()

            rig_controller.go_left_eye()
            # Make sure the HID device is actively connected
            while True:
                try:
                    hid_device.open(0x35bd, 257)
                    break
                except:
                    hid_device.close()
                    time.sleep(0.1)
            hid_device.send_feature_report(bytes([0, 0x70]))  # Temporarily disable proximity sensor

            left = align_find_horizon_left()
            rig_controller.go_right_eye()
            right = align_find_horizon_right()
            print("Horizon lines: ", left, right)
            align_write_config_values(-left, -right) # Restarts SteamVR

            # Make sure the HID device is actively connected
            while True:
                try:
                    hid_device.open(0x35bd, 257)
                    break
                except:
                    hid_device.close()
                    time.sleep(0.1)
            hid_device.send_feature_report(bytes([0, 0x70]))  # Temporarily disable proximity sensor

            # Find the new horizon lines of both eyes (transition points) to confirm alignment
            tp = align_find_transition_point()
            rig_controller.go_left_eye()
            tp2 = align_find_transition_point()
            print("Transition points: ", tp, tp2)

            # Make sure the transition points are relatively close to each other; fail if they are not
            if abs(tp[1] - tp2[1]) > 10 or tp == (0, 0) or tp2 == (0, 0):
                # report fail
                alignment_result_text = "Alignment failed. Please try again."
                uut_stop = time.time()
                formatted_string = mes_add(serial_number, "FAIL", uut_start, uut_stop, MES_SW_VER)
                logger.info("Writing MES...")
                logger.info(formatted_string)
                mes_string_sent = True
            else:
                # report pass
                alignment_result_text = "Alignment completed successfully."
                uut_stop = time.time()
                formatted_string = mes_add(serial_number, "PASS", uut_start, uut_stop, MES_SW_VER)
                logger.info("Writing MES...")
                logger.info(formatted_string)
                mes_string_sent = True

            align_destroy()
            state = STATES.ALIGN_SHOW_RESULTS
        case STATES.ALIGN_SHOW_RESULTS:
            imgui.dummy(0.0, 300.0)
            text_centered(alignment_result_text)
            if enter_key_just_pressed:
                restart_vap()
                keyboard_focus_set = False
                state = STATES.PROMPT
                logger.info("VAP restarted after alignment.")

    imgui.end()
    imgui.end_frame()
    imgui_renderer.process_inputs()
    imgui.render()
    imgui_renderer.render(imgui.get_draw_data())

    # Image to display.
    match state:
        case STATES.CONNECT:
            render_image(connect_image.tex_id,
                         window_width / 2 - connect_image.width / 2,
                         window_height / 2 - connect_image.height / 2 + 150,
                         connect_image.width, connect_image.height)
        case STATES.PROXIMITY_REJECTED:
            render_image(bad_image.tex_id,
                         window_width / 2 - bad_image.width / 2,
                         window_height / 2 - bad_image.height / 2 + 150,
                         bad_image.width, bad_image.height)
        case STATES.CALIBRATION_REJECTED:
            render_image(bad_image.tex_id,
                         window_width / 2 - bad_image.width / 2,
                         window_height / 2 - bad_image.height / 2 + 150,
                         bad_image.width, bad_image.height)
        case STATES.PROXIMITY_WRITE_THRESHOLD:
            render_image(good_image.tex_id,
                         window_width / 2 - good_image.width / 2,
                         window_height / 2 - good_image.height / 2 + 150,
                         good_image.width, good_image.height)
        case STATES.ALIGN_SHOW_RESULTS:
            if alignment_result_text == "Alignment failed. Please try again.":
                render_image(bad_image.tex_id,
                             window_width / 2 - bad_image.width / 2,
                             window_height / 2 - bad_image.height / 2 + 150,
                             bad_image.width, bad_image.height)
            else:
                render_image(good_image.tex_id,
                             window_width / 2 - good_image.width / 2,
                             window_height / 2 - good_image.height / 2 + 150,
                             good_image.width, good_image.height)

    glfw.poll_events()
    if glfw.get_key(glfw_window, glfw.KEY_ESCAPE):
        break
    glfw.swap_buffers(glfw_window)

# Cleanup.
align_destroy()
glfw.terminate()