From f8be3e74f19df631d22456a0e9ed1048839c25f9 Mon Sep 17 00:00:00 2001 From: MihailRis Date: Sat, 15 Nov 2025 12:46:36 +0300 Subject: [PATCH] add error callback to network.tcp_connect --- res/scripts/classes.lua | 12 ++++++- src/logic/scripting/lua/libs/libnetwork.cpp | 13 ++++++- src/network/Network.cpp | 14 ++++++-- src/network/Network.hpp | 16 +++++++-- src/network/Sockets.cpp | 40 +++++++++++++++------ src/network/commons.hpp | 1 + 6 files changed, 79 insertions(+), 17 deletions(-) diff --git a/res/scripts/classes.lua b/res/scripts/classes.lua index 27a73c97..950544bf 100644 --- a/res/scripts/classes.lua +++ b/res/scripts/classes.lua @@ -70,6 +70,7 @@ local DatagramServerSocket = {__index={ local _tcp_server_callbacks = {} local _tcp_client_callbacks = {} +local _tcp_client_error_callbacks = {} local _udp_server_callbacks = {} local _udp_client_datagram_callbacks = {} @@ -116,10 +117,13 @@ network.tcp_open = function (port, handler) return socket end -network.tcp_connect = function(address, port, callback) +network.tcp_connect = function(address, port, callback, errorCallback) local socket = setmetatable({id=0}, Socket) socket.id = network.__connect_tcp(address, port) _tcp_client_callbacks[socket.id] = function() callback(socket) end + if errorCallback then + _tcp_client_error_callbacks[socket.id] = function(message) errorCallback(socket, message) end + end return socket end @@ -239,6 +243,7 @@ network.__process_events = function() local CONNECTED_TO_SERVER = 2 local DATAGRAM = 3 local RESPONSE = 4 + local CONNECTION_ERROR = 5 local ON_SERVER = 1 local ON_CLIENT = 2 @@ -258,6 +263,11 @@ network.__process_events = function() if callback then callback() end + elseif etype == CONNECTION_ERROR then + local callback = _tcp_client_error_callbacks[cid] + if callback then + callback(addr) + end elseif etype == DATAGRAM then if side == ON_CLIENT then local callback = _udp_client_datagram_callbacks[cid] diff --git a/src/logic/scripting/lua/libs/libnetwork.cpp b/src/logic/scripting/lua/libs/libnetwork.cpp index 78abda63..77fc902e 100644 --- a/src/logic/scripting/lua/libs/libnetwork.cpp +++ b/src/logic/scripting/lua/libs/libnetwork.cpp @@ -13,11 +13,13 @@ enum NetworkEventType { CONNECTED_TO_SERVER, DATAGRAM, RESPONSE, + CONNECTION_ERROR, }; struct ConnectionEventDto { u64id_t server; u64id_t client; + std::string comment {}; }; struct ResponseEventDto { @@ -283,6 +285,11 @@ static int l_connect_tcp(lua::State* L, network::Network& network) { CONNECTED_TO_SERVER, ConnectionEventDto {0, cid} )); + }, [](u64id_t cid, std::string errorMessage) { + push_event(NetworkEvent( + CONNECTION_ERROR, + ConnectionEventDto {0, cid, std::move(errorMessage)} + )); }); return lua::pushinteger(L, id); } @@ -439,7 +446,8 @@ static int l_pull_events(lua::State* L, network::Network& network) { const auto& event = local_queue[i]; switch (event.type) { case CLIENT_CONNECTED: - case CONNECTED_TO_SERVER: { + case CONNECTED_TO_SERVER: + case CONNECTION_ERROR: { const auto& dto = std::get(event.payload); lua::pushinteger(L, event.type); lua::rawseti(L, 1); @@ -449,6 +457,9 @@ static int l_pull_events(lua::State* L, network::Network& network) { lua::pushinteger(L, dto.client); lua::rawseti(L, 3); + + lua::pushlstring(L, dto.comment); + lua::rawseti(L, 4); break; } case DATAGRAM: { diff --git a/src/network/Network.cpp b/src/network/Network.cpp index be3f8d33..ae3a2ea5 100644 --- a/src/network/Network.cpp +++ b/src/network/Network.cpp @@ -16,7 +16,10 @@ namespace network { std::unique_ptr create_curl_requests(); std::shared_ptr connect_tcp( - const std::string& address, int port, runnable callback + const std::string& address, + int port, + runnable callback, + stringconsumer errorCallback ); std::shared_ptr open_tcp_server( @@ -87,12 +90,19 @@ Server* Network::getServer(u64id_t id, bool includePrivate) const { return found->second.get(); } -u64id_t Network::connectTcp(const std::string& address, int port, consumer callback) { +u64id_t Network::connectTcp( + const std::string& address, + int port, + consumer callback, + ConnectErrorCallback errorCallback +) { std::lock_guard lock(connectionsMutex); u64id_t id = nextConnection++; auto socket = connect_tcp(address, port, [id, callback]() { callback(id); + }, [id, errorCallback](auto errorMessage) { + errorCallback(id, errorMessage); }); connections[id] = std::move(socket); return id; diff --git a/src/network/Network.hpp b/src/network/Network.hpp index 1110c750..89e27e59 100644 --- a/src/network/Network.hpp +++ b/src/network/Network.hpp @@ -7,7 +7,7 @@ namespace network { public: ~TcpConnection() override = default; - virtual void connect(runnable callback) = 0; + virtual void connect(runnable callback, stringconsumer errorCallback) = 0; virtual void setNoDelay(bool noDelay) = 0; [[nodiscard]] virtual bool isNoDelay() const = 0; @@ -88,8 +88,18 @@ namespace network { [[nodiscard]] Connection* getConnection(u64id_t id, bool includePrivate); [[nodiscard]] Server* getServer(u64id_t id, bool includePrivate) const; - u64id_t connectTcp(const std::string& address, int port, consumer callback); - u64id_t connectUdp(const std::string& address, int port, const consumer& callback, ClientDatagramCallback handler); + u64id_t connectTcp( + const std::string& address, + int port, + consumer callback, + ConnectErrorCallback errorCallback + ); + u64id_t connectUdp( + const std::string& address, + int port, + const consumer& callback, + ClientDatagramCallback handler + ); u64id_t openTcpServer(int port, ConnectCallback handler); u64id_t openUdpServer(int port, const ServerDatagramCallback& handler); diff --git a/src/network/Sockets.cpp b/src/network/Sockets.cpp index c1b6285f..7b6f425f 100644 --- a/src/network/Sockets.cpp +++ b/src/network/Sockets.cpp @@ -106,6 +106,7 @@ class SocketTcpConnection : public TcpConnection { std::vector readBatch; util::Buffer buffer; std::mutex mutex; + std::string errorMessage; void connectSocket() { state = ConnectionState::CONNECTING; @@ -115,7 +116,8 @@ class SocketTcpConnection : public TcpConnection { auto error = handle_socket_error("Connect failed"); closesocket(descriptor); state = ConnectionState::CLOSED; - logger.error() << error.what(); + errorMessage = error.what(); + logger.error() << errorMessage; return; } logger.info() << "connected to " << to_string(addr); @@ -182,13 +184,15 @@ public: thread = std::make_unique([this]() { startListen();}); } - void connect(runnable callback) override { - thread = std::make_unique([this, callback]() { + void connect(runnable callback, stringconsumer errorCallback) override { + thread = std::make_unique([this, callback, errorCallback]() { connectSocket(); if (state == ConnectionState::CONNECTED) { callback(); + startListen(); + } else { + errorCallback(errorMessage); } - startListen(); }); } @@ -263,7 +267,10 @@ public: } static std::shared_ptr connect( - const std::string& address, int port, runnable callback + const std::string& address, + int port, + runnable callback, + stringconsumer errorCallback ) { addrinfo hints {}; @@ -274,7 +281,11 @@ public: if (int res = getaddrinfo( address.c_str(), nullptr, &hints, &addrinfo )) { - throw std::runtime_error(gai_strerror(res)); + std::string errorMessage = gai_strerror(res); + if (errorCallback) { + errorCallback(errorMessage); + } + throw std::runtime_error(errorMessage); } sockaddr_in serverAddress; @@ -284,10 +295,14 @@ public: SOCKET descriptor = socket(AF_INET, SOCK_STREAM, 0); if (descriptor == -1) { - throw std::runtime_error("Could not create socket"); + std::string errorMessage = "could not create socket"; + if (errorCallback) { + errorCallback(errorMessage); + } + throw std::runtime_error(errorMessage); } auto socket = std::make_shared(descriptor, std::move(serverAddress)); - socket->connect(std::move(callback)); + socket->connect(std::move(callback), std::move(errorCallback)); return socket; } @@ -670,9 +685,14 @@ public: namespace network { std::shared_ptr connect_tcp( - const std::string& address, int port, runnable callback + const std::string& address, + int port, + runnable callback, + stringconsumer errorCallback ) { - return SocketTcpConnection::connect(address, port, std::move(callback)); + return SocketTcpConnection::connect( + address, port, std::move(callback), std::move(errorCallback) + ); } std::shared_ptr open_tcp_server( diff --git a/src/network/commons.hpp b/src/network/commons.hpp index 150bde25..37fc1a01 100644 --- a/src/network/commons.hpp +++ b/src/network/commons.hpp @@ -13,6 +13,7 @@ namespace network { using OnResponse = std::function)>; using OnReject = std::function)>; using ConnectCallback = std::function; + using ConnectErrorCallback = std::function; using ServerDatagramCallback = std::function; using ClientDatagramCallback = std::function;