#ifndef UNICODE
#define UNICODE
#endif
#ifndef _UNICODE
#define _UNICODE
#endif

#include <string>
#include <iostream>
#include <map>

extern "C" {
    #include <Windows.h>
    #include <setupapi.h>
    #include <devguid.h>
    #include <regstr.h>
    #include <tchar.h>
}

#include "mcp2221a.h"
#include "mcp2221_dll_um.h"

#pragma comment(lib, "setupapi.lib")

#define MCP2221A_VID       "04D8"
#define MCP2221A_PID       "00DD"
#define READ_TIMEOUT        1000 // in milliseconds

namespace MCP2221A {
    static std::map<Serial::BAUD, int> baud_map = {
        {Serial::B_300, CBR_300},
        {Serial::B_1200, CBR_1200},
        {Serial::B_2400, CBR_2400},
        {Serial::B_4800, CBR_4800},
        {Serial::B_9600, CBR_9600},
        {Serial::B_19200, CBR_19200},
        {Serial::B_38400, CBR_38400},
        {Serial::B_57600, CBR_57600},
        {Serial::B_115200, CBR_115200}
    };

    static void PrintCommState(DCB dcb)
    {
        //  Print some of the DCB structure values
        _tprintf( TEXT("\nBaudRate = %d, ByteSize = %d, Parity = %d, StopBits = %d\n"), 
              dcb.BaudRate, 
              dcb.ByteSize, 
              dcb.Parity,
              dcb.StopBits );
    }

    static bool check_com_port_vid_pid(const std::wstring& targetPort, const std::string& targetVid = MCP2221A_VID, const std::string& targetPid = MCP2221A_PID)
    {
        bool ret = false;

        HDEVINFO deviceInfoSet = SetupDiGetClassDevs(
            (const GUID*)&GUID_DEVCLASS_PORTS,
            NULL,
            NULL,
            DIGCF_PRESENT
        );

        if (deviceInfoSet == INVALID_HANDLE_VALUE) {
            std::cerr << "Failed to get device info set." << std::endl;
            return false;
        }

        SP_DEVINFO_DATA deviceInfoData;
        deviceInfoData.cbSize = sizeof(SP_DEVINFO_DATA);

        for (DWORD i = 0; SetupDiEnumDeviceInfo(deviceInfoSet, i, &deviceInfoData); ++i) {
            TCHAR deviceInstanceId[512];
            if (SetupDiGetDeviceInstanceId(deviceInfoSet, &deviceInfoData, deviceInstanceId, sizeof(deviceInstanceId) / sizeof(TCHAR), NULL)) {
                // Example: "USB\\VID_2341&PID_0043\\55639313337351D011E1"
                std::basic_string<TCHAR> instance(deviceInstanceId);

                // Retrieve COM port name
                HKEY hDeviceRegistryKey = SetupDiOpenDevRegKey(
                    deviceInfoSet,
                    &deviceInfoData,
                    DICS_FLAG_GLOBAL,
                    0,
                    DIREG_DEV,
                    KEY_READ
                );

                if (hDeviceRegistryKey == INVALID_HANDLE_VALUE)
                    continue;

                TCHAR portName[256];
                DWORD size = sizeof(portName);
                DWORD type = 0;

                if (RegQueryValueEx(hDeviceRegistryKey, _T("PortName"), NULL, &type, (LPBYTE)portName, &size) == ERROR_SUCCESS) {
                    std::basic_string<TCHAR> tsPortName(portName);
                    std::wstring comPort(tsPortName.begin(), tsPortName.end());

                    if (comPort == targetPort) {
                        // std::wcout << L"Found port: " << tsPortName.c_str() << std::endl;
                        // std::wcout << L"Instance ID: " << instance << std::endl;

                        // Extract VID and PID
                        size_t vidPos = instance.find(_T("VID_"));
                        size_t pidPos = instance.find(_T("PID_"));

                        if (vidPos != std::basic_string<TCHAR>::npos && pidPos != std::basic_string<TCHAR>::npos) {
                            std::basic_string<TCHAR> vid = instance.substr(vidPos + 4, 4);
                            std::basic_string<TCHAR> pid = instance.substr(pidPos + 4, 4);
                            std::string strVid(vid.begin(), vid.end());
                            std::string strPid(pid.begin(), pid.end());

                            if (strVid == targetVid && strPid == targetPid) {
                                ret = true;
                                RegCloseKey(hDeviceRegistryKey);
                                break;
                            }
                        }
                    }
                }

                RegCloseKey(hDeviceRegistryKey);
            }
        }

        SetupDiDestroyDeviceInfoList(deviceInfoSet);

        return ret;
    }

