import os, sys

if getattr(sys, "frozen", False):
    exe_dir = os.path.dirname(sys.executable)
    os.environ["PATH"] = exe_dir + os.pathsep + os.environ.get("PATH", "")

import builtins
import os, sys, ctypes, subprocess
import threading

def is_admin():
    try:
        return ctypes.windll.shell32.IsUserAnAdmin()
    except:
        return False
    
if not is_admin():
    ctypes.windll.shell32.ShellExecuteW(
        None, u"runas", str(sys.executable), str(__file__), None, 1
    )
    exit()

if os.path.exists("./lighthouse_console.exe"):
    setattr(builtins, 'valve_tools_path', './')
    setattr(builtins, 'steam_vr_path', './')

from concurrent.futures import ThreadPoolExecutor
import steamvr
import hid
from steamvr.unbuffered.DeviceMonitor import instance
from dataclasses import dataclass
from pytrinamic.tmcl import TMCLCommand
from typing import Any
from calibot import Calibot
from calibot_settings import config
from random import Random
from steamvr.paths import LHCALIB_PATH
import time
import load_init_json
import bs_usb_tools
import logging
from gui_log import *
from imgui_bundle import imgui, immapp, hello_imgui

logging.getLogger()

#tracking_dir = r'C:\Users\facto\calibot\venv\Lib\site-packages\steamvr\tools\bin\win64'
#init_json = r"C:\Users\facto\calibot\venv\Lib\site-packages\steamvr\init_json_evt.json"
init_json = r".\init_json_config.json"

CALIBOT_VERSION = '1.1.5'

RAND = Random()
VIDS_ATMEL = [0x35bd]
PIDS_ATMEL = [0x0101]
VIDS_TUNDRA = [0x28de]
PIDS_TUNDRA = [0x2300]
KNOWN_VIDS = VIDS_ATMEL + VIDS_TUNDRA
KNOWN_PIDS = PIDS_ATMEL + PIDS_TUNDRA

# Values derived from the yaml config
BASESTATION_COUNT = config.basestation_count
CLEAR_HISTORY = config.clear_device_monitor_history
TOTAL_SENSOR_HITS = config.total_sensor_hits
HITS_PER_SENSOR = config.hits_per_sensor
COUNTDOWN_MINUTES = config.timeout.length_minutes
COUNTDOWN_THRESHOLD = config.timeout.min_calibrations
WRAP_UP_SECONDS = config.timeout.wrap_up_seconds
DELAY_MCU = config.delays.mcu_seconds
DELAY_LH = config.delays.lighthouse_seconds
DELAY_BEGIN_MOTION = config.delays.begin_motion_seconds
LHCALIB_INTERVAL = config.delays.lhcalib_interval

window_title = ""
success_count = 0
interrupt_count = 0
lhcalib_timeslot_current = 0
tick = 0
tick_lock = threading.Lock()
calib_timestamps = []
calib_timestamps_lock = threading.Lock()

warning_strings = []
if os.path.exists("./warning.txt"):
    with open("./warning.txt") as f:
        warning_strings = f.read().split('\n')


def update_window_title():
    global window_title
    window_title = f"Calibot App v{CALIBOT_VERSION} -- x{BASESTATION_COUNT}lh (bscv {TOTAL_SENSOR_HITS}.{HITS_PER_SENSOR})"

update_window_title()


def config_reload():
    global BASESTATION_COUNT, CLEAR_HISTORY, TOTAL_SENSOR_HITS, HITS_PER_SENSOR
    global COUNTDOWN_MINUTES, COUNTDOWN_THRESHOLD, WRAP_UP_SECONDS
    global DELAY_MCU, DELAY_LH, DELAY_BEGIN_MOTION, LHCALIB_INTERVAL
    global window_title

    if config.reload():
        BASESTATION_COUNT = config.basestation_count
        CLEAR_HISTORY = config.clear_device_monitor_history
        TOTAL_SENSOR_HITS = config.total_sensor_hits
        HITS_PER_SENSOR = config.hits_per_sensor
        COUNTDOWN_MINUTES = config.timeout.length_minutes
        COUNTDOWN_THRESHOLD = config.timeout.min_calibrations
        WRAP_UP_SECONDS = config.timeout.wrap_up_seconds
        DELAY_MCU = config.delays.mcu_seconds
        DELAY_LH = config.delays.lighthouse_seconds
        DELAY_BEGIN_MOTION = config.delays.begin_motion_seconds
        LHCALIB_INTERVAL = config.delays.lhcalib_interval
        window_title = f"Calibot App v{CALIBOT_VERSION} (bscv {TOTAL_SENSOR_HITS}.{HITS_PER_SENSOR})"
        log_info(f"Configuration reloaded --> bscv {TOTAL_SENSOR_HITS}.{HITS_PER_SENSOR}")


