"""
This module is used to control the spherical parallel joint device
from Skyentific
"""

from threading import Thread, Lock, Event, main_thread
from dataclasses import dataclass
import time
import numpy as np
import serial


class SpjException(Exception):
    """Generic class exception"""


class NotConnected(SpjException):
    """Serial device not connected"""


@dataclass
class WorldPosition:
    """World position Rz, Ry, Rx"""

    rz: float = 0.0
    ry: float = 0.0
    rx: float = 0.0


_ser: serial.Serial = None


class WorkerThread:
    """Worker thread to continuously acquire device position"""

    def __init__(self):
        """Constructor"""

        self._stop_event = Event()
        self._thread: Thread = None
        self._lock: Lock = Lock()
        self._current_pos: list[float] = [0, 0, 0]

    def start(self) -> None:
        """Start thread"""

        self._stop_event.clear()
        self._thread = Thread(target=self.__run)
        self._thread.start()

    def stop(self) -> None:
        """Stop thread and join"""

        self._stop_event.set()
        self._thread.join()

    @property
    def running(self) -> bool:
        """Thread running

        Return
        ------
        True if thread is running
        """

        return self._thread.is_alive()

    @property
    def current_pos(self) -> list[float]:
        """Get current device position

        Return
        List with current motor steps [Rz, Ry, Rx]
        """

        if self._lock.acquire(blocking=False):
            steps = self._current_pos
            self._lock.release()
        else:
            steps = [0.0, 0.0, 0.0]

        return steps

    def __run(self):
        """Thread function"""

        global _ser

        while not self._stop_event.is_set() and main_thread().is_alive():

            new_pos = _ser.read_until()

            if len(new_pos) > 0:
                steps = new_pos.decode("utf-8").strip().split(",")
                steps = [float(x) for x in steps]
                if self._lock.acquire(blocking=False):
                    self._current_pos = steps
                    self._lock.release()

            time.sleep(0.001)


_worker_thread: WorkerThread = None


