import os, sys

if getattr(sys, "frozen", False):
    exe_dir = os.path.dirname(sys.executable)
    os.environ["PATH"] = exe_dir + os.pathsep + os.environ.get("PATH", "")

import builtins
import os, sys
import threading

#import imgui

if os.path.exists("./lighthouse_console.exe"):
    setattr(builtins, 'valve_tools_path', './')
    setattr(builtins, 'steam_vr_path', './')

from concurrent.futures import ThreadPoolExecutor
import steamvr
import hid
import subprocess
from steamvr.unbuffered.DeviceMonitor import instance
from dataclasses import dataclass
from pytrinamic.tmcl import TMCLCommand
from typing import Any
from calibot import Calibot
from calibot_settings import config
from random import Random
from steamvr.paths import LHCALIB_PATH
import time
import load_init_json
import logging
from gui_log import *
from imgui_bundle import imgui, immapp, hello_imgui

logging.getLogger()

#tracking_dir = r'C:\Users\facto\calibot\venv\Lib\site-packages\steamvr\tools\bin\win64'
#init_json = r"C:\Users\facto\calibot\venv\Lib\site-packages\steamvr\init_json_evt.json"
init_json = r".\init_json_config.json"

RAND = Random()
VIDS_ATMEL = [0x35bd]
PIDS_ATMEL = [0x0101]
VIDS_TUNDRA = [0x28de]
PIDS_TUNDRA = [0x2300]
KNOWN_VIDS = VIDS_ATMEL + VIDS_TUNDRA
KNOWN_PIDS = PIDS_ATMEL + PIDS_TUNDRA

BASESTATION_COUNT = config.basestation_count
CLEAR_HISTORY = config.clear_device_monitor_history
COUNTDOWN_MINUTES = config.timeout.length_minutes
COUNTDOWN_THRESHOLD = config.timeout.min_calibrations
DELAY_MCU = config.delays.mcu_seconds
DELAY_LH = config.delays.lighthouse_seconds
DELAY_BEGIN_MOTION = config.delays.begin_motion_seconds
LHCALIB_INTERVAL = config.delays.lhcalib_interval

success_count = 0
interrupt_count = 0
lhcalib_timeslot_current = 0
tick = 0
tick_lock = threading.Lock()

warning_strings = []
if os.path.exists("./warning.txt"):
    with open("./warning.txt") as f:
        warning_strings = f.read().split('\n')


def static_vars(**kwargs):
    def decorate(func):
        for k in kwargs:
            setattr(func, k, kwargs[k])
        return func
    return decorate


@static_vars(prior_angle=0)
def perform_random_motion_path(arm: Calibot):
    # -0.05 < turn1 < 0.8 and -0.05 < turn2 < 0.55
    a = RAND.uniform(0, 0.9)
    b = RAND.uniform(0, 0.9)
    a = perform_random_motion_path.prior_angle
    if perform_random_motion_path.prior_angle < 0.5:
        perform_random_motion_path.prior_angle = 0.79
    else:
        perform_random_motion_path.prior_angle = 0.0
    arm.go_position(a, b)


@dataclass
class HeadsetHID:
    atmel: Any = None
    lh_serial: str = ''

    def __repr__(self):
        return f"{self.atmel['serial_number']} - {self.lh_serial}"


class HeadsetHIDMatching:
    def __init__(self, on_match):
        self.prospect = None
        self.on_match = on_match

    def __on_match(self, dev):
        global tick
        global lhcalib_timeslot_current

        if lhcalib_timeslot_current < tick:
            lhcalib_timeslot_current = tick - LHCALIB_INTERVAL
        lhcalib_timeslot_current += LHCALIB_INTERVAL

        log_info(f"Found match between {dev.atmel['serial_number']} & {dev.lh_serial}")
        print(f"Found match between {dev.atmel['serial_number']} & {dev.lh_serial}")
        log_info(f"[TICK:{tick}] Scheduling lhcalib for {dev.lh_serial} at tick {lhcalib_timeslot_current}.")
        self.on_match(dev, lhcalib_timeslot_current)  # calibrate_headset will offset this

    def on_new_hid_device(self, dev, hid_index, new_devs, prev_devs):
        vid = dev['vendor_id']
        pid = dev['product_id']
        product = dev['product_string']
        path = dev['path']
        pair_product_str = None

        if vid in KNOWN_VIDS and pid in KNOWN_PIDS and path not in prev_devs:
            if self.prospect is None:
                self.prospect = HeadsetHID(None, "")

            if self.prospect.atmel is None and vid in VIDS_ATMEL and pid in PIDS_ATMEL and product == 'Beyond':
                print(f"Found dev atmel: {dev}")
                self.prospect.atmel = dev
                pair_product_str = 'Controller'
            if self.prospect.lh_serial == "" and vid in VIDS_TUNDRA and pid in PIDS_TUNDRA and product == 'Controller':
                print(f"Found lh_serial: {dev}")
                self.prospect.lh_serial = dev['serial_number']
                pair_product_str = 'Beyond'

            if self.prospect.atmel and self.prospect.lh_serial:
                self.__on_match(self.prospect)
                self.prospect = None
                return [path]
            elif pair_product_str:  # we found one of the pair, now find the other
                for pair_index, pair_device in enumerate(new_devs):
                    if pair_index <= hid_index:
                        continue
                    elif pair_device['product_string'] == pair_product_str:
                        return self.on_new_hid_device(pair_device, pair_index, new_devs, prev_devs)

        # print(f"NEW HID DEVICE: VID={vid:04x}, PID={pid:04x}, Product={product}")  # debug

    def _clear_matching_attempt(self):
        self.prospect = None
        self.timer = None


