diff --git a/libraries/libmesosphere/include/mesosphere/kern_k_auto_object.hpp b/libraries/libmesosphere/include/mesosphere/kern_k_auto_object.hpp index 431bd32f7..2618b04d2 100644 --- a/libraries/libmesosphere/include/mesosphere/kern_k_auto_object.hpp +++ b/libraries/libmesosphere/include/mesosphere/kern_k_auto_object.hpp @@ -248,10 +248,16 @@ namespace ams::kern { public: constexpr ALWAYS_INLINE KAutoObjectWithList(util::ConstantInitializeTag) : KAutoObjectWithListBase(util::ConstantInitialize), m_list_node(util::ConstantInitialize) { /* ... */ } ALWAYS_INLINE explicit KAutoObjectWithList() { /* ... */ } + public: + using RedBlackKeyType = u64; - static ALWAYS_INLINE int Compare(const KAutoObjectWithList &lhs, const KAutoObjectWithList &rhs) { - const u64 lid = lhs.GetId(); - const u64 rid = rhs.GetId(); + static constexpr ALWAYS_INLINE RedBlackKeyType GetRedBlackKey(const RedBlackKeyType &v) { return v; } + static constexpr ALWAYS_INLINE RedBlackKeyType GetRedBlackKey(const KAutoObjectWithList &v) { return v.GetId(); } + + template requires (std::same_as || std::same_as) + static ALWAYS_INLINE int Compare(const T &lhs, const KAutoObjectWithList &rhs) { + const u64 lid = GetRedBlackKey(lhs); + const u64 rid = GetRedBlackKey(rhs); if (lid < rid) { return -1; diff --git a/libraries/libmesosphere/include/mesosphere/kern_k_auto_object_container.hpp b/libraries/libmesosphere/include/mesosphere/kern_k_auto_object_container.hpp index 1687b702a..6efd657d0 100644 --- a/libraries/libmesosphere/include/mesosphere/kern_k_auto_object_container.hpp +++ b/libraries/libmesosphere/include/mesosphere/kern_k_auto_object_container.hpp @@ -33,17 +33,21 @@ namespace ams::kern { explicit ListAccessor(KAutoObjectWithListContainer *container) : KScopedLightLock(container->m_lock), m_list(container->m_object_list) { /* ... */ } explicit ListAccessor(KAutoObjectWithListContainer &container) : KScopedLightLock(container.m_lock), m_list(container.m_object_list) { /* ... */ } - typename ListType::iterator begin() const { + ALWAYS_INLINE typename ListType::iterator begin() const { return m_list.begin(); } - typename ListType::iterator end() const { + ALWAYS_INLINE typename ListType::iterator end() const { return m_list.end(); } - typename ListType::iterator find(typename ListType::const_reference ref) const { + ALWAYS_INLINE typename ListType::iterator find(typename ListType::const_reference ref) const { return m_list.find(ref); } + + ALWAYS_INLINE typename ListType::iterator find_key(typename ListType::const_key_reference ref) const { + return m_list.find_key(ref); + } }; friend class ListAccessor; diff --git a/libraries/libmesosphere/include/mesosphere/kern_k_thread_local_page.hpp b/libraries/libmesosphere/include/mesosphere/kern_k_thread_local_page.hpp index ec7c43640..280b65904 100644 --- a/libraries/libmesosphere/include/mesosphere/kern_k_thread_local_page.hpp +++ b/libraries/libmesosphere/include/mesosphere/kern_k_thread_local_page.hpp @@ -42,10 +42,16 @@ namespace ams::kern { explicit KThreadLocalPage() : KThreadLocalPage(Null) { /* ... */ } constexpr ALWAYS_INLINE KProcessAddress GetAddress() const { return m_virt_addr; } + public: + using RedBlackKeyType = KProcessAddress; - static constexpr ALWAYS_INLINE int Compare(const KThreadLocalPage &lhs, const KThreadLocalPage &rhs) { - const KProcessAddress lval = lhs.GetAddress(); - const KProcessAddress rval = rhs.GetAddress(); + static constexpr ALWAYS_INLINE RedBlackKeyType GetRedBlackKey(const RedBlackKeyType &v) { return v; } + static constexpr ALWAYS_INLINE RedBlackKeyType GetRedBlackKey(const KThreadLocalPage &v) { return v.GetAddress(); } + + template requires (std::same_as || std::same_as) + static constexpr ALWAYS_INLINE int Compare(const T &lhs, const KThreadLocalPage &rhs) { + const KProcessAddress lval = GetRedBlackKey(lhs); + const KProcessAddress rval = GetRedBlackKey(rhs); if (lval < rval) { return -1; diff --git a/libraries/libmesosphere/source/kern_k_process.cpp b/libraries/libmesosphere/source/kern_k_process.cpp index 9ddf09835..89ec422c1 100644 --- a/libraries/libmesosphere/source/kern_k_process.cpp +++ b/libraries/libmesosphere/source/kern_k_process.cpp @@ -705,10 +705,10 @@ namespace ams::kern { KScopedSchedulerLock sl; /* Try to find the page in the partially used list. */ - auto it = m_partially_used_tlp_tree.find(KThreadLocalPage(util::AlignDown(GetInteger(addr), PageSize))); + auto it = m_partially_used_tlp_tree.find_key(util::AlignDown(GetInteger(addr), PageSize)); if (it == m_partially_used_tlp_tree.end()) { /* If we don't find it, it has to be in the fully used list. */ - it = m_fully_used_tlp_tree.find(KThreadLocalPage(util::AlignDown(GetInteger(addr), PageSize))); + it = m_fully_used_tlp_tree.find_key(util::AlignDown(GetInteger(addr), PageSize)); R_UNLESS(it != m_fully_used_tlp_tree.end(), svc::ResultInvalidAddress()); /* Release the region. */ @@ -749,9 +749,9 @@ namespace ams::kern { KThreadLocalPage *tlp = nullptr; { KScopedSchedulerLock sl; - if (auto it = m_partially_used_tlp_tree.find(KThreadLocalPage(util::AlignDown(GetInteger(addr), PageSize))); it != m_partially_used_tlp_tree.end()) { + if (auto it = m_partially_used_tlp_tree.find_key(util::AlignDown(GetInteger(addr), PageSize)); it != m_partially_used_tlp_tree.end()) { tlp = std::addressof(*it); - } else if (auto it = m_fully_used_tlp_tree.find(KThreadLocalPage(util::AlignDown(GetInteger(addr), PageSize))); it != m_fully_used_tlp_tree.end()) { + } else if (auto it = m_fully_used_tlp_tree.find_key(util::AlignDown(GetInteger(addr), PageSize)); it != m_fully_used_tlp_tree.end()) { tlp = std::addressof(*it); } else { return nullptr; diff --git a/libraries/libmesosphere/source/kern_k_thread.cpp b/libraries/libmesosphere/source/kern_k_thread.cpp index 303b2268a..548eb1c41 100644 --- a/libraries/libmesosphere/source/kern_k_thread.cpp +++ b/libraries/libmesosphere/source/kern_k_thread.cpp @@ -1324,33 +1324,16 @@ namespace ams::kern { KThread::ListAccessor accessor; const auto end = accessor.end(); - /* Define helper object to find the thread. */ - class IdObjectHelper : public KAutoObjectWithListContainer::ListType::value_type { - private: - u64 m_id; - public: - explicit IdObjectHelper(u64 id) : m_id(id) { /* ... */ } - virtual u64 GetId() const override { return m_id; } - }; - /* Find the object with the right id. */ - const auto it = accessor.find(IdObjectHelper(thread_id)); - - /* Check to make sure we found the thread. */ - if (it == end) { - return nullptr; + if (const auto it = accessor.find_key(thread_id); it != end) { + /* Try to open the thread. */ + if (KThread *thread = static_cast(std::addressof(*it)); AMS_LIKELY(thread->Open())) { + MESOSPHERE_ASSERT(thread->GetId() == thread_id); + return thread; + } } - /* Get the thread. */ - KThread *thread = static_cast(std::addressof(*it)); - - /* Open the thread. */ - if (AMS_LIKELY(thread->Open())) { - MESOSPHERE_ASSERT(thread->GetId() == thread_id); - return thread; - } - - /* We failed to find the thread. */ + /* We failed to find or couldn't open the thread. */ return nullptr; }