import time
from pytrinamic.modules import TMCM1161
from pytrinamic.tmcl import TMCLCommand
from pytrinamic.connections import ConnectionManager
from typing import Optional
from gui_log import *
from serial import SerialException
import threading

# using USB interface
# Calibot designed for use with TMCM1161 via USB connection


class Calibot:
    def __init__(self, o1, o2):
        self.wormratio = 180
        self.mrs = 5
        self.__offset1 = o1
        self.__offset2 = o2

        log_debug(f"Calibot initialized with offsets {self.__offset1} & {self.__offset2}")
        id1 = find_interface(159)
        id2 = find_interface(160)
        log_debug(f"Found interfaces {id1} & {id2}")
        if id1 is None:
            log_error("Failed to connect to axis 159")
            raise RuntimeError("Failed to connect to axis 159")
        if id2 is None:
            log_error("Failed to connect to axis 160")
            raise RuntimeError("Failed to connect to axis 160")

        log_debug(f"Connecting to axis 159")
        self.connection_manager1 = ConnectionManager(f"--interface serial_tmcl --port COM{id1} --data-rate 115200")
        log_debug(f"Connecting to axis 160")
        self.connection_manager2 = ConnectionManager(f"--interface serial_tmcl --port COM{id2} --data-rate 115200")
        log_debug(f"Connecting to module 1")
        self.my_interface1 = self.connection_manager1.connect()
        log_debug(f"Connecting to module 2")
        self.my_interface2 = self.connection_manager2.connect()

        log_debug(f"Initializing module 1")
        self.module1 = TMCM1161(self.my_interface1)  # light axis --- limit travel {0 to 0.75}
        log_debug(f"Initializing module 2")
        self.module2 = TMCM1161(self.my_interface2)  # heavy axis --- limit travel {0 to 0.5}

        mod1val = self.module1.connection.send(TMCLCommand.GGP, 1, 2, 0).value

        #print("module1: ", mod1val)
        log_debug(f"Module 1 value: {mod1val}")
        #print("module2: ", mod2val)
        log_debug(f"Module 2 value: {mod1val}")
        if mod1val != 159:
            log_debug(f"Swapping modules")
            #print("modules swapping")
            self.module1, self.module2 = self.module2, self.module1

        log_debug(f"Configuring module 1")
        # initialization config
        self.module1.connection.send(TMCLCommand.SAP, 5, 0, 50)  # acceleration
        self.module1.connection.send(TMCLCommand.SAP, 6, 0, 80)  # current
        self.module1.connection.send(TMCLCommand.SAP, 4, 0, 1600)  # speed max
        self.module1.connection.send(TMCLCommand.SAP, 140, 0, self.mrs)  # micro step res 5 (128x)
        self.module1.connection.send(TMCLCommand.SAP, 153, 0, 7)  # ramp divisor
        self.module1.connection.send(TMCLCommand.SAP, 154, 0, 3)  # pulse divisor
        self.module1.connection.send(TMCLCommand.SIO, 0, 0, 1)  # pull-up resistor
        self.module1.connection.send(TMCLCommand.RSGP, 1, 2, 0)  # restore global user variable 1
        log_debug(f"Configuring module 2")
        # initialization config
        self.module2.connection.send(TMCLCommand.SAP, 5, 0, 20)  # acceleration
        self.module2.connection.send(TMCLCommand.SAP, 6, 0, 80)  # current
        self.module2.connection.send(TMCLCommand.SAP, 4, 0, 1600)  # speed max
        self.module2.connection.send(TMCLCommand.SAP, 140, 0, self.mrs)  # micro step res 5 (128x)
        self.module2.connection.send(TMCLCommand.SAP, 153, 0, 7)  # ramp divisor
        self.module2.connection.send(TMCLCommand.SAP, 154, 0, 3)  # pulse divisor
        self.module2.connection.send(TMCLCommand.SIO, 0, 0, 1)  # pull-up resistor
        self.module2.connection.send(TMCLCommand.RSGP, 1, 2, 0)  # restore global user variable 1

    def go_home(self):
        # initialization
        motor1 = self.module1.motors[0]
        motor2 = self.module2.motors[0]

        # ref search config
        self.module1.connection.send(TMCLCommand.SAP, 12, 0, 1)  # right limit activate
        self.module1.connection.send(TMCLCommand.SAP, 13, 0, 1)  # left limit activate
        self.module1.connection.send(TMCLCommand.SAP, 149, 0, 1)  # soft stop
        self.module1.connection.send(TMCLCommand.SAP, 193, 0, 4)  # search mode
        self.module1.connection.send(TMCLCommand.SAP, 194, 0, 900)  # search speed
        self.module1.connection.send(TMCLCommand.SAP, 195, 0, 100)  # search switch speed
        self.module1.connection.send(TMCLCommand.SIO, 0, 2, 1)  # OUT_0 HIGH
        # ref search config
        self.module2.connection.send(TMCLCommand.SAP, 12, 0, 1)  # right limit activate
        self.module2.connection.send(TMCLCommand.SAP, 13, 0, 1)  # left limit activate
        self.module2.connection.send(TMCLCommand.SAP, 149, 0, 1)  # soft stop
        self.module2.connection.send(TMCLCommand.SAP, 193, 0, 4)  # search mode
        self.module2.connection.send(TMCLCommand.SAP, 194, 0, 900)  # search speed
        self.module2.connection.send(TMCLCommand.SAP, 195, 0, 100)  # search switch speed
        self.module2.connection.send(TMCLCommand.SIO, 0, 2, 1)  # OUT_0 HIGH

        # start ref search
        self.module1.connection.send(TMCLCommand.RFS, 0, 0, 0)  # start reference search
        self.module2.connection.send(TMCLCommand.RFS, 0, 0, 0)  # start reference search
        timeout = 60  # seconds
        sec0 = time.time()
        sec1 = time.time()

        while self.module1.connection.send(TMCLCommand.RFS, 2, 0, 0).value != 0 or self.module2.connection.send(TMCLCommand.RFS,
                                                                                                      2, 0,
                                                                                                      0).value != 0:
            sec1 = time.time()
            #print("motor1: ", module1.connection.send(TMCLCommand.RFS, 2, 0, 0).value)
            #print("motor2: ", module2.connection.send(TMCLCommand.RFS, 2, 0, 0).value)
            if sec1 > sec0 + timeout:
                #print("TIMED OUT")
                self.module1.connection.send(TMCLCommand.MST, 0, 0, 0)  # stop motor when timed out
                self.module2.connection.send(TMCLCommand.MST, 0, 0, 0)  # stop motor when timed out
                break
            time.sleep(0.5)
        time.sleep(2)
        self.module1.connection.send(TMCLCommand.SAP, 1, 0, 0)  # reset current position to 0
        self.module1.connection.send(TMCLCommand.SAP, 0, 0, 0)  # reset target position to 0
        self.module2.connection.send(TMCLCommand.SAP, 1, 0, 0)  # reset current position to 0
        self.module2.connection.send(TMCLCommand.SAP, 0, 0, 0)  # reset target position to 0
        time.sleep(0.5)
        #print("Position1 = {}".format(motor1.actual_position))
        #print("Position2 = {}".format(motor2.actual_position))
        #print("HOMING COMPLETE")
        #print("HOMING ELAPSED TIME =", sec1 - sec0)
        return None

    def go_position(self, turn1, turn2):
        float(turn1)
        float(turn2)
        if -0.05 < turn1 < 0.8 and -0.05 < turn2 < 0.55:
            # initialization
            motor1 = self.module1.motors[0]
            motor2 = self.module2.motors[0]
            steps1 = (turn1 + self.__offset1) * self.wormratio * (2 ** self.mrs) * 200
            steps1 = int(steps1)
            steps2 = (turn2 + self.__offset2) * self.wormratio * (2 ** self.mrs) * 200
            steps2 = int(steps2)
            #print(turn1 + self.__offset1)

            self.module1.connection.send(TMCLCommand.MVP, 0, 0, steps1)
            self.module2.connection.send(TMCLCommand.MVP, 0, 0, steps2)
            timeout = 30  # seconds
            sec0 = time.time()
            sec1 = time.time()
            currentposition1 = motor1.actual_position
            currentposition2 = motor2.actual_position

            while self.module1.connection.send(TMCLCommand.GAP, 8, 0, 0).value == 0 or self.module2.connection.send(
                    TMCLCommand.GAP, 8, 0, 0).value == 0:
                sec1 = time.time()
                currentposition1 = motor1.actual_position
                currentposition2 = motor2.actual_position
                #print("motor1 = ", (steps1 - currentposition1) / (self.wormratio * (2 ** self.mrs) * 200), "turns left")
                #print("motor2 = ", (steps2 - currentposition2) / (self.wormratio * (2 ** self.mrs) * 200), "turns left")
                if sec1 > sec0 + timeout:
                    #print("TIMED OUT")
                    self.module1.connection.send(TMCLCommand.SAP, 0, 0, currentposition1)
                    self.module2.connection.send(TMCLCommand.SAP, 0, 0, currentposition2)
                    break
                time.sleep(0.1)
            time.sleep(0.1)
            #print("Position1 = ", currentposition1)
            #print("Position2 = ", currentposition2)
            #print("motor1 = ", currentposition1 / (self.wormratio * (2 ** self.mrs) * 200))
            #print("motor2 = ", currentposition2 / (self.wormratio * (2 ** self.mrs) * 200))
            #print("MOVE COMPLETE")
            #print("MOVE ELAPSED TIME =", sec1 - sec0)
            return None
        else:
            #print("turn out of range")
            return None

    def e_stop(self):
        self.module1.connection.send(TMCLCommand.MST, 0, 0, 0)  # stop motor
        self.module2.connection.send(TMCLCommand.MST, 0, 0, 0)  # stop motor
        return None