class SphericalParallelJoint:
    """Class to control the SPJ device"""

    @staticmethod
    def home_pos() -> WorldPosition:
        """Get the home position"""

        return WorldPosition(0, 0, 0)

    @staticmethod
    def store_pos() -> WorldPosition:
        """Get the store position"""

        return WorldPosition(-60, 0, 0)

    def __init__(self) -> None:
        """Constructor

        Example
        -------
        .. code-block:: python

            robot = SphericalParallelJoint()
            robot.connect()

            print("Move sync")
            if not robot.move_sync(robot.HOME_POS):
                print("Move failed")

            print("Move async")
            robot.move_async(robot.STORE_POS)
            time.sleep(4)

            print(f"Current position: {robot.get_current_position()}")

            robot.close()
        """

        self._rot_axis = np.array(
            [
                [-1.0 / np.sqrt(3.0), -1.0 / np.sqrt(3.0), -1.0 / np.sqrt(3.0)],
                [-1.0 / np.sqrt(2.0), 1.0 / np.sqrt(2.0), 0.0],
                [
                    1.0 / np.sqrt(6.0),
                    1.0 / np.sqrt(6.0),
                    -2.0 / np.sqrt(6.0),
                ],  # ZYX version. Values calculated with cross product of ZxY
                # [ -1.0 / np.sqrt(3.0), -1.0 / np.sqrt(3.0), - 1.0 / np.sqrt(3.0) ] #ZYZ' version. Z' = Z
            ]
        )

        self._current_pos: list[float] = self.store_pos()
        self._connected = False

    def connect(self, port: str = "/dev/ttyACM0", baud: int = 115200):
        """Connect to the device

        Parameters
        ----------
        port: str
            COM port for serial communication. For linux /dev/tty<name>, for windows COM<X>
        baud: int, optional
            Serial baudrate

        Raise
        -----
        Runtime error
        """

        global _worker_thread
        global _ser

        _ser = serial.Serial(port, baudrate=baud, timeout=1)

        if not self.check_online():
            raise RuntimeError("Failed to communicate with device")

        self._connected = True

        # Empty the read buffer
        _ser.read_all()

        # Start the thread
        _worker_thread = WorkerThread()
        _worker_thread.start()
        time.sleep(0.01)
        self._current_pos = self.get_current_position()

    def __check_connected(self):
        """Check if the connect method was called

        Raise
        -----
        NotConnected
        """

        if not self._connected:
            raise NotConnected(
                "Use the connect method to establish a connection to the device"
            )

    def __get_rotational_matrix(self, index: int, angle: float) -> np.ndarray:
        """Compute the rotational matrix for one axis

        Parameters
        ----------
        index: int
            Axis index (Rz = 0, Ry = 1, Rx = 2)
        angle: float
            Axis angle in degree

        Return
        ------
        Rotational matrix for that axis
        """

        angle = np.deg2rad(angle)
        ux, uy, uz = self._rot_axis[index]

        mat = np.array(
            [
                [
                    np.cos(angle) + ux**2 * (1.0 - np.cos(angle)),
                    ux * uy * (1.0 - np.cos(angle)) - uz * np.sin(angle),
                    ux * uz * (1.0 - np.cos(angle)) + uy * np.sin(angle),
                ],
                [
                    uy * uz * (1.0 - np.cos(angle)) + uz * np.sin(angle),
                    np.cos(angle) + uy**2 * (1.0 - np.cos(angle)),
                    uy * ux * (1.0 - np.cos(angle)) + ux * np.sin(angle),
                ],
                [
                    uz * ux * (1.0 - np.cos(angle)) - uy * np.sin(angle),
                    uz * uy * (1.0 - np.cos(angle)) + ux * np.sin(angle),
                    np.cos(angle) + uz**2 * (1.0 - np.cos(angle)),
                ],
            ]
        )

        return mat

    def compute_steps(self, new_pos: WorldPosition) -> list[float]:
        """Compute the steps for each motor based on Rz, Ry and Rx coordinates

        Parameters
        ----------
        new_pos: WorldPosition
            Position to compute the steps for

        Return
        ------
        Steps for each motor
        """

        first_rot_mat = self.__get_rotational_matrix(0, new_pos.rz)
        second_rot_mat = self.__get_rotational_matrix(1, new_pos.ry)
        third_rot_mat = self.__get_rotational_matrix(2, new_pos.rx)

        # rot_mat = third_rot_mat * (second_rot_mat * first_rot_mat)
        rot_mat = np.matmul(third_rot_mat, np.matmul(second_rot_mat, first_rot_mat))

        theta1 = np.atan(rot_mat[2, 1] / rot_mat[1, 1])
        theta2 = np.atan(rot_mat[1, 0] / rot_mat[0, 0])
        theta3 = np.atan(rot_mat[0, 2] / rot_mat[2, 2])

        step1 = (108.0 / 20.0) * (theta1 - np.pi / 4.0) * (200.0 / (2 * np.pi))
        step2 = (108.0 / 20.0) * (theta2 - np.pi / 4.0) * (200.0 / (2 * np.pi))
        step3 = (108.0 / 20.0) * (theta3 - np.pi / 4.0) * (200.0 / (2 * np.pi))

        return [step1, step2, step3]

    def __move(self, pos: WorldPosition) -> list[float]:
        """Do move the device

        Parameters
        ----------
        pos: WorldPosition
            Position to move to

        Return
        ------
        End position in steps
        """

        global _ser

        self.__check_connected()

        steps = self.compute_steps(pos)
        _bytes = _ser.write(f"{steps[0]},{steps[1]},{steps[2]}\n".encode("utf-8"))
        if _bytes < 1:
            print("Failed to write")
            steps = self._current_pos

        return [float(steps[0]), float(steps[1]), float(steps[2])]

    def move_sync(self, pos: WorldPosition) -> bool:
        """Move the device synchronously. This call will return when the movement is finished

        Parameters
        ----------
        pos: WorldPosition
            Position to move to

        Return
        ------
        True if move was successful, False otherwise
        """

        max_errors = 100
        end_pos = self.__move(pos)

        try:
            end_pos = list(map(int, end_pos))
        except ValueError:
            print(f"Failed to convert end position to int: {end_pos}")
            return False

        while True:

            if max_errors < 1:
                break

            try:
                current_pos = list(map(int, self.get_current_position()))
            except ValueError:
                max_errors -= 1
                continue

            if current_pos == end_pos:
                break

        if max_errors < 1:
            return False

        return True

    def move_async(self, pos: WorldPosition) -> list[float]:
        """Move the robot asynchronously. Return before movement ends.

        Parameters
        ----------
        pos: WorldPosition
            Position to move to

        Return
        ------
        End position in steps
        """

        return self.__move(pos)

    def check_online(self) -> bool:
        """Check if the device is online and working

        Return
        ------
        True if online
        """

        global _ser

        msg = _ser.read_until().decode("utf-8").strip()

        if msg == "Wait":
            _ser.timeout = 5
            _ser.write("\n".encode("utf-8"))
            msg = _ser.read_until().decode("utf-8").strip()
            _ser.timeout = 1
            if msg == "Ready":
                return True
            return False

        if len(msg) > 0:
            return True

        return False

    def get_current_position(self) -> list[float]:
        """Get the current position of the motors in steps

        Return
        ------
        A list in the format (stepZ, stepY, stepX)
        """

        global _worker_thread

        self.__check_connected()

        self._current_pos = _worker_thread._current_pos

        # Sleep to avoid overloading the lock
        time.sleep(0.001)

        return self._current_pos

    def close(self):
        """End communication with the device"""

        global _worker_thread
        global _ser

        if not self._connected:
            return

        if _worker_thread.running:
            _worker_thread.stop()

        _ser.close()
        self._connected = False

    def __del__(self):
        """Destructor"""

        if self._connected:
            self.close()


if "__main__" == __name__:

    import os

    os.chdir(os.path.dirname(__file__))

    robot = SphericalParallelJoint()

    robot.connect()

    print("Move sync")
    if not robot.move_sync(SphericalParallelJoint.home_pos()):
        print("Move failed")

    print("Move async")
    robot.move_async(SphericalParallelJoint.store_pos())
    time.sleep(4)

    print(f"Current position: {robot.get_current_position()}")

    robot.close()