class TimerState:
    def __init__(self):
        self.reset()

    def activate(self):
        self.activated = True

    def end(self):
        self.finished = True

    def reset(self):
        self.finished = False
        self.activated = False

calib_timer = TimerState()


def start_timeout_countdown():
    global calib_timer
    seconds = 0

    log_info("START calibration timeout countdown")
    calib_timer.activate()
    while calib_timer.activated and seconds < COUNTDOWN_MINUTES * 60:
        seconds += 1
        time.sleep(1)

    if calib_timer.activated:
        log_info("FINISH calibration timeout countdown")
        calib_timer.end()
    else:
        log_info("RESET calibration timeout countdown")
        calib_timer.reset()


def set_light_color(dev, rgb):
    assert dev['vendor_id'] in VIDS_ATMEL and dev['product_id'] in PIDS_ATMEL
    device = hid.device()
    device.open_path(dev['path'])
    print(f'setting led on serial {dev["serial_number"]} to {rgb}')
    device.send_feature_report(bytes([0, ord('L'), rgb[0], rgb[1], rgb[2]]))


def calibrate_headset(headset: HeadsetHID, lhcalib_timeslot: int = 0):
    global tick
    global success_count
    global interrupt_count
    global calib_timer
    start_tick = tick
    timed_out = False

    try:
        time.sleep(DELAY_MCU)  # give the MCU time to post
        set_light_color(headset.atmel, rgb=(255, 255, 255))

        print('loading init config')
        log_info(headset.lh_serial + ": loading init config")
        time.sleep(DELAY_LH)  # give the lighthouse SIP time to post
        load_init_json.load_init_config(headset.lh_serial, init_json)

        print("Calibrating " + headset.lh_serial)
        log_info(headset.lh_serial + ": Calibrating " + headset.lh_serial)
        # 800 200

        # stagger lhcalib calls
        lhcalib_timeslot += tick - start_tick  # offset granted timeslot by the time it took to get here
        while lhcalib_timeslot > tick:
            time.sleep(1)
        cmd = [LHCALIB_PATH, '/bodyserial', headset.lh_serial, "/bodycal", init_json, '800', '200', '/bodycalmulti', f"{BASESTATION_COUNT}", '/deletemissingsensors']
        print(cmd)
        log_info(headset.lh_serial + ": Opening command: " + ' '.join(cmd))
        p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stdin=subprocess.PIPE, stderr=subprocess.PIPE)

        rslt = None
        while rslt is None:
            if calib_timer.finished:
                timed_out = True
                log_info(headset.lh_serial + ": " + "Calibration did not finish in time")
                assert not calib_timer.finished, "Calibration timed out"
            rslt = p.poll()
            line = p.stdout.readline().decode().rstrip()
            if line == "":
                continue
            print(line)
            log_debug(headset.lh_serial + ": " + line)
            if "Unexpected error 15?" in line:
                log_error(headset.lh_serial + ": Unexpected error 15?")
                raise RuntimeError("Unexpected error 15?")
            for l in warning_strings:
                if l in line:
                    log_warning(headset.lh_serial + ": found warning " + l + " in line " + line)
            if "Ready to run" in line:
                print("\r\n")
                p.stdin.write("\n".encode())
                p.stdin.flush()
                print("Calibration Started")
                log_info(headset.lh_serial + ": Calibration Started")
            if "avg z" in line:
                time.sleep(1)
        assert rslt == 0, p.stderr.readline().decode().rstrip()
        print(p.stdout.readline().decode().rstrip())
        log_info(headset.lh_serial + ": " + p.stdout.readline().decode().rstrip())

        tracking_file_path = f"./auto_{headset.lh_serial}.json"
        assert os.path.exists(tracking_file_path), "No tracking file found for this headset"
        calibrated_config = steamvr.lighthouse_console.load_steamvr_json(tracking_file_path)

        log_info(headset.lh_serial + ": " + "Uploading config")
        lh = steamvr.LighthouseConsole()
        lh.open()
        lh.select_device(headset.lh_serial)
        config = lh.download_config()
        config['lighthouse_config'] = calibrated_config["lighthouse_config"]
        lh.upload_config(config)

        set_light_color(headset.atmel, rgb=(0, 255, 0))
        log_info(headset.lh_serial + ": " + "Finished calibrating headset ")
        print(f"Finished calibrating headset {headset.lh_serial}")
        success_count += 1
    except BaseException as e:
        print(e)
        log_error(headset.lh_serial + ": " + str(e))
        if timed_out:
            interrupt_count += 1
            set_light_color(headset.atmel, rgb=(255, 255, 0))
        else:
            set_light_color(headset.atmel, rgb=(255, 0, 0))