    Serial::Serial(std::wstring port, BAUD baudrate): handle_(NULL), port_(port), baudrate_(baudrate), read_timeout_(READ_TIMEOUT)
    {
        DCB dcb;
        bool ret;

        TCHAR *port_name = new TCHAR[port.length() + 8]; // +8 for "\\\\.\\"
        if (!port_name) {
            throw SerialException("Failed to allocate memory for port name");
        }

        if (_stprintf(port_name, TEXT("\\\\.\\%s"), port.c_str()) < 0) {
            delete[] port_name;
            throw SerialException("Failed to format port name");
        }

        if (!check_com_port_vid_pid(port_)) {
            delete[] port_name;
            std::string narrow_port(port_.begin(), port_.end());
            throw SerialException(MCP2221_NAME " not found or not connected on port " + narrow_port);
        }

        handle_ = CreateFile(port_name,
                             GENERIC_READ | GENERIC_WRITE,
                             0,
                             NULL,
                             OPEN_EXISTING,
                             FILE_ATTRIBUTE_NORMAL,
                             NULL);

        if (handle_ == INVALID_HANDLE_VALUE) {
            delete[] port_name;
            std::string narrow_port(port_.begin(), port_.end());
            throw SerialException("Failed to open serial port " + narrow_port);
        }

        SecureZeroMemory(&dcb, sizeof(DCB));
        dcb.DCBlength = sizeof(DCB);

        ret = GetCommState(handle_, &dcb);
        if (!ret) {
            CloseHandle(handle_);
            delete[] port_name;
            throw SerialException("Failed to get serial port state");
        }
        
        dcb.BaudRate = baud_map[baudrate_];
        dcb.ByteSize = 8;
        dcb.Parity   = NOPARITY;
        dcb.StopBits = ONESTOPBIT;
        ret = SetCommState(handle_, &dcb);
        if (!ret) {
            CloseHandle(handle_);
            delete[] port_name;
            throw SerialException("Failed to set serial port state");
        }

        set_timeout(read_timeout_);
    }   

    // Implementations of pure virtual destructor
    Serial::~Serial() 
    {
        delete[] port_.c_str();
        if (handle_ != INVALID_HANDLE_VALUE) {
            CloseHandle(handle_);
        }
    }

    void Serial::send(uint8_t *data, size_t length) 
    {
        DWORD bytes_written = 0;
        bool ret;

        ret = WriteFile(handle_, data, length, &bytes_written, NULL);
        if (!ret) {
            throw SerialException("Failed to send data");
        }
    }

    void Serial::receive(uint8_t *data, size_t length) 
    {
        DWORD bytes_read = 0;
        bool ret;

        ret = ReadFile(handle_, data, length, &bytes_read, NULL);
        if (!ret) {
            throw SerialException("Failed to receive data");
        }

        if (bytes_read < length) {
            throw TimeoutException("Timeout expired and " + std::to_string(bytes_read) + "/" + std::to_string(length) + " bytes were received");
        }
    }

    void Serial::set_baudrate(BAUD baudrate)
    {
        DCB dcb;
        bool ret;

        if (baudrate_ == baudrate)
            return;
        
        SecureZeroMemory(&dcb, sizeof(DCB));
        dcb.DCBlength = sizeof(DCB);

        ret = GetCommState(handle_, &dcb);
        if (!ret) {
            throw SerialException("Failed to get serial port state");
        }

        dcb.BaudRate = baud_map[baudrate];

        ret = SetCommState(handle_, &dcb);
        if (!ret) {
            throw SerialException("Failed to set serial port state");
        }

        baudrate_ = baudrate;
    }

    void Serial::set_timeout(DWORD timeout_ms)
    {
        if (read_timeout_ == timeout_ms)
            return;

        COMMTIMEOUTS timeouts;
        SecureZeroMemory(&timeouts, sizeof(COMMTIMEOUTS));
        
        if (!GetCommTimeouts(handle_, &timeouts)) {
            throw SerialException("Failed to get serial port timeouts");
        }

        timeouts.ReadTotalTimeoutConstant = timeout_ms;

        if (!SetCommTimeouts(handle_, &timeouts)) {
            throw SerialException("Failed to set serial port timeouts");
        }

        read_timeout_ = timeout_ms;
    }
}
