diff --git a/doc/ru/scripting/builtins/libnetwork.md b/doc/ru/scripting/builtins/libnetwork.md index 0d33acf1..c9cc890c 100644 --- a/doc/ru/scripting/builtins/libnetwork.md +++ b/doc/ru/scripting/builtins/libnetwork.md @@ -2,7 +2,7 @@ Библиотека для работы с сетью. -## HTTP-запросы +## HTTP-Запросы ```lua -- Выполняет GET запрос к указанному URL. @@ -98,6 +98,65 @@ server:is_open() --> bool server:get_port() --> int ``` +## UDP-Датаграммы + +```lua +network.udp_connect( + address: str, + port: int, + -- Функция, вызываемая при получении датаграммы с указанного при открытии сокета адреса и порта + datagramHandler: function(Bytearray), + -- Функция, вызываемая после открытия сокета + -- Опциональна, так как в UDP нет handshake + [опционально] openCallback: function(WriteableSocket), +) --> WriteableSocket +``` + +Открывает UDP-сокет с привязкой к удалённому адресу и порту + +Класс WriteableSocket имеет следующие методы: + +```lua +-- Отправляет датаграмму на адрес и порт, заданные при открытии сокета +socket:send(table|Bytearray|str) + +-- Закрывает сокет +socket:close() + +-- Проверяет открыт ли сокет +socket:is_open() --> bool + +-- Возвращает адрес и порт, на которые привязан сокет +socket:get_address() --> str, int +``` + +```lua +network.udp_open( + port: int, + -- Функция, вызываемая при получении датаграмы + -- В параметры передаётся адрес и порт отправителя, а также сами данные + datagramHandler: function(address: str, port: int, data: Bytearray, server: DatagramServerSocket) +) --> DatagramServerSocket +``` + +Открывает UDP-сервер на указанном порту + +Класс DatagramServerSocket имеет следующие методы: + +```lua +-- Отправляет датаграмму на переданный адрес и порт +server:send(address: str, port: int, data: table|Bytearray|str) + +-- Завершает принятие датаграмм +server:stop() + +-- Проверяет возможность принятия датаграмм +server:is_open() --> bool + +-- Возвращает порт, который слушает сервер +server:get_port() --> int +``` + ## Аналитика ```lua diff --git a/res/scripts/classes.lua b/res/scripts/classes.lua index 41406892..e63d5a5a 100644 --- a/res/scripts/classes.lua +++ b/res/scripts/classes.lua @@ -46,18 +46,35 @@ local Socket = {__index={ get_address=function(self) return network.__get_address(self.id) end, }} +local WriteableSocket = {__index={ + send=function(self, ...) return network.__send(self.id, ...) end, + close=function(self) return network.__close(self.id) end, + is_open=function(self) return network.__is_alive(self.id) end, + get_address=function(self) return network.__get_address(self.id) end, +}} + local ServerSocket = {__index={ close=function(self) return network.__closeserver(self.id) end, is_open=function(self) return network.__is_serveropen(self.id) end, get_port=function(self) return network.__get_serverport(self.id) end, }} +local DatagramServerSocket = {__index={ + close=function(self) return network.__closeserver(self.id) end, + is_open=function(self) return network.__is_serveropen(self.id) end, + get_port=function(self) return network.__get_serverport(self.id) end, + send=function(self, ...) return network.__udp_server_send_to(self.id, ...) end +}} local _tcp_server_callbacks = {} local _tcp_client_callbacks = {} +local _udp_server_callbacks = {} +local _udp_client_datagram_callbacks = {} +local _udp_client_open_callbacks = {} + network.tcp_open = function (port, handler) - local socket = setmetatable({id=network.__open(port)}, ServerSocket) + local socket = setmetatable({id=network.__open_tcp(port)}, ServerSocket) _tcp_server_callbacks[socket.id] = function(id) handler(setmetatable({id=id}, Socket)) @@ -67,19 +84,63 @@ end network.tcp_connect = function(address, port, callback) local socket = setmetatable({id=0}, Socket) - socket.id = network.__connect(address, port) + socket.id = network.__connect_tcp(address, port) _tcp_client_callbacks[socket.id] = function() callback(socket) end return socket end +network.udp_open = function (port, datagramHandler) + if type(datagramHandler) ~= 'function' then + error "udp server cannot be opened without datagram handler" + end + + local socket = setmetatable({id=network.__open_udp(port)}, DatagramServerSocket) + + _udp_server_callbacks[socket.id] = function(address, port, data) + datagramHandler(address, port, data, socket) + end + + return socket +end + +network.udp_connect = function (address, port, datagramHandler, openCallback) + if type(datagramHandler) ~= 'function' then + error "udp client socket cannot be opened without datagram handler" + end + + local socket = setmetatable({id=0}, WriteableSocket) + socket.id = network.__connect_udp(address, port) + + _udp_client_datagram_callbacks[socket.id] = datagramHandler + _udp_client_open_callbacks[socket.id] = openCallback + + return socket +end + +local function clean(iterable, checkFun, ...) + local tables = { ... } + + for id, _ in pairs(iterable) do + if not checkFun(id) then + for i = 1, #tables do + tables[i][id] = nil + end + end + end +end + network.__process_events = function() local CLIENT_CONNECTED = 1 local CONNECTED_TO_SERVER = 2 + local DATAGRAM = 3 + + local ON_SERVER = 1 + local ON_CLIENT = 2 local cleaned = false local events = network.__pull_events() for i, event in ipairs(events) do - local etype, sid, cid = unpack(event) + local etype, sid, cid, addr, port, side, data = unpack(event) if etype == CLIENT_CONNECTED then local callback = _tcp_server_callbacks[sid] @@ -87,24 +148,26 @@ network.__process_events = function() callback(cid) end elseif etype == CONNECTED_TO_SERVER then - local callback = _tcp_client_callbacks[cid] + local callback = _tcp_client_callbacks[cid] or _udp_client_open_callbacks[cid] if callback then callback() end + elseif etype == DATAGRAM then + if side == ON_CLIENT then + _udp_client_datagram_callbacks[cid](data) + elseif side == ON_SERVER then + _udp_server_callbacks[sid](addr, port, data) + end end -- remove dead servers if not cleaned then - for sid, _ in pairs(_tcp_server_callbacks) do - if not network.__is_serveropen(sid) then - _tcp_server_callbacks[sid] = nil - end - end - for cid, _ in pairs(_tcp_client_callbacks) do - if not network.__is_alive(cid) then - _tcp_client_callbacks[cid] = nil - end - end + clean(_tcp_server_callbacks, network.__is_serveropen, _tcp_server_callbacks) + clean(_tcp_client_callbacks, network.__is_alive, _tcp_client_callbacks) + + clean(_udp_server_callbacks, network.__is_serveropen, _udp_server_callbacks) + clean(_udp_client_datagram_callbacks, network.__is_alive, _udp_client_open_callbacks, _udp_client_datagram_callbacks) + cleaned = true end end diff --git a/src/logic/scripting/lua/libs/libnetwork.cpp b/src/logic/scripting/lua/libs/libnetwork.cpp index ed6799bd..94390c93 100644 --- a/src/logic/scripting/lua/libs/libnetwork.cpp +++ b/src/logic/scripting/lua/libs/libnetwork.cpp @@ -1,8 +1,7 @@ #include "api_lua.hpp" - +#include "coders/json.hpp" #include "engine/Engine.hpp" #include "network/Network.hpp" -#include "coders/json.hpp" using namespace scripting; @@ -134,17 +133,58 @@ static int l_send(lua::State* L, network::Network& network) { return 0; } +static int l_udp_server_send_to(lua::State* L, network::Network& network) { + u64id_t id = lua::tointeger(L, 1); + + if (auto server = network.getServer(id)) { + if (server->getTransportType() != network::TransportType::UDP) + throw std::runtime_error("the server must work on UDP transport"); + + const std::string& addr = lua::tostring(L, 2); + const int& port = lua::tointeger(L, 3); + + auto udpServer = dynamic_cast(server); + + if (lua::istable(L, 4)) { + lua::pushvalue(L, 4); + size_t size = lua::objlen(L, 4); + util::Buffer buffer(size); + for (size_t i = 0; i < size; i++) { + lua::rawgeti(L, i + 1); + buffer[i] = lua::tointeger(L, -1); + lua::pop(L); + } + lua::pop(L); + udpServer->sendTo(addr, port, buffer.data(), size); + } else if (lua::isstring(L, 4)) { + auto string = lua::tolstring(L, 4); + udpServer->sendTo(addr, port, string.data(), string.length()); + } else { + auto string = lua::bytearray_as_string(L, 4); + udpServer->sendTo(addr, port, string.data(), string.length()); + lua::pop(L); + } + } + + return 0; +} + static int l_recv(lua::State* L, network::Network& network) { u64id_t id = lua::tointeger(L, 1); int length = lua::tointeger(L, 2); + auto connection = engine->getNetwork().getConnection(id); - if (connection == nullptr) { + + if (connection == nullptr || connection->getTransportType() != network::TransportType::TCP) { return 0; } - length = glm::min(length, connection->available()); + + auto tcpConnection = dynamic_cast(connection); + + length = glm::min(length, tcpConnection->available()); util::Buffer buffer(length); - int size = connection->recv(buffer.data(), length); + int size = tcpConnection->recv(buffer.data(), length); if (size == -1) { return 0; } @@ -162,38 +202,123 @@ static int l_recv(lua::State* L, network::Network& network) { static int l_available(lua::State* L, network::Network& network) { u64id_t id = lua::tointeger(L, 1); + if (auto connection = network.getConnection(id)) { - return lua::pushinteger(L, connection->available()); + return lua::pushinteger(L, dynamic_cast(connection)->available()); } + return 0; } enum NetworkEventType { CLIENT_CONNECTED = 1, - CONNECTED_TO_SERVER + CONNECTED_TO_SERVER, + DATAGRAM }; struct NetworkEvent { NetworkEventType type; u64id_t server; u64id_t client; + + NetworkEvent( + NetworkEventType type, + u64id_t server, + u64id_t client + ) { + this->type = type; + this->server = server; + this->client = client; + } + + virtual ~NetworkEvent() = default; }; -static std::vector events_queue {}; +enum NetworkDatagramSide { + ON_SERVER = 1, + ON_CLIENT +}; -static int l_connect(lua::State* L, network::Network& network) { +struct NetworkDatagramEvent : NetworkEvent { + NetworkDatagramSide side; + std::string addr; + int port; + const char* buffer; + size_t length; + + NetworkDatagramEvent( + NetworkEventType datagram, + u64id_t sid, + u64id_t cid, + NetworkDatagramSide side, + const std::string& addr, + int port, + const char* data, + size_t length + ) : NetworkEvent(DATAGRAM, sid, cid) { + this->side = side; + this->addr = addr; + this->port = port; + + buffer = data; + + this->length = length; + } +}; + +static std::vector> events_queue {}; + +static int l_connect_tcp(lua::State* L, network::Network& network) { std::string address = lua::require_string(L, 1); int port = lua::tointeger(L, 2); - u64id_t id = network.connect(address, port, [](u64id_t cid) { - events_queue.push_back({CONNECTED_TO_SERVER, 0, cid}); + u64id_t id = network.connectTcp(address, port, [](u64id_t cid) { + events_queue.push_back(std::make_unique(CONNECTED_TO_SERVER, 0, cid)); }); return lua::pushinteger(L, id); } -static int l_open(lua::State* L, network::Network& network) { +static int l_open_tcp(lua::State* L, network::Network& network) { int port = lua::tointeger(L, 1); - u64id_t id = network.openServer(port, [](u64id_t sid, u64id_t id) { - events_queue.push_back({CLIENT_CONNECTED, sid, id}); + u64id_t id = network.openTcpServer(port, [](u64id_t sid, u64id_t id) { + events_queue.push_back(std::make_unique(CLIENT_CONNECTED, sid, id)); + }); + return lua::pushinteger(L, id); +} + +static int l_connect_udp(lua::State* L, network::Network& network) { + std::string address = lua::require_string(L, 1); + int port = lua::tointeger(L, 2); + u64id_t id = network.connectUdp(address, port, [](u64id_t cid) { + events_queue.push_back(std::make_unique(CONNECTED_TO_SERVER, 0, cid)); + }, [address, port]( + u64id_t cid, + const char* buffer, + size_t length + ) { + events_queue.push_back( + std::make_unique( + DATAGRAM, 0, cid, ON_CLIENT, + address, port, buffer, length + ) + ); + }); + return lua::pushinteger(L, id); +} + +static int l_open_udp(lua::State* L, network::Network& network) { + int port = lua::tointeger(L, 1); + u64id_t id = network.openUdpServer(port, []( + u64id_t sid, + const std::string& addr, + int port, + const char* buffer, + size_t length) { + events_queue.push_back( + std::make_unique( + DATAGRAM, sid, 0, ON_SERVER, + addr, port, buffer, length + ) + ); }); return lua::pushinteger(L, id); } @@ -204,7 +329,10 @@ static int l_is_alive(lua::State* L, network::Network& network) { return lua::pushboolean( L, connection->getState() != network::ConnectionState::CLOSED || - connection->available() > 0 + ( + connection->getTransportType() == network::TransportType::TCP && + dynamic_cast(connection)->available() > 0 + ) ); } return lua::pushboolean(L, false); @@ -256,17 +384,34 @@ static int l_get_total_download(lua::State* L, network::Network& network) { static int l_pull_events(lua::State* L, network::Network& network) { lua::createtable(L, events_queue.size(), 0); - for (size_t i = 0; i < events_queue.size(); i++) { - lua::createtable(L, 3, 0); - lua::pushinteger(L, events_queue[i].type); + for (size_t i = 0; i < events_queue.size(); i++) { + const auto* datagramEvent = dynamic_cast(events_queue[i].get()); + + lua::createtable(L, datagramEvent ? 7 : 3, 0); + + lua::pushinteger(L, events_queue[i]->type); lua::rawseti(L, 1); - lua::pushinteger(L, events_queue[i].server); + lua::pushinteger(L, events_queue[i]->server); lua::rawseti(L, 2); - lua::pushinteger(L, events_queue[i].client); + lua::pushinteger(L, events_queue[i]->client); lua::rawseti(L, 3); + + if (datagramEvent) { + lua::pushstring(L, datagramEvent->addr); + lua::rawseti(L, 4); + + lua::pushinteger(L, datagramEvent->port); + lua::rawseti(L, 5); + + lua::pushinteger(L, datagramEvent->side); + lua::rawseti(L, 6); + + lua::create_bytearray(L, datagramEvent->buffer, datagramEvent->length); + lua::rawseti(L, 7); + } lua::rawseti(L, i + 1); } @@ -298,9 +443,12 @@ const luaL_Reg networklib[] = { {"get_total_upload", wrap}, {"get_total_download", wrap}, {"__pull_events", wrap}, - {"__open", wrap}, + {"__open_tcp", wrap}, + {"__open_udp", wrap}, {"__closeserver", wrap}, - {"__connect", wrap}, + {"__udp_server_send_to", wrap}, + {"__connect_tcp", wrap}, + {"__connect_udp", wrap}, {"__close", wrap}, {"__send", wrap}, {"__recv", wrap}, diff --git a/src/network/Network.cpp b/src/network/Network.cpp index 8936adc6..234197e4 100644 --- a/src/network/Network.cpp +++ b/src/network/Network.cpp @@ -291,7 +291,7 @@ static std::string to_string(const sockaddr_in& addr, bool port=true) { return ""; } -class SocketConnection : public Connection { +class SocketTcpConnection : public TcpConnection { SOCKET descriptor; sockaddr_in addr; size_t totalUpload = 0; @@ -317,10 +317,10 @@ class SocketConnection : public Connection { state = ConnectionState::CONNECTED; } public: - SocketConnection(SOCKET descriptor, sockaddr_in addr) + SocketTcpConnection(SOCKET descriptor, sockaddr_in addr) : descriptor(descriptor), addr(std::move(addr)), buffer(16'384) {} - ~SocketConnection() { + ~SocketTcpConnection() { if (state != ConnectionState::CLOSED) { shutdown(descriptor, 2); } @@ -442,7 +442,7 @@ public: return to_string(addr, false); } - static std::shared_ptr connect( + static std::shared_ptr connect( const std::string& address, int port, runnable callback ) { addrinfo hints {}; @@ -466,7 +466,7 @@ public: if (descriptor == -1) { throw std::runtime_error("Could not create socket"); } - auto socket = std::make_shared(descriptor, std::move(serverAddress)); + auto socket = std::make_shared(descriptor, std::move(serverAddress)); socket->connect(std::move(callback)); return socket; } @@ -476,7 +476,7 @@ public: } }; -class SocketTcpSServer : public TcpServer { +class SocketTcpServer : public TcpServer { u64id_t id; Network* network; SOCKET descriptor; @@ -486,10 +486,10 @@ class SocketTcpSServer : public TcpServer { std::unique_ptr thread = nullptr; int port; public: - SocketTcpSServer(u64id_t id, Network* network, SOCKET descriptor, int port) + SocketTcpServer(u64id_t id, Network* network, SOCKET descriptor, int port) : id(id), network(network), descriptor(descriptor), port(port) {} - ~SocketTcpSServer() { + ~SocketTcpServer() { closeSocket(); } @@ -510,7 +510,7 @@ public: break; } logger.info() << "client connected: " << to_string(address); - auto socket = std::make_shared( + auto socket = std::make_shared( clientDescriptor, address ); socket->startClient(); @@ -558,7 +558,7 @@ public: return port; } - static std::shared_ptr openServer( + static std::shared_ptr openServer( u64id_t id, Network* network, int port, ConnectCallback handler ) { SOCKET descriptor = socket( @@ -586,7 +586,222 @@ public: } logger.info() << "opened server at port " << port; auto server = - std::make_shared(id, network, descriptor, port); + std::make_shared(id, network, descriptor, port); + server->startListen(std::move(handler)); + return server; + } +}; + +class SocketUdpConnection : public UdpConnection { + u64id_t id; + SOCKET descriptor; + sockaddr_in addr{}; + bool open = true; + std::unique_ptr thread; + ClientDatagramCallback callback; + + size_t totalUpload = 0; + size_t totalDownload = 0; + ConnectionState state = ConnectionState::INITIAL; + +public: + SocketUdpConnection(u64id_t id, SOCKET descriptor, sockaddr_in addr) + : id(id), descriptor(descriptor), addr(std::move(addr)) {} + + ~SocketUdpConnection() override { + SocketUdpConnection::close(); + } + + static std::shared_ptr connect( + u64id_t id, + const std::string& address, + int port, + ClientDatagramCallback handler, + runnable callback + ) { + SOCKET descriptor = socket(AF_INET, SOCK_DGRAM, 0); + if (descriptor == -1) { + throw std::runtime_error("could not create UDP socket"); + } + + sockaddr_in serverAddr{}; + serverAddr.sin_family = AF_INET; + if (inet_pton(AF_INET, address.c_str(), &serverAddr.sin_addr) <= 0) { + closesocket(descriptor); + throw std::runtime_error("invalid UDP address: " + address); + } + serverAddr.sin_port = htons(port); + + if (::connect(descriptor, (sockaddr*)&serverAddr, sizeof(serverAddr)) < 0) { + auto err = handle_socket_error("UDP connect failed"); + closesocket(descriptor); + throw err; + } + + auto socket = std::make_shared(id, descriptor, serverAddr); + socket->connect(std::move(handler)); + + callback(); + + return socket; + } + + void connect(ClientDatagramCallback handler) override { + callback = std::move(handler); + state = ConnectionState::CONNECTED; + + thread = std::make_unique([this]() { + util::Buffer buffer(16'384); + while (open) { + int size = recv(descriptor, buffer.data(), buffer.size(), 0); + if (size <= 0) { + if (!open) break; + closesocket(descriptor); + state = ConnectionState::CLOSED; + break; + } + totalDownload += size; + if (callback) { + callback(id, buffer.data(), size); + } + } + }); + } + + int send(const char* buffer, size_t length) override { + int len = sendto(descriptor, buffer, length, 0, + (sockaddr*)&addr, sizeof(addr)); + if (len < 0) { + closesocket(descriptor); + state = ConnectionState::CLOSED; + } else totalUpload += len; + + return len; + } + + void close(bool discardAll=false) override { + if (!open) return; + open = false; + + if (state != ConnectionState::CLOSED) { + shutdown(descriptor, 2); + closesocket(descriptor); + } + + if (thread) { + thread->join(); + thread.reset(); + } + state = ConnectionState::CLOSED; + } + + size_t pullUpload() override { + size_t s = totalUpload; + totalUpload = 0; + return s; + } + + size_t pullDownload() override { + size_t s = totalDownload; + totalDownload = 0; + return s; + } + + [[nodiscard]] int getPort() const override { + return ntohs(addr.sin_port); + } + + [[nodiscard]] std::string getAddress() const override { + return to_string(addr, false); + } + + [[nodiscard]] ConnectionState getState() const override { + return state; + } +}; + +class SocketUdpServer : public UdpServer { + u64id_t id; + Network* network; + SOCKET descriptor; + bool open = true; + std::unique_ptr thread = nullptr; + int port; + ServerDatagramCallback callback; + +public: + SocketUdpServer(u64id_t id, Network* network, SOCKET descriptor, int port) + : id(id), network(network), descriptor(descriptor), port(port) {} + + ~SocketUdpServer() override { + SocketUdpServer::close(); + } + + void startListen(ServerDatagramCallback handler) { + callback = std::move(handler); + + thread = std::make_unique([this]() { + util::Buffer buffer(16384); + sockaddr_in clientAddr{}; + socklen_t addrlen = sizeof(clientAddr); + + while (open) { + int size = recvfrom(descriptor, buffer.data(), buffer.size(), 0, + reinterpret_cast(&clientAddr), &addrlen); + if (size <= 0) { + if (!open) break; + continue; + } + + std::string addrStr = to_string(clientAddr, false); + int port = ntohs(clientAddr.sin_port); + + callback(id, addrStr, port, buffer.data(), size); + } + }); + } + + void sendTo(const std::string& addr, int port, const char* buffer, size_t length) override { + sockaddr_in client{}; + client.sin_family = AF_INET; + inet_pton(AF_INET, addr.c_str(), &client.sin_addr); + client.sin_port = htons(port); + + sendto(descriptor, buffer, length, 0, + reinterpret_cast(&client), sizeof(client)); + } + + void close() override { + if (!open) return; + open = false; + shutdown(descriptor, 2); + closesocket(descriptor); + if (thread) { + thread->join(); + thread = nullptr; + } + } + + bool isOpen() override { return open; } + int getPort() const override { return port; } + + static std::shared_ptr openServer( + u64id_t id, Network* network, int port, const ServerDatagramCallback& handler + ) { + SOCKET descriptor = socket(AF_INET, SOCK_DGRAM, 0); + if (descriptor == -1) throw std::runtime_error("Could not create UDP socket"); + + sockaddr_in address{}; + address.sin_family = AF_INET; + address.sin_addr.s_addr = INADDR_ANY; + address.sin_port = htons(port); + + if (bind(descriptor, (sockaddr*)&address, sizeof(address)) < 0) { + closesocket(descriptor); + throw std::runtime_error("Could not bind UDP port " + std::to_string(port)); + } + + auto server = std::make_shared(id, network, descriptor, port); server->startListen(std::move(handler)); return server; } @@ -627,7 +842,7 @@ Connection* Network::getConnection(u64id_t id) { return found->second.get(); } -TcpServer* Network::getServer(u64id_t id) const { +Server* Network::getServer(u64id_t id) const { const auto& found = servers.find(id); if (found == servers.end()) { return nullptr; @@ -635,20 +850,38 @@ TcpServer* Network::getServer(u64id_t id) const { return found->second.get(); } -u64id_t Network::connect(const std::string& address, int port, consumer callback) { +u64id_t Network::connectTcp(const std::string& address, int port, consumer callback) { std::lock_guard lock(connectionsMutex); u64id_t id = nextConnection++; - auto socket = SocketConnection::connect(address, port, [id, callback]() { + auto socket = SocketTcpConnection::connect(address, port, [id, callback]() { callback(id); }); connections[id] = std::move(socket); return id; } -u64id_t Network::openServer(int port, ConnectCallback handler) { +u64id_t Network::openTcpServer(int port, ConnectCallback handler) { u64id_t id = nextServer++; - auto server = SocketTcpSServer::openServer(id, this, port, handler); + auto server = SocketTcpServer::openServer(id, this, port, handler); + servers[id] = std::move(server); + return id; +} + +u64id_t Network::connectUdp(const std::string& address, int port, const consumer& callback, ClientDatagramCallback handler) { + std::lock_guard lock(connectionsMutex); + + u64id_t id = nextConnection++; + auto socket = SocketUdpConnection::connect(id, address, port, std::move(handler), [id, callback]() { + callback(id); + }); + connections[id] = std::move(socket); + return id; +} + +u64id_t Network::openUdpServer(int port, const ServerDatagramCallback& handler) { + u64id_t id = nextServer++; + auto server = SocketUdpServer::openServer(id, this, port, handler); servers[id] = std::move(server); return id; } @@ -679,7 +912,10 @@ void Network::update() { auto socket = socketiter->second.get(); totalDownload += socket->pullDownload(); totalUpload += socket->pullUpload(); - if (socket->available() == 0 && + if ( + ( socket->getTransportType() == TransportType::UDP || + dynamic_cast(socket)->available() == 0 + ) && socket->getState() == ConnectionState::CLOSED) { socketiter = connections.erase(socketiter); continue; diff --git a/src/network/Network.hpp b/src/network/Network.hpp index 934642fa..ddde70b5 100644 --- a/src/network/Network.hpp +++ b/src/network/Network.hpp @@ -13,6 +13,9 @@ namespace network { using OnResponse = std::function)>; using OnReject = std::function; using ConnectCallback = std::function; + using ServerDatagramCallback = std::function; + using ClientDatagramCallback = std::function; + class Requests { public: @@ -33,8 +36,8 @@ namespace network { long maxSize=0 ) = 0; - virtual size_t getTotalUpload() const = 0; - virtual size_t getTotalDownload() const = 0; + [[nodiscard]] virtual size_t getTotalUpload() const = 0; + [[nodiscard]] virtual size_t getTotalDownload() const = 0; virtual void update() = 0; }; @@ -43,32 +46,82 @@ namespace network { INITIAL, CONNECTING, CONNECTED, CLOSED }; + enum class TransportType { + TCP, UDP + }; + class Connection { public: - virtual ~Connection() {} + virtual ~Connection() = default; - virtual void connect(runnable callback) = 0; - virtual int recv(char* buffer, size_t length) = 0; - virtual int send(const char* buffer, size_t length) = 0; virtual void close(bool discardAll=false) = 0; - virtual int available() = 0; + + virtual int send(const char* buffer, size_t length) = 0; virtual size_t pullUpload() = 0; virtual size_t pullDownload() = 0; - virtual int getPort() const = 0; - virtual std::string getAddress() const = 0; + [[nodiscard]] virtual int getPort() const = 0; + [[nodiscard]] virtual std::string getAddress() const = 0; - virtual ConnectionState getState() const = 0; + [[nodiscard]] virtual ConnectionState getState() const = 0; + + [[nodiscard]] virtual TransportType getTransportType() const noexcept = 0; }; - class TcpServer { + class TcpConnection : public Connection { public: - virtual ~TcpServer() {} - virtual void startListen(ConnectCallback handler) = 0; + ~TcpConnection() override = default; + + virtual void connect(runnable callback) = 0; + virtual int recv(char* buffer, size_t length) = 0; + virtual int available() = 0; + + [[nodiscard]] TransportType getTransportType() const noexcept override { + return TransportType::TCP; + } + }; + + class UdpConnection : public Connection { + public: + ~UdpConnection() override = default; + + virtual void connect(ClientDatagramCallback handler) = 0; + + [[nodiscard]] TransportType getTransportType() const noexcept override { + return TransportType::UDP; + } + }; + + class Server { + public: + virtual ~Server() = default; virtual void close() = 0; virtual bool isOpen() = 0; - virtual int getPort() const = 0; + [[nodiscard]] virtual TransportType getTransportType() const noexcept = 0; + [[nodiscard]] virtual int getPort() const = 0; + }; + + class TcpServer : public Server { + public: + ~TcpServer() override {} + virtual void startListen(ConnectCallback handler) = 0; + + [[nodiscard]] TransportType getTransportType() const noexcept override { + return TransportType::TCP; + } + }; + + class UdpServer : public Server { + public: + ~UdpServer() override {} + virtual void startListen(ServerDatagramCallback handler) = 0; + + virtual void sendTo(const std::string& addr, int port, const char* buffer, size_t length) = 0; + + [[nodiscard]] TransportType getTransportType() const noexcept override { + return TransportType::UDP; + } }; class Network { @@ -78,7 +131,7 @@ namespace network { std::mutex connectionsMutex {}; u64id_t nextConnection = 1; - std::unordered_map> servers; + std::unordered_map> servers; u64id_t nextServer = 1; size_t totalDownload = 0; @@ -103,16 +156,18 @@ namespace network { ); [[nodiscard]] Connection* getConnection(u64id_t id); - [[nodiscard]] TcpServer* getServer(u64id_t id) const; + [[nodiscard]] Server* getServer(u64id_t id) const; - u64id_t connect(const std::string& address, int port, consumer callback); + u64id_t connectTcp(const std::string& address, int port, consumer callback); + u64id_t connectUdp(const std::string& address, int port, const consumer& callback, ClientDatagramCallback handler); - u64id_t openServer(int port, ConnectCallback handler); + u64id_t openTcpServer(int port, ConnectCallback handler); + u64id_t openUdpServer(int port, const ServerDatagramCallback& handler); u64id_t addConnection(const std::shared_ptr& connection); - size_t getTotalUpload() const; - size_t getTotalDownload() const; + [[nodiscard]] size_t getTotalUpload() const; + [[nodiscard]] size_t getTotalDownload() const; void update();