def reset_run():
    global success_count
    global interrupt_count
    global tick
    global lhcalib_timeslot_current

    if CLEAR_HISTORY:
        instance.reset_history()
    success_count = 0
    interrupt_count = 0
    tick = 0
    lhcalib_timeslot_current = 0


def log_window():
    imgui.separator()
    hello_imgui.log_gui()
    time.sleep(0.1)



def log_loop():
    while True:
        immapp.run(log_window, "log", with_markdown=True)


def tick_loop():
    global tick
    while True:
        time.sleep(1)
        with tick_lock:
            tick += 1


def main():
    global success_count
    global interrupt_count
    global calib_timer
    threading.Thread(target=tick_loop, daemon=True).start()
    threading.Thread(target=log_loop, daemon=True).start()

    log_info("Looking out for the following warning strings:")
    for l in warning_strings:
        log_warning(l)

    log_info("Starting Calibot")
    O1 = 1/360*2  # offset 1
    O2 = 1/360*0.1  #0.0000  # offset 2
    calibot = Calibot(O1, O2)
    log_info("Calibot Started")

    log_info("Connecting to Calibot")
    calibot.module1.connection.send(TMCLCommand.SAP, 5, 0, 200)  # acceleration
    calibot.module2.connection.send(TMCLCommand.SAP, 5, 0, 200)  # acceleration
    log_info("Connected to Calibot")

    log_info("Homing Calibot")
    calibot.go_home()
    log_info("Calibot Homed")

    log_info("Starting Headset Calibration Loop")
    matching = HeadsetHIDMatching(lambda headset, lhcalib_timeslot: print(headset))
    instance.register_on_new_hid_device(matching.on_new_hid_device)
    while True:
        futures = []
        with ThreadPoolExecutor(max_workers=5) as tp:
            matching.on_match = lambda headset, lhcalib_timeslot: futures.append(tp.submit(calibrate_headset, headset, lhcalib_timeslot))

            calibot.module1.connection.send(TMCLCommand.SAP, 4, 0, 1600)  # search speed
            calibot.module2.connection.send(TMCLCommand.SAP, 4, 0, 1600)  # search speed
            calibot.go_position(0.25, 0.25)
            print("Place headsets")
            log_info("Waiting for headsets to be placed")
            # start = input("press enter once loaded:")
            # print("starting motion")

            while len(futures) != 4:
                time.sleep(1)

            time.sleep(DELAY_BEGIN_MOTION)  # give the MCU, lighthouse SIP time to post, and lhcalib time to load
            calib_timer.reset()

            while sum([not f.done() for f in futures]) and not calib_timer.finished:
                if success_count >= COUNTDOWN_THRESHOLD and not calib_timer.activated:
                    log_info(f"Countdown threshold of {COUNTDOWN_THRESHOLD} calibrations reached")
                    tp.submit(start_timeout_countdown)
                log_info("Waiting for headset calibration to finish")
                calibot.module1.connection.send(TMCLCommand.SAP, 4, 0, 1800)
                calibot.module2.connection.send(TMCLCommand.SAP, 4, 0, 1800)
                log_info("Performing motion path")
                perform_random_motion_path(calibot)

            if calib_timer.finished:
                log_info("Wrapping up remaining tasks")
            
            calib_timer.reset()

        if interrupt_count > 0:
            log_info(f"Calibration timed out with {interrupt_count} calibrations incomplete")
        else:
            log_info("Finished headset calibration")

        calibot.module1.connection.send(TMCLCommand.SAP, 4, 0, 1600)  # search speed
        calibot.module2.connection.send(TMCLCommand.SAP, 4, 0, 1600)  # search speed
        calibot.e_stop()
        calibot.go_position(0.25, 0.25)
        log_info("Waiting for headsets to be removed")

        reset_run()


if __name__ == '__main__':
    try:
        main()
    except BaseException as e:
        print(e)
    k = input("press close to exit")