def find_interface(axis_id) -> Optional[int]:
    def attempt_connection(i, result_holder):
        start_time = time.time()
        try:
            connection_manager = ConnectionManager(
                f"--interface serial_tmcl --port COM{i} --data-rate 115200"
            )
            with connection_manager.connect() as module:
                while time.time() - start_time < 0.2:  # Enforce 200ms timeout
                    try:
                        mod_val = (
                            TMCM1161(module)
                            .connection.send(TMCLCommand.GGP, 1, 2, 0)
                            .value
                        )
                        if mod_val == axis_id:
                            print("found axis", mod_val, "on COM", i)
                            result_holder["port"] = i
                            return
                        else:
                            print(
                                "axis",
                                axis_id,
                                "status: found module unassociated with axis on COM",
                                i,
                                ". axis number:",
                                mod_val,
                            )
                            return
                    except RuntimeError as e:
                        print(f"TMCL error on COM{i}: {e}")
                        return
                    except SerialException as e:
                        print(f"Serial communication error on COM{i}: {e}")
                        return
                    except AttributeError:
                        print(f"Serial port issue on COM{i}. Skipping.")
                        return
        except ConnectionError:
            print("port COM", i, " attempted")
        except TimeoutError:
            print("Timeout on COM", i)
        except SerialException as e:
            print(f"SerialException on COM{i}: {e}")
        except AttributeError:
            print(f"COM{i} may be in an invalid state. Skipping.")

    for i in range(100):
        result_holder = {"port": None}
        thread = threading.Thread(
            target=attempt_connection, args=(i, result_holder), daemon=True
        )
        thread.start()
        thread.join(timeout=0.2)  # Main timeout control
        if thread.is_alive():
            print("Connection attempt timed out on COM", i)
            continue
        if result_holder["port"] is not None:
            return result_holder["port"]
    return None