def static_vars(**kwargs):
    def decorate(func):
        for k in kwargs:
            setattr(func, k, kwargs[k])
        return func
    return decorate


@static_vars(prior_angle=0)
def perform_random_motion_path(arm: Calibot):
    # -0.05 < turn1 < 0.8 and -0.05 < turn2 < 0.55
    a = RAND.uniform(0, 0.9)
    b = RAND.uniform(0, 0.9)
    a = perform_random_motion_path.prior_angle
    if perform_random_motion_path.prior_angle < 0.5:
        perform_random_motion_path.prior_angle = 0.79
    else:
        perform_random_motion_path.prior_angle = 0.0
    arm.go_position(a, b)


@dataclass
class HeadsetHID:
    atmel: Any = None
    lh_serial: str = ''

    def __repr__(self):
        return f"{self.atmel['serial_number']} - {self.lh_serial}"
    
    def safety_check(self):
        assert self.atmel is not None
        assert self.lh_serial != ''
        assert 'port_chain' in self.atmel
    
    def restart_hub(self):
        self.safety_check()
        log_info(f"Restarting USB hub for device {self.atmel['serial_number']} & {self.lh_serial}")
        print(f"Restarting USB hub for device {self.atmel['serial_number']} & {self.lh_serial}")
        parent_hub_chain = '-'.join(self.atmel['port_chain'][:-2])
        bs_usb_tools.restart_port(parent_hub_chain)
    
    def reenqueue(self):
        self.safety_check()
        global matching
        global lhcalib_timeslot_current

        if lhcalib_timeslot_current < tick:
            lhcalib_timeslot_current = tick - LHCALIB_INTERVAL
        lhcalib_timeslot_current += LHCALIB_INTERVAL
        matching.on_match(self, lhcalib_timeslot_current)


class HeadsetHIDMatching:
    def __init__(self, on_match):
        self.prospect = None
        self.atmel_candidates = []
        self.tundra_candidates = []
        self.matched_paths = []
        self.on_match = on_match

    def __on_match(self, dev):
        global tick
        global lhcalib_timeslot_current

        if lhcalib_timeslot_current < tick:
            lhcalib_timeslot_current = tick - LHCALIB_INTERVAL
        lhcalib_timeslot_current += LHCALIB_INTERVAL

        log_info(f"Found match between {dev.atmel['serial_number']} & {dev.lh_serial}")
        print(f"Found match between {dev.atmel['serial_number']} & {dev.lh_serial}")
        # log_info(f"[TICK:{tick}] Scheduling lhcalib for {dev.lh_serial} at tick {lhcalib_timeslot_current}.")  # debug
        self.on_match(dev, lhcalib_timeslot_current)  # calibrate_headset will offset this

    def on_new_hid_device(self, dev, hid_index, new_devs, prev_devs):
        # log_debug(f"NEW HID DEVICE: {dev}")  # debug
        vid = dev['vendor_id']
        pid = dev['product_id']
        product = dev['product_string']
        serial = dev['serial_number'].strip()
        path = dev['path']
        return_devs = []
        new_candidate = False

        if vid in KNOWN_VIDS and pid in KNOWN_PIDS and path not in self.matched_paths:
            # log_debug(f"Processing HID device: Serial={serial}")  # debug
            usb_info = bs_usb_tools.get_info(f"*{serial}*")
            retries = 3
            while 'PortChain' not in usb_info and retries > 0:
                time.sleep(1)
                usb_info = bs_usb_tools.get_info(f"*{serial}*")
                retries -= 1
            assert 'PortChain' in usb_info, f"Could not get PortChain for device with serial {serial}"
            # log_debug(f"USB Info: {usb_info}")  # debug

            if vid in VIDS_ATMEL and pid in PIDS_ATMEL and product == 'Beyond':
                assert len(usb_info) > 0, f"Could not get USB info for HMD with serial {serial}"
                dev['port_chain'] = usb_info['PortChain']
                log_debug(f"Adding Atmel candidate: Serial={serial}")  # debug
                self.atmel_candidates.append(dev)
                new_candidate = True
            if vid in VIDS_TUNDRA and pid in PIDS_TUNDRA and product == 'Controller':
                assert usb_info is not None, f"Could not get USB info for Tundra SIP with serial {serial}"
                dev['port_chain'] = usb_info['PortChain']
                log_debug(f"Adding Tundra candidate: Serial={serial}")  # debug
                self.tundra_candidates.append(dev)
                new_candidate = True

            # check that we found at least one of each. then check for matched pairs
            if new_candidate and len(self.atmel_candidates) > 0 and len(self.tundra_candidates) > 0:  
                log_debug(f"Checking for matches")  # debug
                for a_i, a_dev in enumerate(self.atmel_candidates):
                    log_debug(f"Atmel candidate: {a_i}")  # debug
                    for t_i, t_dev in enumerate(self.tundra_candidates):
                        log_debug(f"  Tundra candidate: {t_i}")  # debug
                        usb_hubs_match = False
                        # iterate through port chains to see if they match up to index -3 (all but the last 2)
                        if len(a_dev['port_chain']) >= 3 and len(t_dev['port_chain']) >= 3:
                            usb_hubs_match = a_dev['port_chain'][:-2] == t_dev['port_chain'][:-2]
                        if not usb_hubs_match:
                            # log_debug(f"    USB hubs do not match: {a_dev['port_chain']} != {t_dev['port_chain']}")  # debug
                            continue
                        if a_dev['path'] not in self.matched_paths and t_dev['path'] not in self.matched_paths:
                            log_debug(f"    Found match!")  # debug
                            print(f"Found dev atmel: {a_dev}")
                            print(f"Found lh_serial: {t_dev}")
                            self.prospect = HeadsetHID(a_dev, t_dev['serial_number'])
                            del self.atmel_candidates[a_i]
                            del self.tundra_candidates[t_i]
                            self.__on_match(self.prospect)
                            self.matched_paths.extend([a_dev['path'], t_dev['path']])
                        else:
                            log_debug(f"    One of the devices was already processed")

        # print(f"---\r\nPrevious devices list: {self.matched_paths}")  # debug
        # print(f"Returning devices: {return_devs}")  # debug
        return return_devs

    def reset(self):
        self.prospect = None
        self.atmel_candidates = []
        self.tundra_candidates = []
        self.matched_paths = []


