From a72d36f53c87fd51be36668fb184cd570768ca0f Mon Sep 17 00:00:00 2001 From: MihailRis Date: Wed, 27 Nov 2024 15:38:57 +0300 Subject: [PATCH] add simple inefficient server socket implementation (WIP) --- src/logic/scripting/lua/libs/libnetwork.cpp | 17 ++- src/network/Network.cpp | 141 ++++++++++++++++++-- src/network/Network.hpp | 17 +++ 3 files changed, 164 insertions(+), 11 deletions(-) diff --git a/src/logic/scripting/lua/libs/libnetwork.cpp b/src/logic/scripting/lua/libs/libnetwork.cpp index df60b741..1b3757b1 100644 --- a/src/logic/scripting/lua/libs/libnetwork.cpp +++ b/src/logic/scripting/lua/libs/libnetwork.cpp @@ -42,7 +42,9 @@ static int l_connect(lua::State* L) { lua::pushvalue(L, 3); auto callback = lua::create_lambda(L); u64id_t id = engine->getNetwork().connect(address, port, [callback](u64id_t id) { - callback({id}); + engine->postRunnable([=]() { + callback({id}); + }); }); return lua::pushinteger(L, id); } @@ -109,9 +111,22 @@ static int l_recv(lua::State* L) { return 1; } +static int l_open(lua::State* L) { + int port = lua::tointeger(L, 1); + lua::pushvalue(L, 2); + auto callback = lua::create_lambda(L); + u64id_t id = engine->getNetwork().openServer(port, [callback](u64id_t id) { + engine->postRunnable([=]() { + callback({id}); + }); + }); + return lua::pushinteger(L, id); +} + const luaL_Reg networklib[] = { {"get", lua::wrap}, {"get_binary", lua::wrap}, + {"__open", lua::wrap}, {"__connect", lua::wrap}, {"__close", lua::wrap}, {"__send", lua::wrap}, diff --git a/src/network/Network.cpp b/src/network/Network.cpp index fa46b682..a433860c 100644 --- a/src/network/Network.cpp +++ b/src/network/Network.cpp @@ -244,18 +244,26 @@ static inline int sendsocket( return send(descriptor, buf, len, flags); } +static std::string to_string(const sockaddr_in* addr) { + char ip[INET_ADDRSTRLEN]; + if (inet_ntop(AF_INET, &(addr->sin_addr), ip, INET_ADDRSTRLEN)) { + return std::string(ip)+":"+std::to_string(addr->sin_port); + } + return ""; +} + static std::string to_string(const addrinfo* addr) { if (addr->ai_family == AF_INET) { auto psai = reinterpret_cast(addr->ai_addr); char ip[INET_ADDRSTRLEN]; if (inet_ntop(addr->ai_family, &(psai->sin_addr), ip, INET_ADDRSTRLEN)) { - return std::string(ip); + return std::string(ip)+":"+std::to_string(psai->sin_port); } } else if (addr->ai_family == AF_INET6) { auto psai = reinterpret_cast(addr->ai_addr); char ip[INET6_ADDRSTRLEN]; if (inet_ntop(addr->ai_family, &(psai->sin6_addr), ip, INET6_ADDRSTRLEN)) { - return std::string(ip); + return std::string(ip)+":"+std::to_string(psai->sin6_port); } } return ""; @@ -265,6 +273,7 @@ class SocketConnection : public Connection { SOCKET descriptor; bool open = true; addrinfo* addr; + std::string addrString; size_t totalUpload = 0; size_t totalDownload = 0; ConnectionState state = ConnectionState::INITIAL; @@ -288,8 +297,8 @@ class SocketConnection : public Connection { state = ConnectionState::CONNECTED; } public: - SocketConnection(SOCKET descriptor, addrinfo* addr) - : descriptor(descriptor), addr(addr), buffer(16'384) {} + SocketConnection(SOCKET descriptor, addrinfo* addr, const std::string& addrString) + : descriptor(descriptor), addr(addr), addrString(addrString), buffer(16'384) {} ~SocketConnection() { if (state != ConnectionState::CLOSED) { @@ -299,7 +308,9 @@ public: if (thread) { thread->join(); } - freeaddrinfo(addr); + if (addr) { + freeaddrinfo(addr); + } } void connect(runnable callback) override { @@ -352,8 +363,8 @@ public: int err = errno; close(); throw std::runtime_error( - "Send failed [errno=" + std::to_string(err) + - "]: " + std::string(strerror(err)) + "Send failed [errno=" + std::to_string(err) + "]: " + + std::string(strerror(err)) ); } totalUpload += len; @@ -370,8 +381,10 @@ public: shutdown(descriptor, 2); closesocket(descriptor); } - thread->join(); - thread = nullptr; + if (thread) { + thread->join(); + thread = nullptr; + } } size_t getTotalUpload() const override { @@ -403,7 +416,7 @@ public: freeaddrinfo(addrinfo); throw std::runtime_error("Could not create socket"); } - auto socket = std::make_shared(descriptor, addrinfo); + auto socket = std::make_shared(descriptor, addrinfo, to_string(addrinfo)); socket->connect(std::move(callback)); return socket; } @@ -413,6 +426,101 @@ public: } }; +class SocketTcpSServer : public TcpServer { + Network* network; + SOCKET descriptor; + std::vector clients; + bool open = true; + std::unique_ptr thread = nullptr; +public: + SocketTcpSServer(Network* network, SOCKET descriptor) + : network(network), descriptor(descriptor) {} + + ~SocketTcpSServer() { + closeSocket(); + } + + void startListen(consumer handler) override { + thread = std::make_unique([this, handler]() { + while (open) { + logger.info() << "listening for connections"; + if (listen(descriptor, 2) < 0) { + close(); + break; + } + socklen_t addrlen = sizeof(sockaddr_in); + SOCKET clientDescriptor; + sockaddr_in address; + logger.info() << "accepting clients"; + if ((clientDescriptor = accept(descriptor, (sockaddr*)&address, &addrlen)) < 0) { + close(); + break; + } + logger.info() << "client connected: " << to_string(&address); + auto socket = std::make_shared( + clientDescriptor, nullptr, to_string(&address) + ); + u64id_t id = network->addConnection(socket); + clients.push_back(id); + handler(id); + } + }); + } + + void closeSocket() { + if (!open) { + return; + } + logger.info() << "closing server"; + open = false; + for (u64id_t clientid : clients) { + if (auto client = network->getConnection(clientid)) { + client->close(); + } + } + clients.clear(); + + shutdown(descriptor, 2); + closesocket(descriptor); + thread->join(); + } + + void close() override { + closeSocket(); + } + + bool isOpen() override { + return open; + } + static std::shared_ptr openServer( + Network* network, int port, consumer handler + ) { + SOCKET descriptor = socket( + AF_INET, SOCK_STREAM, 0 + ); + if (descriptor == -1) { + throw std::runtime_error("Could not create server socket"); + } + int opt = 1; + if (setsockopt(descriptor, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &opt, sizeof(opt))) { + closesocket(descriptor); + throw std::runtime_error("setsockopt"); + } + sockaddr_in address; + address.sin_family = AF_INET; + address.sin_addr.s_addr = INADDR_ANY; + address.sin_port = htons(port); + if (bind(descriptor, (sockaddr*)&address, sizeof(address)) < 0) { + closesocket(descriptor); + throw std::runtime_error("could not bind port "+std::to_string(port)); + } + logger.info() << "opened server at port " << port; + auto server = std::make_shared(network, descriptor); + server->startListen(std::move(handler)); + return server; + } +}; + Network::Network(std::unique_ptr requests) : requests(std::move(requests)) { } @@ -445,6 +553,19 @@ u64id_t Network::connect(const std::string& address, int port, consumer return id; } +u64id_t Network::openServer(int port, consumer handler) { + u64id_t id = nextServer++; + auto server = SocketTcpSServer::openServer(this, port, handler); + servers[id] = std::move(server); + return id; +} + +u64id_t Network::addConnection(const std::shared_ptr& socket) { + u64id_t id = nextConnection++; + connections[id] = std::move(socket); + return id; +} + size_t Network::getTotalUpload() const { size_t totalUpload = 0; for (const auto& [_, socket] : connections) { diff --git a/src/network/Network.hpp b/src/network/Network.hpp index f4b6ceef..741b3ed4 100644 --- a/src/network/Network.hpp +++ b/src/network/Network.hpp @@ -34,6 +34,8 @@ namespace network { class Connection { public: + virtual ~Connection() {} + virtual void connect(runnable callback) = 0; virtual int recv(char* buffer, size_t length) = 0; virtual int send(const char* buffer, size_t length) = 0; @@ -46,10 +48,21 @@ namespace network { virtual ConnectionState getState() const = 0; }; + class TcpServer { + public: + virtual ~TcpServer() {} + virtual void startListen(consumer handler) = 0; + virtual void close() = 0; + virtual bool isOpen() = 0; + }; + class Network { std::unique_ptr requests; std::unordered_map> connections; u64id_t nextConnection = 1; + + std::unordered_map> servers; + u64id_t nextServer = 1; public: Network(std::unique_ptr requests); ~Network(); @@ -65,6 +78,10 @@ namespace network { u64id_t connect(const std::string& address, int port, consumer callback); + u64id_t openServer(int port, consumer handler); + + u64id_t addConnection(const std::shared_ptr& connection); + size_t getTotalUpload() const; size_t getTotalDownload() const;