From 25e1d3401776f095cf6bc0ad9fbb48a6b6843cb4 Mon Sep 17 00:00:00 2001 From: Michael Scire Date: Tue, 1 Dec 2020 16:19:39 -0800 Subject: [PATCH] KConditionVariable/KAddressArbiter: no need for global compare thread --- .../mesosphere/kern_k_condition_variable.hpp | 2 - .../include/mesosphere/kern_k_thread.hpp | 18 +++++- .../source/kern_k_address_arbiter.cpp | 12 +--- .../source/kern_k_condition_variable.cpp | 5 +- .../libvapours/include/freebsd/sys/tree.h | 63 +++++++++++++++++-- .../util/util_intrusive_red_black_tree.hpp | 47 +++++++++++++- 6 files changed, 124 insertions(+), 23 deletions(-) diff --git a/libraries/libmesosphere/include/mesosphere/kern_k_condition_variable.hpp b/libraries/libmesosphere/include/mesosphere/kern_k_condition_variable.hpp index 57aca7def..3c8928157 100644 --- a/libraries/libmesosphere/include/mesosphere/kern_k_condition_variable.hpp +++ b/libraries/libmesosphere/include/mesosphere/kern_k_condition_variable.hpp @@ -20,8 +20,6 @@ namespace ams::kern { - extern KThread g_cv_arbiter_compare_thread; - class KConditionVariable { public: using ThreadTree = typename KThread::ConditionVariableThreadTreeType; diff --git a/libraries/libmesosphere/include/mesosphere/kern_k_thread.hpp b/libraries/libmesosphere/include/mesosphere/kern_k_thread.hpp index 71a01dba7..67bd30209 100644 --- a/libraries/libmesosphere/include/mesosphere/kern_k_thread.hpp +++ b/libraries/libmesosphere/include/mesosphere/kern_k_thread.hpp @@ -124,7 +124,21 @@ namespace ams::kern { static_assert(sizeof(SyncObjectBuffer::sync_objects) == sizeof(SyncObjectBuffer::handles)); struct ConditionVariableComparator { - static constexpr ALWAYS_INLINE int Compare(const KThread &lhs, const KThread &rhs) { + struct LightCompareType { + uintptr_t cv_key; + s32 priority; + + constexpr ALWAYS_INLINE uintptr_t GetConditionVariableKey() const { + return this->cv_key; + } + + constexpr ALWAYS_INLINE s32 GetPriority() const { + return this->priority; + } + }; + + template requires (std::same_as || std::same_as) + static constexpr ALWAYS_INLINE int Compare(const T &lhs, const KThread &rhs) { const uintptr_t l_key = lhs.GetConditionVariableKey(); const uintptr_t r_key = rhs.GetConditionVariableKey(); @@ -139,6 +153,8 @@ namespace ams::kern { } } }; + static_assert(ams::util::HasLightCompareType); + static_assert(std::same_as, ConditionVariableComparator::LightCompareType>); private: static inline std::atomic s_next_thread_id = 0; private: diff --git a/libraries/libmesosphere/source/kern_k_address_arbiter.cpp b/libraries/libmesosphere/source/kern_k_address_arbiter.cpp index 171fd6d33..c223ff95f 100644 --- a/libraries/libmesosphere/source/kern_k_address_arbiter.cpp +++ b/libraries/libmesosphere/source/kern_k_address_arbiter.cpp @@ -50,9 +50,8 @@ namespace ams::kern { s32 num_waiters = 0; { KScopedSchedulerLock sl; - g_cv_arbiter_compare_thread.SetupForAddressArbiterCompare(addr, -1); - auto it = this->tree.nfind(g_cv_arbiter_compare_thread); + auto it = this->tree.nfind_light({ addr, -1 }); while ((it != this->tree.end()) && (count <= 0 || num_waiters < count) && (it->GetAddressArbiterKey() == addr)) { KThread *target_thread = std::addressof(*it); target_thread->SetSyncedObject(nullptr, ResultSuccess()); @@ -79,10 +78,7 @@ namespace ams::kern { R_UNLESS(UpdateIfEqual(std::addressof(user_value), addr, value, value + 1), svc::ResultInvalidCurrentMemory()); R_UNLESS(user_value == value, svc::ResultInvalidState()); - g_cv_arbiter_compare_thread.SetupForAddressArbiterCompare(addr, -1); - - auto it = this->tree.nfind(g_cv_arbiter_compare_thread); - + auto it = this->tree.nfind_light({ addr, -1 }); while ((it != this->tree.end()) && (count <= 0 || num_waiters < count) && (it->GetAddressArbiterKey() == addr)) { KThread *target_thread = std::addressof(*it); target_thread->SetSyncedObject(nullptr, ResultSuccess()); @@ -103,10 +99,8 @@ namespace ams::kern { s32 num_waiters = 0; { KScopedSchedulerLock sl; - g_cv_arbiter_compare_thread.SetupForAddressArbiterCompare(addr, -1); - - auto it = this->tree.nfind(g_cv_arbiter_compare_thread); + auto it = this->tree.nfind_light({ addr, -1 }); /* Determine the updated value. */ s32 new_value; if (GetTargetFirmware() >= TargetFirmware_7_0_0) { diff --git a/libraries/libmesosphere/source/kern_k_condition_variable.cpp b/libraries/libmesosphere/source/kern_k_condition_variable.cpp index dc768e43c..56ffaf491 100644 --- a/libraries/libmesosphere/source/kern_k_condition_variable.cpp +++ b/libraries/libmesosphere/source/kern_k_condition_variable.cpp @@ -17,8 +17,6 @@ namespace ams::kern { - constinit KThread g_cv_arbiter_compare_thread; - namespace { ALWAYS_INLINE bool ReadFromUser(u32 *out, KProcessAddress address) { @@ -179,9 +177,8 @@ namespace ams::kern { int num_waiters = 0; { KScopedSchedulerLock sl; - g_cv_arbiter_compare_thread.SetupForConditionVariableCompare(cv_key, -1); - auto it = this->tree.nfind(g_cv_arbiter_compare_thread); + auto it = this->tree.nfind_light({ cv_key, -1 }); while ((it != this->tree.end()) && (count <= 0 || num_waiters < count) && (it->GetConditionVariableKey() == cv_key)) { KThread *target_thread = std::addressof(*it); diff --git a/libraries/libvapours/include/freebsd/sys/tree.h b/libraries/libvapours/include/freebsd/sys/tree.h index d18ee1821..7cfaf1330 100644 --- a/libraries/libvapours/include/freebsd/sys/tree.h +++ b/libraries/libvapours/include/freebsd/sys/tree.h @@ -400,6 +400,8 @@ struct { \ RB_PROTOTYPE_REMOVE(name, type, attr); \ RB_PROTOTYPE_FIND(name, type, attr); \ RB_PROTOTYPE_NFIND(name, type, attr); \ + RB_PROTOTYPE_FIND_LIGHT(name, type, attr); \ + RB_PROTOTYPE_NFIND_LIGHT(name, type, attr); \ RB_PROTOTYPE_NEXT(name, type, attr); \ RB_PROTOTYPE_PREV(name, type, attr); \ RB_PROTOTYPE_MINMAX(name, type, attr); @@ -415,6 +417,10 @@ struct { \ attr struct type *name##_RB_FIND(struct name *, struct type *) #define RB_PROTOTYPE_NFIND(name, type, attr) \ attr struct type *name##_RB_NFIND(struct name *, struct type *) +#define RB_PROTOTYPE_FIND_LIGHT(name, type, attr) \ + attr struct type *name##_RB_FIND_LIGHT(struct name *, const void *) +#define RB_PROTOTYPE_NFIND_LIGHT(name, type, attr) \ + attr struct type *name##_RB_NFIND_LIGHT(struct name *, const void *) #define RB_PROTOTYPE_NEXT(name, type, attr) \ attr struct type *name##_RB_NEXT(struct type *) #define RB_PROTOTYPE_PREV(name, type, attr) \ @@ -436,15 +442,17 @@ struct { \ RB_GENERATE_PREV(name, type, field, attr) \ RB_GENERATE_MINMAX(name, type, field, attr) -#define RB_GENERATE_WITH_COMPARE(name, type, field, cmp) \ - RB_GENERATE_WITH_COMPARE_INTERNAL(name, type, field, cmp,) -#define RB_GENERATE_WITH_COMPARE_STATIC(name, type, field, cmp) \ - RB_GENERATE_WITH_COMPARE_INTERNAL(name, type, field, cmp, __unused static) -#define RB_GENERATE_WITH_COMPARE_INTERNAL(name, type, field, cmp, attr) \ +#define RB_GENERATE_WITH_COMPARE(name, type, field, cmp, lcmp) \ + RB_GENERATE_WITH_COMPARE_INTERNAL(name, type, field, cmp, lcmp,) +#define RB_GENERATE_WITH_COMPARE_STATIC(name, type, field, cmp, lcmp) \ + RB_GENERATE_WITH_COMPARE_INTERNAL(name, type, field, cmp, lcmp, __unused static) +#define RB_GENERATE_WITH_COMPARE_INTERNAL(name, type, field, cmp, lcmp, attr) \ RB_GENERATE_INSERT_COLOR(name, type, field, attr) \ RB_GENERATE_INSERT(name, type, field, cmp, attr) \ RB_GENERATE_FIND(name, type, field, cmp, attr) \ - RB_GENERATE_NFIND(name, type, field, cmp, attr) + RB_GENERATE_NFIND(name, type, field, cmp, attr) \ + RB_GENERATE_FIND_LIGHT(name, type, field, lcmp, attr) \ + RB_GENERATE_NFIND_LIGHT(name, type, field, lcmp, attr) #define RB_GENERATE_ALL(name, type, field, cmp) \ RB_GENERATE_ALL_INTERNAL(name, type, field, cmp,) @@ -719,6 +727,47 @@ name##_RB_NFIND(struct name *head, struct type *elm) \ return (res); \ } +#define RB_GENERATE_FIND_LIGHT(name, type, field, lcmp, attr) \ +/* Finds the node with the same key as elm */ \ +attr struct type * \ +name##_RB_FIND_LIGHT(struct name *head, const void *lelm) \ +{ \ + struct type *tmp = RB_ROOT(head); \ + int comp; \ + while (tmp) { \ + comp = lcmp(lelm, tmp); \ + if (comp < 0) \ + tmp = RB_LEFT(tmp, field); \ + else if (comp > 0) \ + tmp = RB_RIGHT(tmp, field); \ + else \ + return (tmp); \ + } \ + return (NULL); \ +} + +#define RB_GENERATE_NFIND_LIGHT(name, type, field, lcmp, attr) \ +/* Finds the first node greater than or equal to the search key */ \ +attr struct type * \ +name##_RB_NFIND_LIGHT(struct name *head, const void *lelm) \ +{ \ + struct type *tmp = RB_ROOT(head); \ + struct type *res = NULL; \ + int comp; \ + while (tmp) { \ + comp = lcmp(lelm, tmp); \ + if (comp < 0) { \ + res = tmp; \ + tmp = RB_LEFT(tmp, field); \ + } \ + else if (comp > 0) \ + tmp = RB_RIGHT(tmp, field); \ + else \ + return (tmp); \ + } \ + return (res); \ +} + #define RB_GENERATE_NEXT(name, type, field, attr) \ /* ARGSUSED */ \ attr struct type * \ @@ -788,6 +837,8 @@ name##_RB_MINMAX(struct name *head, int val) \ #define RB_REMOVE(name, x, y) name##_RB_REMOVE(x, y) #define RB_FIND(name, x, y) name##_RB_FIND(x, y) #define RB_NFIND(name, x, y) name##_RB_NFIND(x, y) +#define RB_FIND_LIGHT(name, x, y) name##_RB_FIND_LIGHT(x, y) +#define RB_NFIND_LIGHT(name, x, y) name##_RB_NFIND_LIGHT(x, y) #define RB_NEXT(name, x, y) name##_RB_NEXT(y) #define RB_PREV(name, x, y) name##_RB_PREV(y) #define RB_MIN(name, x) name##_RB_MINMAX(x, RB_NEGINF) diff --git a/libraries/libvapours/include/vapours/util/util_intrusive_red_black_tree.hpp b/libraries/libvapours/include/vapours/util/util_intrusive_red_black_tree.hpp index 84c05216d..3a6d61459 100644 --- a/libraries/libvapours/include/vapours/util/util_intrusive_red_black_tree.hpp +++ b/libraries/libvapours/include/vapours/util/util_intrusive_red_black_tree.hpp @@ -235,6 +235,27 @@ namespace ams::util { } + template + concept HasLightCompareType = requires { + { std::is_same::value } -> std::convertible_to; + }; + + namespace impl { + + template + consteval auto *GetLightCompareType() { + if constexpr (HasLightCompareType) { + return static_cast(nullptr); + } else { + return static_cast(nullptr); + } + } + + } + + template + using LightCompareType = typename std::remove_pointer())>::type; + template class IntrusiveRedBlackTree { NON_COPYABLE(IntrusiveRedBlackTree); @@ -258,6 +279,10 @@ namespace ams::util { using iterator = Iterator; using const_iterator = Iterator; + using light_value_type = LightCompareType; + using const_light_pointer = const light_value_type *; + using const_light_reference = const light_value_type &; + template class Iterator { public: @@ -325,12 +350,16 @@ namespace ams::util { }; private: /* Generate static implementations for comparison operations for IntrusiveRedBlackTreeRoot. */ - RB_GENERATE_WITH_COMPARE_STATIC(IntrusiveRedBlackTreeRootWithCompare, IntrusiveRedBlackTreeNode, entry, CompareImpl); + RB_GENERATE_WITH_COMPARE_STATIC(IntrusiveRedBlackTreeRootWithCompare, IntrusiveRedBlackTreeNode, entry, CompareImpl, LightCompareImpl); private: static int CompareImpl(const IntrusiveRedBlackTreeNode *lhs, const IntrusiveRedBlackTreeNode *rhs) { return Comparator::Compare(*Traits::GetParent(lhs), *Traits::GetParent(rhs)); } + static int LightCompareImpl(const void *elm, const IntrusiveRedBlackTreeNode *rhs) { + return Comparator::Compare(*static_cast(elm), *Traits::GetParent(rhs)); + } + /* Define accessors using RB_* functions. */ IntrusiveRedBlackTreeNode *InsertImpl(IntrusiveRedBlackTreeNode *node) { return RB_INSERT(IntrusiveRedBlackTreeRootWithCompare, static_cast(&this->impl.root), node); @@ -343,6 +372,14 @@ namespace ams::util { IntrusiveRedBlackTreeNode *NFindImpl(IntrusiveRedBlackTreeNode const *node) const { return RB_NFIND(IntrusiveRedBlackTreeRootWithCompare, const_cast(static_cast(&this->impl.root)), const_cast(node)); } + + IntrusiveRedBlackTreeNode *FindLightImpl(const_light_pointer lelm) const { + return RB_FIND_LIGHT(IntrusiveRedBlackTreeRootWithCompare, const_cast(static_cast(&this->impl.root)), static_cast(lelm)); + } + + IntrusiveRedBlackTreeNode *NFindLightImpl(const_light_pointer lelm) const { + return RB_NFIND_LIGHT(IntrusiveRedBlackTreeRootWithCompare, const_cast(static_cast(&this->impl.root)), static_cast(lelm)); + } public: constexpr ALWAYS_INLINE IntrusiveRedBlackTree() : impl() { /* ... */ } @@ -417,6 +454,14 @@ namespace ams::util { iterator nfind(const_reference ref) const { return iterator(this->NFindImpl(Traits::GetNode(std::addressof(ref)))); } + + iterator find_light(const_light_reference ref) const { + return iterator(this->FindLightImpl(std::addressof(ref))); + } + + iterator nfind_light(const_light_reference ref) const { + return iterator(this->NFindLightImpl(std::addressof(ref))); + } }; template>