class TimerState:
    def __init__(self):
        self.reset()

    def activate(self):
        self.activated = True

    def end(self):
        self.finished = True

    def reset(self):
        self.finished = False
        self.activated = False

calib_timer = TimerState()


def start_timeout_countdown():
    global calib_timer
    seconds = 0

    log_info(f"START calibration timeout countdown: {COUNTDOWN_MINUTES} minutes to go")
    calib_timer.activate()
    while calib_timer.activated and seconds < COUNTDOWN_MINUTES * 60:
        seconds += 1
        time.sleep(1)

    if calib_timer.activated:
        log_info("FINISH calibration timeout countdown")
        calib_timer.end()
    else:
        log_info("RESET calibration timeout countdown")
        calib_timer.reset()


def set_light_color(dev, rgb):
    assert dev['vendor_id'] in VIDS_ATMEL and dev['product_id'] in PIDS_ATMEL
    device = hid.device()
    device.open_path(dev['path'])
    print(f'setting led on serial {dev["serial_number"]} to {rgb}')
    device.send_feature_report(bytes([0, ord('L'), rgb[0], rgb[1], rgb[2]]))


def calibrate_headset(headset: HeadsetHID, lhcalib_timeslot: int = 0, retry_count: int = 0):
    global tick
    global success_count
    global interrupt_count
    global calib_timer
    global calib_timestamps
    global calib_timestamps_lock
    start_tick = tick
    retry_calibrate = True

    try:
        time.sleep(DELAY_MCU)  # give the MCU time to post
        set_light_color(headset.atmel, rgb=(255, 255, 255))

        print('loading init config')
        log_info(headset.lh_serial + ": loading init config")
        time.sleep(DELAY_LH)  # give the lighthouse SIP time to post
        load_init_json.load_init_config(headset.lh_serial, init_json)
        retry_calibrate = False

        print("Calibrating " + headset.lh_serial)
        log_info(headset.lh_serial + ": Calibrating " + headset.lh_serial)
        # 800 200

        # stagger lhcalib calls
        lhcalib_timeslot += tick - start_tick  # offset granted timeslot by the time it took to get here
        while lhcalib_timeslot > tick:
            time.sleep(1)
        cmd = [LHCALIB_PATH, '/bodyserial', headset.lh_serial, "/bodycal", init_json, f"{TOTAL_SENSOR_HITS}", f"{HITS_PER_SENSOR}", '/bodycalmulti', f"{BASESTATION_COUNT}", '/deletemissingsensors']
        print(cmd)
        log_debug(headset.lh_serial + ": Opening command: " + ' '.join(cmd))
        p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stdin=subprocess.PIPE, stderr=subprocess.PIPE)

        rslt = None
        while rslt is None:
            if calib_timer.finished:
                log_info(headset.lh_serial + ": " + "Calibration did not finish in time")
                assert not calib_timer.finished, "Calibration timed out"
            rslt = p.poll()
            line = p.stdout.readline().decode().rstrip()
            if line == "" or "Data request is out of range" in line or "LYX" in line or "RYB" in line:
                continue
            print(line)
            log_debug(headset.lh_serial + ": " + line)
            if "Unexpected error 15?" in line:
                retry_calibrate = True
                log_error(headset.lh_serial + ": Unexpected error 15?")
                raise RuntimeError("Unexpected error 15?")
            for l in warning_strings:
                if l in line:
                    log_warning(headset.lh_serial + ": found warning " + l + " in line " + line)
            if "Ready to run" in line:
                print("\r\n")
                p.stdin.write("\n".encode())
                p.stdin.flush()
                print("Calibration Started")
                log_info(headset.lh_serial + ": Calibration Started")
            if "avg z" in line:
                time.sleep(1)
        assert rslt == 0, p.stderr.readline().decode().rstrip()
        print(p.stdout.readline().decode().rstrip())
        log_info(headset.lh_serial + ": " + p.stdout.readline().decode().rstrip())

        tracking_file_path = f"./auto_{headset.lh_serial}.json"
        assert os.path.exists(tracking_file_path), "No tracking file found for this headset"
        calibrated_config = steamvr.lighthouse_console.load_steamvr_json(tracking_file_path)

        log_info(headset.lh_serial + ": " + f"Uploading config (bscv {TOTAL_SENSOR_HITS}.{HITS_PER_SENSOR})")
        lh = steamvr.LighthouseConsole()
        lh.open()
        lh.select_device(headset.lh_serial)
        config = lh.download_config()
        config['lighthouse_config'] = calibrated_config['lighthouse_config']
        config['bscv'] = f"{TOTAL_SENSOR_HITS}.{HITS_PER_SENSOR}"
        lh.upload_config(config)

        # add current tick as next index of calib_timestamps
        with calib_timestamps_lock:
            calib_timestamps.append({"note": headset.lh_serial, "tick": tick})
        set_light_color(headset.atmel, rgb=(0, 255, 0))
        log_info(headset.lh_serial + ": " + "Finished calibrating headset ")
        print(f"Finished calibrating headset {headset.lh_serial}")
        success_count += 1
    except BaseException as e:
        print(e)
        if calib_timer.finished:
            log_error(headset.lh_serial + ": " + str(e))
            interrupt_count += 1
            set_light_color(headset.atmel, rgb=(255, 0, 0))
        elif retry_calibrate and retry_count < 3:
            log_warning(headset.lh_serial + ": " + str(e))
            if retry_count < 2:
                log_warning(headset.lh_serial + ": Retrying calibration setup for this device...")
                calibrate_headset(headset, lhcalib_timeslot + tick - start_tick, retry_count + 1)
            else:
                log_error(headset.lh_serial + ": 3 consecutive attempts failed. Restart USB before continuing...")
                headset.restart_hub()
                headset.reenqueue()
        else:
            log_error(headset.lh_serial + ": " + str(e))
            log_error(headset.lh_serial + ": Reenqueueing to attempt again...")
            headset.reenqueue()


