import numpy as np
import serial
from dataclasses import dataclass
import time

@dataclass
class WorldPosition:

    rz : float = 0.0
    ry : float = 0.0
    rz1 : float = 0.0

class SphericalParallelJoint:

    def __init__(self, port : str = "/dev/ttyACM0", baud : int = 115200) -> None:
        self._ser = serial.Serial(port, baudrate=baud, timeout=1)
        self._curr_position = WorldPosition()
        self._rot_axis = np.array([
            [ -1.0 / np.sqrt(3.0), -1.0 / np.sqrt(3.0), - 1.0 / np.sqrt(3.0) ],
            [ -1.0 / np.sqrt(2.0), 1.0 / np.sqrt(2.0), 0.0 ],
            [ -1.0 / np.sqrt(3.0), -1.0 / np.sqrt(3.0), - 1.0 / np.sqrt(3.0) ]
        ])

        if not self.check_online():
            raise RuntimeError("Failed to communicate with device")
            
    def __get_rotational_matrix(self, index : int, angle : float) -> np.ndarray:

        angle = np.deg2rad(angle)
        ux, uy, uz = self._rot_axis[index]

        mat = np.array([
            [ 
                np.cos(angle) + ux ** 2 * (1.0 - np.cos(angle)), 
                ux * uy * (1.0 - np.cos(angle)) - uz * np.sin(angle),
                ux * uz * (1.0 - np.cos(angle)) + uy * np.sin(angle)
            ],
            [
                uy * uz * (1.0 - np.cos(angle)) + uz * np.sin(angle),
                np.cos(angle) + uy ** 2 * (1.0 - np.cos(angle)),
                uy * ux * (1.0 - np.cos(angle)) + ux * np.sin(angle)
            ],
            [
                uz * ux * (1.0 - np.cos(angle)) - uy * np.sin(angle),
                uz * uy * (1.0 - np.cos(angle)) + ux * np.sin(angle),
                np.cos(angle) + uz ** 2 * (1.0 - np.cos(angle))
            ]
        ])

        return mat
    
    def compute_steps(self, new_pos : WorldPosition) -> list[int]:

        first_rot_mat = self.__get_rotational_matrix(0, new_pos.rz)
        second_rot_mat = self.__get_rotational_matrix(1, new_pos.ry)
        third_rot_mat = self.__get_rotational_matrix(2, new_pos.rz1)

        # rot_mat = third_rot_mat * (second_rot_mat * first_rot_mat)
        rot_mat = np.matmul(third_rot_mat, np.matmul(second_rot_mat, first_rot_mat))

        theta1 = np.atan(rot_mat[2, 1] / rot_mat[1, 1])
        theta2 = np.atan(rot_mat[1, 0] / rot_mat[0, 0])
        theta3 = np.atan(rot_mat[0, 2] / rot_mat[2, 2])

        step1 = (108.0 / 20.0) * (theta1 - np.pi / 4.0) * (200.0 / (2 * np.pi))
        step2 = (108.0 / 20.0) * (theta2 - np.pi / 4.0) * (200.0 / (2 * np.pi))
        step3 = (108.0 / 20.0) * (theta3 - np.pi / 4.0) * (200.0 / (2 * np.pi))

        return[step1, step2, step3]

    def move(self, pos : WorldPosition) -> list[float]:

        steps = self.compute_steps(pos)
        self._ser.write(f"{steps[0]},{steps[1]},{steps[2]}\n".encode("utf-8"))
        
        return [float(steps[0]),float(steps[1]),float(steps[2])]

    def check_online(self) -> bool:

        msg = self._ser.read_until().decode("utf-8").strip()

        if msg == "Wait":
            self._ser.timeout = 5
            self._ser.write("\n".encode("utf-8"))
            msg = self._ser.read_until().decode("utf-8").strip()
            self._ser.timeout = 1
            if msg == "Ready":
                return True
            return False
        
        if len(msg) > 0:
            return True        

        return False
    
    def get_current_position(self) -> list[float]:
        
        pos = self._ser.read_until()
        steps = pos.decode("utf-8").strip().split(",")
        steps = [float(x) for x in steps]

        return steps
        
        

    