From f7fcb54622923fcd3fd60f34d23c7a352ff67ca1 Mon Sep 17 00:00:00 2001 From: Michael Scire Date: Fri, 19 Feb 2021 13:03:48 -0800 Subject: [PATCH] htcs: implement virtual socket collection --- .../include/stratosphere/htcs.hpp | 2 + .../include/stratosphere/htcs/htcs_api.hpp | 36 + .../include/stratosphere/htcs/htcs_socket.hpp | 46 + .../include/stratosphere/htcs/htcs_types.hpp | 2 + .../htc/server/rpc/htc_rpc_task_table.hpp | 2 +- .../source/htcs/client/htcs_session.hpp | 24 + .../client/htcs_virtual_socket_collection.cpp | 834 ++++++++++++++++++ .../client/htcs_virtual_socket_collection.hpp | 81 ++ .../source/htcs/htcs_socket.cpp | 178 ++++ 9 files changed, 1204 insertions(+), 1 deletion(-) create mode 100644 libraries/libstratosphere/include/stratosphere/htcs/htcs_api.hpp create mode 100644 libraries/libstratosphere/include/stratosphere/htcs/htcs_socket.hpp create mode 100644 libraries/libstratosphere/source/htcs/client/htcs_session.hpp create mode 100644 libraries/libstratosphere/source/htcs/client/htcs_virtual_socket_collection.cpp create mode 100644 libraries/libstratosphere/source/htcs/client/htcs_virtual_socket_collection.hpp create mode 100644 libraries/libstratosphere/source/htcs/htcs_socket.cpp diff --git a/libraries/libstratosphere/include/stratosphere/htcs.hpp b/libraries/libstratosphere/include/stratosphere/htcs.hpp index cda430983..0f88fdda6 100644 --- a/libraries/libstratosphere/include/stratosphere/htcs.hpp +++ b/libraries/libstratosphere/include/stratosphere/htcs.hpp @@ -16,6 +16,8 @@ #pragma once #include +#include +#include #include #include #include diff --git a/libraries/libstratosphere/include/stratosphere/htcs/htcs_api.hpp b/libraries/libstratosphere/include/stratosphere/htcs/htcs_api.hpp new file mode 100644 index 000000000..3673fd0f2 --- /dev/null +++ b/libraries/libstratosphere/include/stratosphere/htcs/htcs_api.hpp @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2018-2020 Atmosphère-NX + * + * This program is free software; you can redistribute it and/or modify it + * under the terms and conditions of the GNU General Public License, + * version 2, as published by the Free Software Foundation. + * + * This program is distributed in the hope it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +#pragma once +#include +#include + +namespace ams::htcs { + + bool IsInitialized(); + + size_t GetWorkingMemorySize(int max_socket_count); + + void Initialize(AllocateFunction allocate, DeallocateFunction deallocate, int num_sessions = SessionCountMax); + void Initialize(void *buffer, size_t buffer_size); + + void InitializeForDisableDisconnectionEmulation(AllocateFunction allocate, DeallocateFunction deallocate, int num_sessions = SessionCountMax); + void InitializeForDisableDisconnectionEmulation(void *buffer, size_t buffer_size); + + void InitializeForSystem(void *buffer, size_t buffer_size, int num_sessions); + + void Finalize(); + +} diff --git a/libraries/libstratosphere/include/stratosphere/htcs/htcs_socket.hpp b/libraries/libstratosphere/include/stratosphere/htcs/htcs_socket.hpp new file mode 100644 index 000000000..3931d1b2b --- /dev/null +++ b/libraries/libstratosphere/include/stratosphere/htcs/htcs_socket.hpp @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2018-2020 Atmosphère-NX + * + * This program is free software; you can redistribute it and/or modify it + * under the terms and conditions of the GNU General Public License, + * version 2, as published by the Free Software Foundation. + * + * This program is distributed in the hope it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +#pragma once +#include +#include + +namespace ams::htcs { + + const HtcsPeerName GetPeerNameAny(); + const HtcsPeerName GetDefaultHostName(); + + s32 GetLastError(); + + s32 Socket(); + s32 Close(s32 desc); + s32 Connect(s32 desc, const SockAddrHtcs *address); + s32 Bind(s32 desc, const SockAddrHtcs *address); + s32 Listen(s32 desc, s32 backlog_count); + s32 Accept(s32 desc, SockAddrHtcs *address); + s32 Shutdown(s32 desc, s32 how); + s32 Fcntl(s32 desc, s32 command, s32 value); + + s32 Select(s32 count, FdSet *read, FdSet *write, FdSet *exception, TimeVal *timeout); + + ssize_t Recv(s32 desc, void *buffer, size_t buffer_size, s32 flags); + ssize_t Send(s32 desc, const void *buffer, size_t buffer_size, s32 flags); + + void FdSetZero(FdSet *set); + void FdSetSet(s32 fd, FdSet *set); + void FdSetClr(s32 fd, FdSet *set); + bool FdSetIsSet(s32 fd, const FdSet *set); + +} diff --git a/libraries/libstratosphere/include/stratosphere/htcs/htcs_types.hpp b/libraries/libstratosphere/include/stratosphere/htcs/htcs_types.hpp index 03cb7b893..b5f3708b6 100644 --- a/libraries/libstratosphere/include/stratosphere/htcs/htcs_types.hpp +++ b/libraries/libstratosphere/include/stratosphere/htcs/htcs_types.hpp @@ -24,6 +24,8 @@ namespace ams::htcs { constexpr inline int PeerNameBufferLength = 32; constexpr inline int PortNameBufferLength = 32; + constexpr inline int SessionCountMax = 0x10; + constexpr inline int SocketCountMax = 40; constexpr inline int FdSetSize = SocketCountMax; diff --git a/libraries/libstratosphere/source/htc/server/rpc/htc_rpc_task_table.hpp b/libraries/libstratosphere/source/htc/server/rpc/htc_rpc_task_table.hpp index 91f472e33..dbc006fcb 100644 --- a/libraries/libstratosphere/source/htc/server/rpc/htc_rpc_task_table.hpp +++ b/libraries/libstratosphere/source/htc/server/rpc/htc_rpc_task_table.hpp @@ -33,7 +33,7 @@ namespace ams::htc::server::rpc { private: /* htcs::ReceiveSmallTask/htcs::ReceiveSendTask are the largest tasks, containing an inline 0xE000 buffer. */ /* We allow for ~0x100 task overhead from the additional events those contain. */ - /* NOTE: Nintnedo hardcodes a maximum size of 0xE1D8, despite SendSmallTask being 0xE098 as of latest check. */ + /* NOTE: Nintendo hardcodes a maximum size of 0xE1D8, despite SendSmallTask being 0xE098 as of latest check. */ static constexpr size_t MaxTaskSize = 0xE100; using TaskStorage = typename std::aligned_storage::type; private: diff --git a/libraries/libstratosphere/source/htcs/client/htcs_session.hpp b/libraries/libstratosphere/source/htcs/client/htcs_session.hpp new file mode 100644 index 000000000..bbd088e06 --- /dev/null +++ b/libraries/libstratosphere/source/htcs/client/htcs_session.hpp @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2018-2020 Atmosphère-NX + * + * This program is free software; you can redistribute it and/or modify it + * under the terms and conditions of the GNU General Public License, + * version 2, as published by the Free Software Foundation. + * + * This program is distributed in the hope it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +#pragma once +#include + +namespace ams::htcs::client { + + void InitializeSessionManager(tma::IHtcsManager **out_manager, tma::IHtcsManager **out_monitor); + void FinalizeSessionManager(); + +} diff --git a/libraries/libstratosphere/source/htcs/client/htcs_virtual_socket_collection.cpp b/libraries/libstratosphere/source/htcs/client/htcs_virtual_socket_collection.cpp new file mode 100644 index 000000000..0c6794dd7 --- /dev/null +++ b/libraries/libstratosphere/source/htcs/client/htcs_virtual_socket_collection.cpp @@ -0,0 +1,834 @@ +/* + * Copyright (c) 2018-2020 Atmosphère-NX + * + * This program is free software; you can redistribute it and/or modify it + * under the terms and conditions of the GNU General Public License, + * version 2, as published by the Free Software Foundation. + * + * This program is distributed in the hope it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +#include +#include "htcs_session.hpp" +#include "htcs_virtual_socket_collection.hpp" + +namespace ams::htcs::client { + + namespace { + + constexpr inline s32 InvalidSocket = -1; + constexpr inline s32 InvalidPrimitive = -1; + + } + + /* Declare client functions. */ + sf::SharedPointer socket(s32 &last_error); + s32 close(sf::SharedPointer socket, s32 &last_error); + s32 bind(sf::SharedPointer socket, const htcs::SockAddrHtcs *address, s32 &last_error); + s32 listen(sf::SharedPointer socket, s32 backlog_count, s32 &last_error); + sf::SharedPointer accept(sf::SharedPointer socket, htcs::SockAddrHtcs *address, s32 &last_error); + s32 fcntl(sf::SharedPointer socket, s32 command, s32 value, s32 &last_error); + + s32 shutdown(sf::SharedPointer socket, s32 how, s32 &last_error); + + ssize_t recv(sf::SharedPointer socket, void *buffer, size_t buffer_size, s32 flags, s32 &last_error); + ssize_t send(sf::SharedPointer socket, const void *buffer, size_t buffer_size, s32 flags, s32 &last_error); + + s32 connect(sf::SharedPointer socket, const htcs::SockAddrHtcs *address, s32 &last_error); + + s32 select(s32 * const read, s32 &num_read, s32 * const write, s32 &num_write, s32 * const except, s32 &num_except, htcs::TimeVal *timeout, s32 &last_error); + + struct VirtualSocket { + s32 m_id; + s32 m_primitive; + sf::SharedPointer m_socket; + bool m_do_bind; + htcs::SockAddrHtcs m_address; + s32 m_listen_backlog_count; + s32 m_fcntl_command; + s32 m_fcntl_value; + bool m_blocking; + + + VirtualSocket() { + /* Initialize. */ + this->Init(); + } + + ~VirtualSocket() { /* ... */ } + + void Init() { + /* Setup fields. */ + m_id = InvalidSocket; + m_primitive = InvalidPrimitive; + m_socket = nullptr; + m_blocking = true; + m_do_bind = false; + + std::memset(std::addressof(m_address), 0, sizeof(m_address)); + + m_listen_backlog_count = -1; + + m_fcntl_command = -1; + m_fcntl_value = 0; + } + + s32 Bind(const htcs::SockAddrHtcs *address, s32 &last_error) { + /* Mark the bind. */ + m_do_bind = true; + + /* Set our address. */ + std::memcpy(std::addressof(m_address), address, sizeof(m_address)); + + /* Clear the error. */ + last_error = 0; + return 0; + } + + s32 Listen(s32 backlog_count, s32 &last_error) { + s32 res = -1; + if (m_do_bind) { + /* Set backlog count. */ + m_listen_backlog_count = std::max(backlog_count, 1); + + /* Clear error. */ + last_error = 0; + res = 0; + } else { + last_error = HTCS_EINVAL; + } + return res; + } + + s32 Fcntl(s32 command, s32 value, s32 &last_error) { + /* Clear error. */ + s32 res = 0; + last_error = 0; + + if (command == HTCS_F_SETFL) { + m_fcntl_command = command; + m_fcntl_value = value; + + m_blocking = (value & HTCS_O_NONBLOCK) == 0; + } else if (command == HTCS_F_GETFL) { + res = m_fcntl_value; + } else { + last_error = HTCS_EINVAL; + res = -1; + } + + return res; + } + + s32 SetSocket(sf::SharedPointer socket, s32 &last_error) { + s32 res = 0; + + if (m_socket == nullptr && socket != nullptr) { + /* Set our socket. */ + m_socket = socket; + + /* Bind, fcntl, and listen, since those may have been deferred. */ + if (m_do_bind) { + res = bind(m_socket, std::addressof(m_address), last_error); + } + + if (res == 0 && m_fcntl_command != -1) { + res = fcntl(m_socket, m_fcntl_command, m_fcntl_value, last_error); + } + + if (res == 0 && m_listen_backlog_count > 0) { + res = listen(m_socket, m_listen_backlog_count, last_error); + } + } + + return res; + } + }; + + VirtualSocketCollection::VirtualSocketCollection() + : m_socket_list(nullptr), + m_list_count(0), + m_list_size(0), + m_next_id(1), + m_mutex() + { + /* ... */ + } + + VirtualSocketCollection::~VirtualSocketCollection() { + /* Clear ourselves. */ + this->Clear(); + + /* Destroy all sockets in our list. */ + for (auto i = 0; i < m_list_size; ++i) { + std::destroy_at(m_socket_list + i); + } + + /* Clear the backing memory for our socket list. */ + std::memset(m_buffer, 0, sizeof(VirtualSocket) * m_list_size); + } + + size_t VirtualSocketCollection::GetWorkingMemorySize(int num_sockets) { + AMS_ASSERT(num_sockets < htcs::SocketCountMax); + return num_sockets * sizeof(VirtualSocket); + } + + void VirtualSocketCollection::Init(void *buffer, size_t buffer_size) { + /* Set our buffer. */ + m_buffer = buffer; + m_buffer_size = buffer_size; + + /* Configure our list. */ + m_list_size = static_cast(m_buffer_size / sizeof(VirtualSocket)); + m_list_size = std::min(m_list_size, htcs::SocketCountMax); + + /* Initialize our list. */ + m_socket_list = static_cast(m_buffer); + for (auto i = 0; i < m_list_size; ++i) { + std::construct_at(m_socket_list + i); + } + } + + void VirtualSocketCollection::Clear() { + /* Lock ourselves. */ + std::scoped_lock lk(m_mutex); + + /* Clear our list. */ + m_list_count = 0; + } + + s32 VirtualSocketCollection::Socket(s32 &error_code) { + /* Clear error code. */ + error_code = 0; + + /* Create the socket. */ + sf::SharedPointer socket(nullptr); + return this->CreateSocket(sf::SharedPointer{nullptr}, error_code); + } + + s32 VirtualSocketCollection::Close(s32 id, s32 &error_code) { + /* Clear error code. */ + error_code = 0; + + /* Prepare to find the socket. */ + s32 res = 0; + sf::SharedPointer socket = nullptr; + + /* Find the socket. */ + { + std::scoped_lock lk(m_mutex); + + if (auto index = this->Find(id, std::addressof(error_code)); index >= 0) { + /* Get the socket's object. */ + VirtualSocket *virt_socket = m_socket_list + index; + socket = virt_socket->m_socket; + + /* Move the list. */ + for (/* ... */; index < m_list_count - 1; ++index) { + m_socket_list[index] = m_socket_list[index + 1]; + } + + /* Clear the now unused last list entry. */ + m_socket_list[index].Init(); + + /* Decrement our list count. */ + --m_list_count; + } + } + + /* If we found the socket, close it. */ + if (socket != nullptr) { + close(socket, error_code); + + /* Clear the error code. */ + res = 0; + error_code = 0; + } + + return index >= 0 ? res : -1; + } + + s32 VirtualSocketCollection::Bind(s32 id, const htcs::SockAddrHtcs *address, s32 &error_code) { + /* Setup result/error code. */ + s32 res = -1; + error_code = 0; + + /* Get the socket. */ + sf::SharedPointer socket = this->GetSocket(id, std::addressof(error_code)); + + /* If we found the socket, bind. */ + if (socket != nullptr) { + res = bind(socket, address, error_code); + } else if (error_code != HTCS_EBADF) { + /* Lock ourselves. */ + std::scoped_lock lk(m_mutex); + + /* Check if the socket is already bound. */ + if (const auto index = this->Find(id); index >= 0) { + const auto exists = this->HasAddr(address); + + if (m_socket_list[index].m_do_bind) { + error_code = HTCS_EINVAL; + res = -1; + } else if (exists) { + error_code = HTCS_EADDRINUSE; + res = -1; + } else { + res = m_socket_list[index].Bind(address, error_code); + } + } else { + error_code = HTCS_EBADF; + } + } + + return res; + } + s32 VirtualSocketCollection::Listen(s32 id, s32 backlog_count, s32 &error_code) { + /* Setup result/error code. */ + s32 res = -1; + error_code = 0; + + /* Get the socket. */ + sf::SharedPointer socket = this->GetSocket(id, std::addressof(error_code)); + + /* If we found the socket, bind. */ + if (socket != nullptr) { + res = listen(socket, backlog_count, error_code); + } else if (error_code != HTCS_EBADF) { + /* Lock ourselves. */ + std::scoped_lock lk(m_mutex); + + /* Try to listen on the virtual socket. */ + if (const auto index = this->Find(id); index >= 0) { + res = m_socket_list[index].Listen(backlog_count, error_code); + } + } + + return res; + } + + s32 VirtualSocketCollection::Accept(s32 id, htcs::SockAddrHtcs *address, s32 &error_code) { + /* Setup result/error code. */ + s32 res = -1; + error_code = 0; + + /* Declare socket that we're creating. */ + sf::SharedPointer new_socket = nullptr; + + /* Get the socket. */ + sf::SharedPointer socket = this->GetSocket(id, std::addressof(error_code)); + + /* If we found the socket, bind. */ + if (socket != nullptr) { + if (error_code != HTCS_ENONE) { + return -1; + } + + new_socket = this->DoAccept(socket, id, address, error_code); + } else if (error_code != HTCS_EBADF) { + /* Fetch the socket. */ + socket = this->FetchSocket(id, error_code); + + /* Wait for the socket. */ + while (socket == nullptr) { + /* Determine whether we should block/listen. */ + bool block_until_done = true; + bool listened = false; + + s32 index; + { + std::scoped_lock lk(m_mutex); + if (index = this->Find(id, std::addressof(error_code)); index >= 0) { + block_until_done = m_socket_list[index].m_blocking; + listened = m_socket_list[index].m_listen_backlog_count > 0; + } + } + + /* Check that the socket exists. */ + if (index < 0) { + error_code = HTCS_EINTR; + return -1; + } + + /* Check that the socket has been listened. */ + if (!listened) { + error_code = HTCS_EINVAL; + return -1; + } + + /* Check that we should block. */ + if (!block_until_done) { + error_code = HTCS_EWOULDBLOCK; + return -1; + } + + /* Wait before trying again. */ + os::SleepThread(TimeSpan::FromMilliSeconds(500)); + + /* Fetch the potentially updated socket. */ + socket = this->FetchSocket(id, error_code); + } + + /* Check that we haven't errored. */ + if (error_code != HTCS_ENONE) { + return -1; + } + + /* Do the accept. */ + new_socket = this->DoAccept(socket, id, address, error_code); + } + + /* If we have a new socket, register it. */ + if (new_socket != 0) { + res = this->CreateSocket(new_socket, error_code); + if (res < 0) { + s32 tmp_error_code; + close(new_socket, tmp_error_code); + } + } + + return res; + } + + s32 VirtualSocketCollection::Fcntl(s32 id, s32 command, s32 value, s32 &error_code) { + /* Setup result/error code. */ + s32 res = -1; + error_code = 0; + + /* Get the socket. */ + sf::SharedPointer socket = this->GetSocket(id, std::addressof(error_code)); + + /* If we found the socket, bind. */ + if (socket != nullptr) { + res = fcntl(socket, command, value, error_code); + } else if (error_code != HTCS_EBADF) { + /* Lock ourselves. */ + std::scoped_lock lk(m_mutex); + + /* Try to listen on the virtual socket. */ + if (const auto index = this->Find(id); index >= 0) { + res = m_socket_list[index].Fcntl(command, value, error_code); + } + } + + return res; + } + + s32 VirtualSocketCollection::Shutdown(s32 id, s32 how, s32 &error_code) { + /* Setup result/error code. */ + s32 res = -1; + error_code = 0; + + /* Get the socket. */ + sf::SharedPointer socket = this->GetSocket(id, std::addressof(error_code)); + + /* If we found the socket, bind. */ + if (socket != nullptr) { + res = shutdown(socket, how, error_code); + } else if (error_code != HTCS_EBADF) { + error_code = HTCS_ENOTCONN; + } + + return res; + } + + ssize_t VirtualSocketCollection::Recv(s32 id, void *buffer, size_t buffer_size, s32 flags, s32 &error_code) { + /* Setup result/error code. */ + ssize_t res = -1; + error_code = 0; + + /* Fetch the socket. */ + sf::SharedPointer socket = this->FetchSocket(id, error_code); + + /* If we found the socket, bind. */ + if (socket != nullptr) { + if (error_code != HTCS_ENONE) { + return -1; + } + res = recv(socket, buffer, buffer_size, flags, error_code); + } else if (error_code != HTCS_EBADF) { + error_code = HTCS_ENOTCONN; + } + + return res; + } + + ssize_t VirtualSocketCollection::Send(s32 id, const void *buffer, size_t buffer_size, s32 flags, s32 &error_code) { + /* Setup result/error code. */ + ssize_t res = -1; + error_code = 0; + + /* Fetch the socket. */ + sf::SharedPointer socket = this->FetchSocket(id, error_code); + + /* If we found the socket, bind. */ + if (socket != nullptr) { + if (error_code != HTCS_ENONE) { + return -1; + } + res = send(socket, buffer, buffer_size, flags, error_code); + } else if (error_code != HTCS_EBADF) { + error_code = HTCS_ENOTCONN; + } + + return res; + } + + s32 VirtualSocketCollection::Connect(s32 id, const htcs::SockAddrHtcs *address, s32 &error_code) { + /* Setup result/error code. */ + s32 res = -1; + error_code = 0; + + /* Fetch the socket. */ + sf::SharedPointer socket = this->FetchSocket(id, error_code); + + /* If we found the socket, bind. */ + if (socket != nullptr) { + if (error_code != HTCS_ENONE) { + return -1; + } + res = connect(socket, address, error_code); + } else if (error_code != HTCS_EBADF) { + error_code = HTCS_EADDRNOTAVAIL; + } + + return res; + } + + s32 VirtualSocketCollection::Select(htcs::FdSet *read, htcs::FdSet *write, htcs::FdSet *except, htcs::TimeVal *timeout, s32 &error_code) { + /* Setup result/error code. */ + s32 res = -1; + s32 tmp_error_code = 0; + + /* Declare buffers. */ + s32 read_primitives[SocketCountMax]; + s32 write_primitives[SocketCountMax]; + s32 except_primitives[SocketCountMax]; + + /* Get reads. */ + s32 num_read = this->GetSockets(read_primitives, read, tmp_error_code); + if (tmp_error_code != HTCS_ENONE) { + error_code = tmp_error_code; + return res; + } + + /* Get writes. */ + s32 num_write = this->GetSockets(write_primitives, write, tmp_error_code); + if (tmp_error_code != HTCS_ENONE) { + error_code = tmp_error_code; + return res; + } + + /* Get excepts. */ + s32 num_except = this->GetSockets(except_primitives, except, tmp_error_code); + if (tmp_error_code != HTCS_ENONE) { + error_code = tmp_error_code; + return res; + } + + /* Perform the select. */ + if (num_read + num_write + num_except > 0) { + res = select(read_primitives, num_read, write_primitives, num_write, except_primitives, num_except, timeout, error_code); + + /* Set the socket primitives. */ + this->SetSockets(read, read_primitives, num_read); + this->SetSockets(write, write_primitives, num_write); + this->SetSockets(except, except_primitives, num_except); + } else { + error_code = HTCS_EINVAL; + } + + return res; + } + + s32 VirtualSocketCollection::CreateId() { + /* Lock ourselves. */ + std::scoped_lock lk(m_mutex); + + /* Get a free id. */ + s32 res = 0; + do { + res = m_next_id++; + if (m_next_id <= 0) { + m_next_id = 1; + } + } while (this->Find(res) >= 0); + + return res; + } + + s32 VirtualSocketCollection::Add(sf::SharedPointer socket) { + /* Check that the socket isn't null. */ + if (socket == nullptr) { + return -1; + } + + /* Create the socket. */ + s32 error_code; + return this->CreateSocket(socket, error_code); + } + + void VirtualSocketCollection::Insert(s32 id, sf::SharedPointer socket) { + /* Lock ourselves. */ + std::scoped_lock lk(m_mutex); + + /* Add the socket to the list. */ + if (m_list_count != 0) { + /* Ensure the list remains in sorder order. */ + s32 index; + for (index = m_list_count - 1; index >= 0; --index) { + if (m_socket_list[index].m_id < id) { + break; + } + m_socket_list[index + 1] = m_socket_list[index]; + } + + /* Set the socket in the list. */ + m_socket_list[index + 1].m_id = id; + m_socket_list[index + 1].m_socket = socket; + } else { + /* Set the socket in the list. */ + m_socket_list[0].m_id = id; + m_socket_list[0].m_socket = socket; + } + + /* Increment our count. */ + ++m_list_count; + } + + void VirtualSocketCollection::SetSize(s32 size) { + /* ... */ + } + + s32 VirtualSocketCollection::Find(s32 id, s32 *error_code) { + /* Perform a binary search to find the socket. */ + if (m_list_count > 0) { + s32 left = 0; + s32 right = m_list_count - 1; + while (left <= right) { + const s32 mid = (left + right) / 2; + if (m_socket_list[mid].m_id == id) { + return mid; + } else if (m_socket_list[mid].m_id > id) { + right = mid - 1; + } else /* if (m_socket_list[mid].m_id < id) */ { + left = mid + 1; + } + } + } + + /* We failed to find the socket. */ + if (error_code != nullptr) { + *error_code = HTCS_EBADF; + } + + return InvalidSocket; + } + + s32 VirtualSocketCollection::FindByPrimitive(s32 primitive) { + /* Find a socket with the desired primitive. */ + for (auto i = 0; i < m_list_size; ++i) { + if (m_socket_list[i].m_primitive == primitive) { + return i; + } + } + + return InvalidPrimitive; + } + + bool VirtualSocketCollection::HasAddr(const htcs::SockAddrHtcs *address) { + /* Try to find a matching socket. */ + for (auto i = 0; i < m_list_count; ++i) { + if (m_socket_list[i].m_address.family == address->family && + std::strcmp(m_socket_list[i].m_address.peer_name.name, address->peer_name.name) == 0 && + std::strcmp(m_socket_list[i].m_address.port_name.name, address->port_name.name) == 0) + { + return true; + } + } + return false; + } + + sf::SharedPointer VirtualSocketCollection::GetSocket(s32 id, s32 *error_code) { + sf::SharedPointer res = nullptr; + + /* Get the socket. */ + { + std::scoped_lock lk(m_mutex); + + if (const auto index = this->Find(id, error_code); index >= 0) { + res = m_socket_list[index].m_socket; + } + } + + return res; + } + + sf::SharedPointer VirtualSocketCollection::FetchSocket(s32 id, s32 &error_code) { + /* Clear the error code. */ + error_code = 0; + + /* Get the socket. */ + auto socket = this->GetSocket(id, std::addressof(error_code)); + if (socket == nullptr && error_code == HTCS_ENONE) { + socket = this->RealizeSocket(id); + } + + return socket; + } + + sf::SharedPointer VirtualSocketCollection::RealizeSocket(s32 id) { + /* Clear the error code. */ + s32 error_code = 0; + + /* Get socket. */ + sf::SharedPointer res = socket(id); + if (res != nullptr) { + /* Assign the new socket. */ + s32 index; + { + std::scoped_lock lk(m_mutex); + + index = this->Find(id, std::addressof(error_code)); + if (index >= 0) { + m_socket_list[index].SetSocket(res, error_code); + } + } + + /* If the socket was deleted, close it. */ + if (index < 0) { + s32 temp_error = 0; + close(res, temp_error); + res = nullptr; + } + } + + return res; + } + + sf::SharedPointer VirtualSocketCollection::DoAccept(sf::SharedPointer socket, s32 id, htcs::SockAddrHtcs *address, s32 &error_code) { + /* Clear the error code. */ + error_code = 0; + + /* Try to accept. */ + sf::SharedPointer new_socket = accept(socket, address, error_code); + if (error_code == HTCS_ENETDOWN) { + new_socket = accept(socket, address, error_code); + + std::scoped_lock lk(m_mutex); + if (const auto index = this->Find(id, std::addressof(error_code)); index >= 0) { + m_socket_list[index].m_socket = nullptr; + } + } + + return new_socket; + } + + s32 VirtualSocketCollection::GetSockets(s32 * const out_primitives, htcs::FdSet *set, s32 &error_code) { + /* Clear the error code. */ + error_code = 0; + s32 count = 0; + + /* Walk the fdset. */ + if (set != nullptr) { + for (auto i = 0; i < FdSetSize; ++i) { + /* If the set no longer has fds, we're done. */ + if (set->fds[i] == 0) { + break; + } + + /* Find the fd's primitive. */ + s32 primitive = InvalidPrimitive; + s32 index; + { + std::scoped_lock lk(m_mutex); + if (index = this->Find(set->fds[i], std::addressof(error_code)); index >= 0) { + /* Get the primitive, if necessary. */ + if (m_socket_list[index].m_primitive == InvalidPrimitive && m_socket_list[index].m_socket != nullptr) { + m_socket_list[index].m_socket->GetPrimitive(std::addressof(m_socket_list[index].m_primitive)); + } + + primitive = m_socket_list[index].m_primitive; + } + } + + /* Check that an error didn't occur. */ + if (error_code != HTCS_ENONE) { + return 0; + } + + /* If the primitive is invalid, try to realize the socket. */ + if (primitive == InvalidPrimitive) { + if (this->RealizeSocket(set->fds[i]) != nullptr) { + std::scoped_lock lk(m_mutex); + + /* Get the primitive. */ + if (index = this->Find(set->fds[i], std::addressof(error_code)); index >= 0) { + m_socket_list[index].m_socket->GetPrimitive(std::addressof(m_socket_list[index].m_primitive)); + + primitive = m_socket_list[index].m_primitive; + } + } + + /* Check that an error didn't occur. */ + if (error_code != HTCS_ENONE) { + return 0; + } + } + + /* Set the output primitive. */ + if (primitive != InvalidPrimitive) { + out_primitives[count++] = primitive; + } + } + } + + return count; + } + + void VirtualSocketCollection::SetSockets(htcs::FdSet *set, s32 * const primitives, s32 count) { + if (set != nullptr) { + /* Clear the set. */ + FdSetZero(set); + + /* Copy the fds. */ + for (auto i = 0; i < count; ++i) { + std::scoped_lock lk(m_mutex); + + if (const auto index = this->FindByPrimitive(primitives[i]); index >= 0) { + set->fds[i] = m_socket_list[index].m_id; + } + } + } + } + + s32 VirtualSocketCollection::CreateSocket(sf::SharedPointer socket, s32 &error_code) { + /* Clear the error code. */ + error_code = 0; + s32 id = InvalidSocket; + + /* Check that we can add to the list. */ + if (m_list_count < m_list_size) { + /* Create a new id. */ + id = this->CreateId(); + + /* Insert the socket into the list. */ + this->Insert(id, socket); + } else { + if (socket != nullptr) { + s32 tmp_error_code; + close(socket, tmp_error_code); + } + + error_code = HTCS_EMFILE; + } + + return id; + } + +} diff --git a/libraries/libstratosphere/source/htcs/client/htcs_virtual_socket_collection.hpp b/libraries/libstratosphere/source/htcs/client/htcs_virtual_socket_collection.hpp new file mode 100644 index 000000000..67161b422 --- /dev/null +++ b/libraries/libstratosphere/source/htcs/client/htcs_virtual_socket_collection.hpp @@ -0,0 +1,81 @@ +/* + * Copyright (c) 2018-2020 Atmosphère-NX + * + * This program is free software; you can redistribute it and/or modify it + * under the terms and conditions of the GNU General Public License, + * version 2, as published by the Free Software Foundation. + * + * This program is distributed in the hope it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +#pragma once +#include + +namespace ams::htcs::client { + + struct VirtualSocket; + + class VirtualSocketCollection { + private: + void *m_buffer; + size_t m_buffer_size; + VirtualSocket *m_socket_list; + s32 m_list_count; + s32 m_list_size; + s32 m_next_id; + os::SdkMutex m_mutex; + public: + static size_t GetWorkingMemorySize(int num_sockets); + public: + explicit VirtualSocketCollection(); + ~VirtualSocketCollection(); + public: + void Init(void *buffer, size_t buffer_size); + void Clear(); + + s32 Socket(s32 &error_code); + s32 Close(s32 id, s32 &error_code); + s32 Bind(s32 id, const htcs::SockAddrHtcs *address, s32 &error_code); + s32 Listen(s32 id, s32 backlog_count, s32 &error_code); + s32 Accept(s32 id, htcs::SockAddrHtcs *address, s32 &error_code); + s32 Fcntl(s32 id, s32 command, s32 value, s32 &error_code); + + s32 Shutdown(s32 id, s32 how, s32 &error_code); + + ssize_t Recv(s32 id, void *buffer, size_t buffer_size, s32 flags, s32 &error_code); + ssize_t Send(s32 id, const void *buffer, size_t buffer_size, s32 flags, s32 &error_code); + + s32 Connect(s32 id, const htcs::SockAddrHtcs *address, s32 &error_code); + + s32 Select(htcs::FdSet *read, htcs::FdSet *write, htcs::FdSet *except, htcs::TimeVal *timeout, s32 &error_code); + private: + s32 CreateId(); + + s32 Add(sf::SharedPointer socket); + + void Insert(s32 id, sf::SharedPointer socket); + void SetSize(s32 size); + + s32 Find(s32 id, s32 *error_code = nullptr); + s32 FindByPrimitive(s32 primitive); + + bool HasAddr(const htcs::SockAddrHtcs *address); + + sf::SharedPointer GetSocket(s32 id, s32 *error_code = nullptr); + sf::SharedPointer FetchSocket(s32 id, s32 &error_code); + sf::SharedPointer RealizeSocket(s32 id); + + sf::SharedPointer DoAccept(sf::SharedPointer socket, s32 id, htcs::SockAddrHtcs *address, s32 &error_code); + + s32 GetSockets(s32 * const out_primitives, htcs::FdSet *set, s32 &error_code); + void SetSockets(htcs::FdSet *set, s32 * const primitives, s32 count); + + s32 CreateSocket(sf::SharedPointer socket, s32 &error_code); + }; + +} diff --git a/libraries/libstratosphere/source/htcs/htcs_socket.cpp b/libraries/libstratosphere/source/htcs/htcs_socket.cpp new file mode 100644 index 000000000..5884e43fb --- /dev/null +++ b/libraries/libstratosphere/source/htcs/htcs_socket.cpp @@ -0,0 +1,178 @@ +/* + * Copyright (c) 2018-2020 Atmosphère-NX + * + * This program is free software; you can redistribute it and/or modify it + * under the terms and conditions of the GNU General Public License, + * version 2, as published by the Free Software Foundation. + * + * This program is distributed in the hope it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +#include +#include "client/htcs_session.hpp" +#include "client/htcs_virtual_socket_collection.hpp" + +namespace ams::htcs { + + namespace { + + constinit bool g_initialized = false; + constinit bool g_enable_disconnection_emulation = false; + + constinit AllocateFunction g_allocate_function = nullptr; + constinit DeallocateFunction g_deallocate_function = nullptr; + + constinit void *g_buffer = nullptr; + constinit size_t g_buffer_size = 0; + + constinit os::TlsSlot g_tls_slot; + + constinit tma::IHtcsManager *g_manager = nullptr; + constinit tma::IHtcsManager *g_monitor = nullptr; + + constinit client::VirtualSocketCollection *g_sockets = nullptr; + + void InitializeImpl(void *buffer, size_t buffer_size, int num_sessions) { + /* Check the session count. */ + AMS_ASSERT(0 < num_sessions && num_sessions <= SessionCountMax); + + /* Initialize the manager and monitor. */ + client::InitializeSessionManager(std::addressof(g_manager), std::addressof(g_monitor)); + + /* Register the process. */ + const sf::ClientProcessId process_id{0}; + R_ABORT_UNLESS(g_manager->RegisterProcessId(process_id)); + R_ABORT_UNLESS(g_monitor->MonitorManager(process_id)); + + /* Allocate a tls slot for our last error. */ + os::SdkAllocateTlsSlot(std::addressof(g_tls_slot), nullptr); + + /* Setup the virtual socket collection. */ + AMS_ASSERT(buffer != nullptr); + g_sockets = reinterpret_cast(buffer); + std::construct_at(g_sockets); + g_sockets->Init(static_cast(buffer) + sizeof(*g_sockets), buffer_size - sizeof(*g_sockets)); + + /* Mark initialized. */ + g_initialized = true; + } + + void InitializeImpl(AllocateFunction allocate, DeallocateFunction deallocate, int num_sessions, int num_sockets) { + /* Check the session count. */ + AMS_ASSERT(0 < num_sessions && num_sessions <= SessionCountMax); + + /* Set the allocation functions. */ + g_allocate_function = allocate; + g_deallocate_function = deallocate; + + /* Allocate a buffer. */ + g_buffer_size = sizeof(client::VirtualSocketCollection) + client::VirtualSocketCollection::GetWorkingMemorySize(num_sockets); + g_buffer = g_allocate_function(g_buffer_size); + + /* Initialize. */ + InitializeImpl(g_buffer, g_buffer_size, num_sessions); + } + + } + + bool IsInitialized() { + return g_initialized; + } + + size_t GetWorkingMemorySize(int num_sockets) { + AMS_ASSERT(num_sockets <= SocketCountMax); + return sizeof(client::VirtualSocketCollection) + client::VirtualSocketCollection::GetWorkingMemorySize(num_sockets); + } + + void Initialize(AllocateFunction allocate, DeallocateFunction deallocate, int num_sessions) { + /* Check that we're not already initialized. */ + AMS_ASSERT(!IsInitialized()); + + /* Configure disconnection emulation. */ + g_enable_disconnection_emulation = true; + + /* Initialize. */ + InitializeImpl(allocate, deallocate, num_sessions, htcs::SocketCountMax); + } + + void Initialize(void *buffer, size_t buffer_size) { + /* Check that we're not already initialized. */ + AMS_ASSERT(!IsInitialized()); + + /* Configure disconnection emulation. */ + g_enable_disconnection_emulation = true; + + /* Initialize. */ + InitializeImpl(buffer, buffer_size, htcs::SessionCountMax); + } + + void InitializeForDisableDisconnectionEmulation(AllocateFunction allocate, DeallocateFunction deallocate, int num_sessions) { + /* Check that we're not already initialized. */ + AMS_ASSERT(!IsInitialized()); + + /* Configure disconnection emulation. */ + g_enable_disconnection_emulation = false; + + /* Initialize. */ + InitializeImpl(allocate, deallocate, num_sessions, htcs::SocketCountMax); + } + + void InitializeForDisableDisconnectionEmulation(void *buffer, size_t buffer_size) { + /* Check that we're not already initialized. */ + AMS_ASSERT(!IsInitialized()); + + /* Configure disconnection emulation. */ + g_enable_disconnection_emulation = false; + + /* Initialize. */ + InitializeImpl(buffer, buffer_size, htcs::SessionCountMax); + } + + void InitializeForSystem(void *buffer, size_t buffer_size, int num_sessions) { + /* Check that we're not already initialized. */ + AMS_ASSERT(!IsInitialized()); + + /* Configure disconnection emulation. */ + g_enable_disconnection_emulation = true; + + /* Initialize. */ + InitializeImpl(buffer, buffer_size, num_sessions); + } + + void Finalize() { + /* Check that we're initialized. */ + AMS_ASSERT(IsInitialized()); + + /* Set not initialized. */ + g_initialized = false; + + /* Destroy the virtual socket collection. */ + std::destroy_at(g_sockets); + g_sockets = nullptr; + + /* Free the buffer, if we have one. */ + if (g_buffer != nullptr) { + g_deallocate_function(g_buffer, g_buffer_size); + g_buffer = nullptr; + g_buffer_size = 0; + } + + /* Free the tls slot. */ + os::FreeTlsSlot(g_tls_slot); + + /* Release the manager objects. */ + sf::ReleaseSharedObject(g_manager); + sf::ReleaseSharedObject(g_monitor); + g_manager = nullptr; + g_monitor = nullptr; + + /* Finalize the bsd client sessions. */ + client::FinalizeSessionManager(); + } + +}