def reset_run():
    global success_count
    global interrupt_count
    global tick
    global lhcalib_timeslot_current
    global calib_timestamps
    global matching

    if CLEAR_HISTORY:
        instance.reset_history()
        matching.reset()

    success_count = 0
    interrupt_count = 0
    tick = 0
    lhcalib_timeslot_current = 0
    calib_timestamps = []


def log_window():
    imgui.separator()
    hello_imgui.log_gui()
    time.sleep(0.1)



def log_loop():
    global window_title

    while True:
        immapp.run(log_window, window_title, with_markdown=True, window_restore_previous_geometry=True)


def tick_loop():
    global tick
    while True:
        time.sleep(1)
        with tick_lock:
            tick += 1


def main():
    global tick
    global success_count
    global interrupt_count
    global calib_timer
    global calib_timestamps
    global calib_timestamps_lock
    global matching

    update_window_title()
    threading.Thread(target=tick_loop, daemon=True).start()
    threading.Thread(target=log_loop, daemon=True).start()

    if CLEAR_HISTORY:
        instance.reset_history()

    log_info("Looking out for the following warning strings:")
    for l in warning_strings:
        log_warning(l)

    log_info("Starting Calibot")
    O1 = 1/360*2  # offset 1
    O2 = 1/360*0.1  #0.0000  # offset 2
    calibot = Calibot(O1, O2)
    log_info("Calibot Started")

    log_info("Connecting to Calibot")
    calibot.module1.connection.send(TMCLCommand.SAP, 5, 0, 200)  # acceleration
    calibot.module2.connection.send(TMCLCommand.SAP, 5, 0, 200)  # acceleration
    log_info("Connected to Calibot")

    log_info("Homing Calibot")
    calibot.go_home()
    log_info("Calibot Homed")

    log_info("Starting Headset Calibration Loop")
    matching = HeadsetHIDMatching(lambda headset, lhcalib_timeslot: print(headset))
    instance.register_on_new_hid_device(matching.on_new_hid_device)
    while True:
        last_minute_reported = 0
        futures = []
        with ThreadPoolExecutor(max_workers=5) as tp:
            matching.on_match = lambda headset, lhcalib_timeslot: futures.append(tp.submit(calibrate_headset, headset, lhcalib_timeslot))

            calibot.module1.connection.send(TMCLCommand.SAP, 4, 0, 1600)  # search speed
            calibot.module2.connection.send(TMCLCommand.SAP, 4, 0, 1600)  # search speed
            calibot.go_position(0.25, 0.25)
            print("Place headsets")
            log_info("Waiting for headsets to be placed")
            # start = input("press enter once loaded:")
            # print("starting motion")

            while len(futures) != 4:
                time.sleep(1)

            config_reload()
            update_window_title()

            time.sleep(DELAY_BEGIN_MOTION)  # give the MCU, lighthouse SIP time to post, and lhcalib time to load
            calib_timer.reset()
            # add current tick as next index of calib_timestamps
            with calib_timestamps_lock:
                calib_timestamps.append({"note": "start", "tick": tick})

            while sum([not f.done() for f in futures]) and not calib_timer.finished:
                with tick_lock:
                    # check if diff between calib_timestamps[0] and tick is a new minute
                    if tick - calib_timestamps[0]['tick'] >= (last_minute_reported + 1) * 60:
                        last_minute_reported += 1
                        seconds = (tick - calib_timestamps[0]['tick']) % 60
                        log_info(f"TIME ELAPSED: {last_minute_reported}m{seconds:02d}s")
                if success_count >= COUNTDOWN_THRESHOLD and not calib_timer.activated:
                    log_info(f"Countdown threshold of {COUNTDOWN_THRESHOLD} calibrations reached")
                    tp.submit(start_timeout_countdown)
                log_info("Waiting for headset calibration to finish")
                calibot.module1.connection.send(TMCLCommand.SAP, 4, 0, 1800)
                calibot.module2.connection.send(TMCLCommand.SAP, 4, 0, 1800)
                log_info("Performing motion path")
                perform_random_motion_path(calibot)

            log_info("Wrapping up remaining tasks. PLEASE WAIT...")
            cutoff_tick = tick + WRAP_UP_SECONDS
            # send calibot home to make sure we get more lhcalib output for graceful shutdown
            calibot.module1.connection.send(TMCLCommand.SAP, 4, 0, 1600)  # search speed
            calibot.module2.connection.send(TMCLCommand.SAP, 4, 0, 1600)  # search speed
            calibot.e_stop()
            calibot.go_position(0.25, 0.25)

            if calib_timer.finished:
                while sum([not f.done() for f in futures]) and tick < cutoff_tick:
                    time.sleep(1)
                if sum([not f.done() for f in futures]) and tick >= cutoff_tick:
                    log_warning("lhcalib stalled for some devices. Force closing.")
            
            calib_timer.reset()
            tp.shutdown(wait=False)

        if interrupt_count > 0:
            log_info(f"Calibration timed out with {interrupt_count} calibrations incomplete")
        else:
            log_info("Finished headset calibration")
        
        log_info("HMD calibration times:")
        with calib_timestamps_lock:
            for i in range(1, len(calib_timestamps)):
                elapsed = calib_timestamps[i]['tick'] - calib_timestamps[0]['tick']
                # output as minutes:seconds
                minutes = elapsed // 60
                seconds = elapsed % 60
                log_info(f"  {calib_timestamps[i]['note']}: {minutes}m{seconds:02d}s")

        bs_usb_tools.cleanup_beyond_devices(log_info, log_debug)
        reset_run()


if __name__ == '__main__':
    try:
        main()
    except BaseException as e:
        print(e)
    k = input("press close to exit")
