diff --git a/libraries/libstratosphere/include/stratosphere/tipc/tipc_server_manager.hpp b/libraries/libstratosphere/include/stratosphere/tipc/tipc_server_manager.hpp index 6454ab129..ce218e2b5 100644 --- a/libraries/libstratosphere/include/stratosphere/tipc/tipc_server_manager.hpp +++ b/libraries/libstratosphere/include/stratosphere/tipc/tipc_server_manager.hpp @@ -82,7 +82,7 @@ namespace ams::tipc { public: class PortManagerBase : public PortManagerInterface { public: - enum MessageType { + enum MessageType : u8 { MessageType_AddSession = 0, MessageType_TriggerResume = 1, }; @@ -96,17 +96,33 @@ namespace ams::tipc { os::WaitableHolderType m_message_queue_holder; uintptr_t m_message_queue_storage[MaxSessions]; ObjectManagerBase *m_object_manager; + ServerManagerImpl *m_server_manager; public: - PortManagerBase() : m_id(), m_num_sessions(), m_port_number(), m_waitable_manager(), m_deferral_manager(), m_message_queue(), m_message_queue_holder(), m_message_queue_storage(), m_object_manager() { + PortManagerBase() : m_id(), m_num_sessions(), m_port_number(), m_waitable_manager(), m_deferral_manager(), m_message_queue(), m_message_queue_holder(), m_message_queue_storage(), m_object_manager(), m_server_manager() { /* Setup our message queue. */ os::InitializeMessageQueue(std::addressof(m_message_queue), m_message_queue_storage, util::size(m_message_queue_storage)); os::InitializeWaitableHolder(std::addressof(m_message_queue_holder), std::addressof(m_message_queue), os::MessageQueueWaitType::ForNotEmpty); } - void InitializeBase(s32 id, ObjectManagerBase *manager) { + constexpr s32 GetPortIndex() const { + return m_port_number; + } + + s32 GetSessionCount() const { + return m_num_sessions; + } + + ObjectManagerBase *GetObjectManager() const { + return m_object_manager; + } + + void InitializeBase(s32 id, ServerManagerImpl *sm, ObjectManagerBase *manager) { /* Set our id. */ m_id = id; + /* Set our server manager. */ + m_server_manager = sm; + /* Reset our session count. */ m_num_sessions = 0; @@ -149,22 +165,128 @@ namespace ams::tipc { return m_object_manager->ReplyAndReceive(out_holder, out_object, reply_target, std::addressof(m_waitable_manager)); } + void ProcessMessages() { + /* While we have messages in our queue, receive and handle them. */ + uintptr_t message_type, message_data; + while (os::TryReceiveMessageQueue(std::addressof(message_type), std::addressof(m_message_queue))) { + /* Receive the message's data. */ + os::ReceiveMessageQueue(std::addressof(message_data), std::addressof(m_message_queue)); + + /* Handle the specific message. */ + switch (static_cast(static_cast::type>(message_type))) { + case MessageType_AddSession: + { + /* Get the handle from where it's packed into the message type. */ + const svc::Handle session_handle = static_cast(message_type >> BITSIZEOF(u32)); + + /* Allocate a service object for the port. */ + auto *service_object = m_server_manager->AllocateObject(static_cast(message_data)); + + /* Create a waitable object for the session. */ + tipc::WaitableObject object; + + /* Setup the object. */ + object.InitializeAsSession(session_handle, true, service_object); + + /* Register the object. */ + m_object_manager->AddObject(object); + } + break; + case MessageType_TriggerResume: + if constexpr (IsDeferralSupported) { + /* Acquire exclusive server manager access. */ + std::scoped_lock lk(m_server_manager->GetMutex()); + + /* Perform the resume. */ + const auto resume_key = ConvertMessageToKey(message_data); + m_deferral_manager.Resume(resume_key, this); + } + break; + AMS_UNREACHABLE_DEFAULT_CASE(); + } + } + } + + void CloseSession(WaitableObject &object) { + /* Get the object's handle. */ + const auto handle = object.GetHandle(); + + /* Close the object with our manager. */ + m_object_manager->CloseObject(handle); + + /* Close the handle itself. */ + R_ABORT_UNLESS(svc::CloseHandle(handle)); + + /* Decrement our session count. */ + --m_num_sessions; + } + + void CloseSessionIfNecessary(WaitableObject &object, bool necessary) { + if (necessary) { + /* Get the object's handle. */ + const auto handle = object.GetHandle(); + + /* Close the object with our manager. */ + m_object_manager->CloseObject(handle); + + /* Close the handle itself. */ + R_ABORT_UNLESS(svc::CloseHandle(handle)); + } + + /* Decrement our session count. */ + --m_num_sessions; + } + void StartRegisterRetry(ResumeKey key) { - /* Begin the retry. */ - m_deferral_manager.StartRegisterRetry(key); + if constexpr (IsDeferralSupported) { + /* Acquire exclusive server manager access. */ + std::scoped_lock lk(m_server_manager->GetMutex()); + + /* Begin the retry. */ + m_deferral_manager.StartRegisterRetry(key); + } + } + + void ProcessRegisterRetry(WaitableObject &object) { + if constexpr (IsDeferralSupported) { + /* Acquire exclusive server manager access. */ + std::scoped_lock lk(m_server_manager->GetMutex()); + + /* Process the retry. */ + m_deferral_manager.ProcessRegisterRetry(object); + } } bool TestResume(ResumeKey key) { - /* Check to see if the key corresponds to some deferred message. */ - return m_deferral_manager.TestResume(key); + if constexpr (IsDeferralSupported) { + /* Acquire exclusive server manager access. */ + std::scoped_lock lk(m_server_manager->GetMutex()); + + /* Check to see if the key corresponds to some deferred message. */ + return m_deferral_manager.TestResume(key); + } else { + return false; + } } void TriggerResume(ResumeKey key) { + /* Acquire exclusive server manager access. */ + std::scoped_lock lk(m_server_manager->GetMutex()); + /* Send the key as a message. */ os::SendMessageQueue(std::addressof(m_message_queue), static_cast(MessageType_TriggerResume)); os::SendMessageQueue(std::addressof(m_message_queue), ConvertKeyToMessage(key)); } - private: + + void TriggerAddSession(svc::Handle session_handle, size_t port_index) { + /* Acquire exclusive server manager access. */ + std::scoped_lock lk(m_server_manager->GetMutex()); + + /* Send information about the session as a message. */ + os::SendMessageQueue(std::addressof(m_message_queue), static_cast(MessageType_AddSession) | (static_cast(session_handle) << BITSIZEOF(u32))); + os::SendMessageQueue(std::addressof(m_message_queue), static_cast(port_index)); + } + public: static bool IsRequestDeferred() { if constexpr (IsDeferralSupported) { /* Get the message buffer. */ @@ -199,9 +321,9 @@ namespace ams::tipc { /* ... */ } - void Initialize(s32 id) { + void Initialize(s32 id, ServerManagerImpl *sm) { /* Initialize our base. */ - this->InitializeBase(id, std::addressof(m_object_manager_impl)); + this->InitializeBase(id, sm, std::addressof(m_object_manager_impl)); /* Initialize our object manager. */ m_object_manager_impl->Initialize(std::addressof(this->m_waitable_manager)); @@ -217,7 +339,7 @@ namespace ams::tipc { using PortAllocatorTuple = std::tuple; private: - os::SdkMutex m_mutex; + os::Mutex m_mutex; os::TlsSlot m_tls_slot; PortManagerTuple m_port_managers; PortAllocatorTuple m_port_allocators; @@ -253,10 +375,12 @@ namespace ams::tipc { os::StartThread(m_port_threads + Ix); } public: - ServerManagerImpl() : m_mutex(), m_tls_slot(), m_port_managers(), m_port_allocators() { /* ... */ } + ServerManagerImpl() : m_mutex(true), m_tls_slot(), m_port_managers(), m_port_allocators() { /* ... */ } os::TlsSlot GetTlsSlot() const { return m_tls_slot; } + os::Mutex &GetMutex() { return m_mutex; } + void Initialize() { /* Initialize our tls slot. */ if constexpr (IsDeferralSupported) { @@ -265,7 +389,7 @@ namespace ams::tipc { /* Initialize our port managers. */ [this](std::index_sequence) ALWAYS_INLINE_LAMBDA { - (this->GetPortManager().Initialize(static_cast(Ix)), ...); + (this->GetPortManager().Initialize(static_cast(Ix), this), ...); }(std::make_index_sequence(NumPorts)); } @@ -288,7 +412,41 @@ namespace ams::tipc { /* Process for the last port. */ this->LoopAutoForPort(); } + + tipc::ServiceObjectBase *AllocateObject(size_t port_index) { + /* Check that the port index is valid. */ + AMS_ABORT_UNLESS(port_index < NumPorts); + + /* Try to allocate from each port, in turn. */ + tipc::ServiceObjectBase *allocated = nullptr; + return [this, port_index, &allocated](std::index_sequence) ALWAYS_INLINE_LAMBDA { + (this->TryAllocateObject(port_index, allocated), ...); + }(std::make_index_sequence()); + + /* Return the allocated object. */ + AMS_ABORT_UNLESS(allocated != nullptr); + return allocated; + } private: + template requires (Ix < NumPorts) + void TryAllocateObject(size_t port_index, tipc::ServiceObjectBase *&allocated) { + /* Check that the port index matches. */ + if (port_index == Ix) { + /* Get the allocator. */ + auto &allocator = std::get(m_port_allocators); + + /* Allocate the object. */ + AMS_ABORT_UNLESS(allocated == nullptr); + allocated = allocator.Allocate(); + AMS_ABORT_UNLESS(allocated != nullptr); + + /* If we should, set the object's deleter. */ + if constexpr (IsServiceObjectDeleter::type>) { + allocated->SetDeleter(std::addressof(allocator)); + } + } + } + Result LoopProcess(PortManagerBase &port_manager) { /* Set our tls slot's value to be the port manager we're processing for. */ if constexpr (IsDeferralSupported) { @@ -302,7 +460,113 @@ namespace ams::tipc { /* Process requests forever. */ svc::Handle reply_target = svc::InvalidHandle; while (true) { - /* TODO */ + /* Reply to our pending request, and receive a new one. */ + os::WaitableHolderType *signaled_holder = nullptr; + tipc::WaitableObject signaled_object{}; + R_TRY_CATCH(port_manager.ReplyAndReceive(std::addressof(signaled_holder), std::addressof(signaled_object), reply_target)) { + R_CATCH(os::ResultSessionClosedForReceive, os::ResultReceiveListBroken) { + /* Close the object and continue. */ + port_manager.CloseObject(signaled_object); + + /* We have nothing to reply to. */ + reply_target = svc::InvalidHandle; + continue; + } + } R_END_TRY_CATCH; + + if (signaled_holder == nullptr) { + /* A session was signaled, accessible via signaled_object. */ + switch (signaled_object.GetType()) { + case WaitableObject::ObjectType_Port: + { + /* Try to accept a new session */ + svc::Handle session_handle; + if (R_SUCCEEDED(svc::AcceptSession(std::addressof(session_handle), signaled_object.GetHandle()))) { + this->TriggerAddSession(session_handle, static_cast(port_manager.GetPortIndex())); + } + + /* We have nothing to reply to. */ + reply_target = svc::InvalidHandle; + } + break; + case WaitableObject::ObjectType_Session: + { + /* Process the request */ + const Result process_result = port_manager.GetObjectManager()->ProcessRequest(signaled_object); + if (R_SUCCEEDED(process_result)) { + if constexpr (IsDeferralSupported) { + /* Check if the request is deferred. */ + if (PortManagerBase::IsRequestDeferred()) { + /* Process the retry that we began. */ + port_manager.ProcessRegisterRetry(signaled_object); + + /* We have nothing to reply to. */ + reply_target = svc::InvalidHandle; + } else { + /* We're done processing, so we should reply. */ + reply_target = signaled_object.GetHandle(); + } + } else { + /* We're done processing, so we should reply. */ + reply_target = signaled_object.GetHandle(); + } + } else { + /* We failed to process, so note the session as closed (or close it). */ + port_manager.CloseSessionIfNecessary(signaled_object, !tipc::ResultSessionClosed::Includes(process_result)); + + /* We have nothing to reply to. */ + reply_target = svc::InvalidHandle; + } + } + break; + AMS_UNREACHABLE_DEFAULT_CASE(); + } + } else { + /* Our message queue was signaled. */ + port_manager.ProcessMessages(this); + + /* We have nothing to reply to. */ + reply_target = svc::InvalidHandle; + } + } + } + + void TriggerAddSession(svc::Handle session_handle, size_t port_index) { + /* Acquire exclusive access to ourselves. */ + std::scoped_lock lk(m_mutex); + + /* Select the best port manager. */ + PortManagerBase *best_manager = nullptr; + s32 best_sessions = -1; + const auto session_counts = [this, &best_manager, &best_sessions](std::index_sequence) ALWAYS_INLINE_LAMBDA { + (this->TrySelectBetterPort(best_manager, best_sessions), ...); + }(std::make_index_sequence()); + + /* Trigger the session add on the least-burdened manager. */ + best_manager->TriggerAddSession(session_handle, port_index); + } + + template requires (Ix < NumPorts) + void TrySelectBetterPort(PortManagerBase *&best_manager, s32 &best_sessions) { + if constexpr (Ix == 0) { + best_manager = std::addressof(this->GetPortManager()); + best_sessions = std::min(best_manager->GetSessionCount(), static_cast(SessionsPerPortManager)); + } else if constexpr (Ix < NumPorts - 1) { + auto &cur_manager = this->GetPortManager(); + const auto cur_sessions = std::min(cur_manager.GetSessionCount(), static_cast(SessionsPerPortManager)); + + if (cur_sessions < best_sessions) { + best_manager = std::addressof(cur_manager); + best_sessions = cur_sessions; + } + } else { + auto &cur_manager = this->GetPortManager(); + const auto cur_sessions = cur_manager.GetSessionCount(); + + if (cur_sessions < best_sessions) { + best_manager = std::addressof(cur_manager); + best_sessions = cur_sessions; + } } } };