diff --git a/res/scripts/classes.lua b/res/scripts/classes.lua index 3c23fdc4..430f27f1 100644 --- a/res/scripts/classes.lua +++ b/res/scripts/classes.lua @@ -34,3 +34,28 @@ cameras.get = function(name) wrappers[name] = wrapper return wrapper end + + +local Socket = {__index={ + send=function(self, bytes) return network.__send(self.id, bytes) end, + recv=function(self, len, usetable) return network.__recv(self.id, len, usetable) end, + close=function(self) return network.__close(self.id) end, +}} + +network.tcp_connect = function(address, port, callback) + local socket = setmetatable({id=0}, Socket) + return setmetatable({id=network.__connect(address, port, function(id) + socket.id = id + callback(socket) + end)}, Socket) +end + +local ServerSocket = {__index={ + close=function(self) return network.__closeserver(self.id) end, +}} + +network.tcp_open = function(port, handler) + return setmetatable({id=network.__open(port, function(id) + handler(setmetatable({id=id}, Socket)) + end)}, ServerSocket) +end diff --git a/src/network/Network.cpp b/src/network/Network.cpp index c6736ada..3fae398d 100644 --- a/src/network/Network.cpp +++ b/src/network/Network.cpp @@ -7,7 +7,6 @@ #include #include #include -#include #include #ifdef _WIN32 @@ -426,6 +425,7 @@ class SocketTcpSServer : public TcpServer { Network* network; SOCKET descriptor; std::vector clients; + std::mutex clientsMutex; bool open = true; std::unique_ptr thread = nullptr; public: @@ -457,7 +457,10 @@ public: clientDescriptor, nullptr, to_string(&address) ); u64id_t id = network->addConnection(socket); - clients.push_back(id); + { + std::lock_guard lock(clientsMutex); + clients.push_back(id); + } handler(id); } }); @@ -469,9 +472,13 @@ public: } logger.info() << "closing server"; open = false; - for (u64id_t clientid : clients) { - if (auto client = network->getConnection(clientid)) { - client->close(); + + { + std::lock_guard lock(clientsMutex); + for (u64id_t clientid : clients) { + if (auto client = network->getConnection(clientid)) { + client->close(); + } } } clients.clear(); @@ -536,7 +543,9 @@ void Network::get( requests->get(url, onResponse, onReject, maxSize); } -Connection* Network::getConnection(u64id_t id) const { +Connection* Network::getConnection(u64id_t id) { + std::lock_guard lock(connectionsMutex); + const auto& found = connections.find(id); if (found == connections.end()) { return nullptr; @@ -553,6 +562,8 @@ TcpServer* Network::getServer(u64id_t id) const { } u64id_t Network::connect(const std::string& address, int port, consumer callback) { + std::lock_guard lock(connectionsMutex); + u64id_t id = nextConnection++; auto socket = SocketConnection::connect(address, port, [id, callback]() { callback(id); @@ -569,29 +580,35 @@ u64id_t Network::openServer(int port, consumer handler) { } u64id_t Network::addConnection(const std::shared_ptr& socket) { + std::lock_guard lock(connectionsMutex); + u64id_t id = nextConnection++; connections[id] = std::move(socket); return id; } size_t Network::getTotalUpload() const { - size_t totalUpload = 0; - for (const auto& [_, socket] : connections) { - totalUpload += socket->getTotalUpload(); - } return requests->getTotalUpload() + totalUpload; } size_t Network::getTotalDownload() const { - size_t totalDownload = 0; - for (const auto& [_, socket] : connections) { - totalDownload += socket->getTotalDownload(); - } return requests->getTotalDownload() + totalDownload; } void Network::update() { requests->update(); + + totalDownload = 0; + totalUpload = 0; + { + std::lock_guard lock(connectionsMutex); + for (const auto& [_, socket] : connections) { + totalDownload += socket->getTotalDownload(); + } + for (const auto& [_, socket] : connections) { + totalUpload += socket->getTotalUpload(); + } + } } std::unique_ptr Network::create(const NetworkSettings& settings) { diff --git a/src/network/Network.hpp b/src/network/Network.hpp index 911966db..a6ccb864 100644 --- a/src/network/Network.hpp +++ b/src/network/Network.hpp @@ -2,6 +2,7 @@ #include #include +#include #include "typedefs.hpp" #include "settings.hpp" @@ -58,11 +59,16 @@ namespace network { class Network { std::unique_ptr requests; + std::unordered_map> connections; + std::mutex connectionsMutex {}; u64id_t nextConnection = 1; std::unordered_map> servers; u64id_t nextServer = 1; + + size_t totalDownload = 0; + size_t totalUpload = 0; public: Network(std::unique_ptr requests); ~Network(); @@ -74,7 +80,7 @@ namespace network { long maxSize=0 ); - [[nodiscard]] Connection* getConnection(u64id_t id) const; + [[nodiscard]] Connection* getConnection(u64id_t id); [[nodiscard]] TcpServer* getServer(u64id_t id) const; u64id_t connect(const std::string& address, int port, consumer callback);