From aed9d3f5353984f9f1eb1794dac288ff53e0d499 Mon Sep 17 00:00:00 2001 From: Michael Scire Date: Wed, 20 Oct 2021 11:02:17 -0700 Subject: [PATCH] util: better match true std::atomic semantics --- .../source/secmon_exception_handler.cpp | 2 +- .../program/source/smc/secmon_smc_se_lock.cpp | 2 +- .../mesosphere/kern_k_dynamic_slab_heap.hpp | 8 +-- .../include/mesosphere/kern_k_process.hpp | 2 +- .../include/mesosphere/kern_k_thread.hpp | 8 +-- .../source/arch/arm64/kern_cpu.cpp | 2 +- .../source/kern_k_client_port.cpp | 2 +- .../libmesosphere/source/kern_k_process.cpp | 12 ++-- .../libmesosphere/source/kern_k_thread.cpp | 4 +- libraries/libmesosphere/source/kern_panic.cpp | 6 +- .../vapours/util/arch/arm64/util_atomic.hpp | 60 +++++++++++++------ .../vapours/util/arch/generic/util_atomic.hpp | 47 +++++++++++---- 12 files changed, 102 insertions(+), 53 deletions(-) diff --git a/exosphere/program/source/secmon_exception_handler.cpp b/exosphere/program/source/secmon_exception_handler.cpp index 1f033315d..e1af21055 100644 --- a/exosphere/program/source/secmon_exception_handler.cpp +++ b/exosphere/program/source/secmon_exception_handler.cpp @@ -22,7 +22,7 @@ namespace ams::secmon { constexpr inline uintptr_t PMC = MemoryRegionVirtualDevicePmc.GetAddress(); - constinit util::Atomic g_is_locked{false}; + constinit util::Atomic g_is_locked = false; } diff --git a/exosphere/program/source/smc/secmon_smc_se_lock.cpp b/exosphere/program/source/smc/secmon_smc_se_lock.cpp index b96eba8eb..1dc5866a2 100644 --- a/exosphere/program/source/smc/secmon_smc_se_lock.cpp +++ b/exosphere/program/source/smc/secmon_smc_se_lock.cpp @@ -21,7 +21,7 @@ namespace ams::secmon::smc { namespace { - constinit util::Atomic g_is_locked{false}; + constinit util::Atomic g_is_locked = false; ALWAYS_INLINE bool TryLockSecurityEngineImpl() { bool value = false; diff --git a/libraries/libmesosphere/include/mesosphere/kern_k_dynamic_slab_heap.hpp b/libraries/libmesosphere/include/mesosphere/kern_k_dynamic_slab_heap.hpp index 2de558c96..a4f687ce5 100644 --- a/libraries/libmesosphere/include/mesosphere/kern_k_dynamic_slab_heap.hpp +++ b/libraries/libmesosphere/include/mesosphere/kern_k_dynamic_slab_heap.hpp @@ -66,7 +66,7 @@ namespace ams::kern { KSlabHeapImpl::Free(allocated + i); } - m_count.FetchAdd(sizeof(PageBuffer) / sizeof(T)); + m_count += sizeof(PageBuffer) / sizeof(T); } } @@ -89,7 +89,7 @@ namespace ams::kern { for (size_t i = 1; i < sizeof(PageBuffer) / sizeof(T); i++) { KSlabHeapImpl::Free(allocated + i); } - m_count.FetchAdd(sizeof(PageBuffer) / sizeof(T)); + m_count += sizeof(PageBuffer) / sizeof(T); } } } @@ -99,7 +99,7 @@ namespace ams::kern { std::construct_at(allocated); /* Update our tracking. */ - size_t used = m_used.FetchAdd(1) + 1; + const size_t used = ++m_used; size_t peak = m_peak.Load(); while (peak < used) { if (m_peak.CompareExchangeWeak(peak, used)) { @@ -113,7 +113,7 @@ namespace ams::kern { ALWAYS_INLINE void Free(T *t) { KSlabHeapImpl::Free(t); - m_used.FetchSub(1); + --m_used; } }; diff --git a/libraries/libmesosphere/include/mesosphere/kern_k_process.hpp b/libraries/libmesosphere/include/mesosphere/kern_k_process.hpp index e1a9ed684..311db71ed 100644 --- a/libraries/libmesosphere/include/mesosphere/kern_k_process.hpp +++ b/libraries/libmesosphere/include/mesosphere/kern_k_process.hpp @@ -288,7 +288,7 @@ namespace ams::kern { KThread *GetExceptionThread() const { return m_exception_thread; } - void AddCpuTime(s64 diff) { m_cpu_time.FetchAdd(diff); } + void AddCpuTime(s64 diff) { m_cpu_time += diff; } s64 GetCpuTime() { return m_cpu_time.Load(); } constexpr s64 GetScheduledCount() const { return m_schedule_count; } diff --git a/libraries/libmesosphere/include/mesosphere/kern_k_thread.hpp b/libraries/libmesosphere/include/mesosphere/kern_k_thread.hpp index 23fe71058..c726e3fa2 100644 --- a/libraries/libmesosphere/include/mesosphere/kern_k_thread.hpp +++ b/libraries/libmesosphere/include/mesosphere/kern_k_thread.hpp @@ -176,8 +176,6 @@ namespace ams::kern { }; static_assert(ams::util::HasRedBlackKeyType); static_assert(std::same_as, ConditionVariableComparator::RedBlackKeyType>); - private: - static constinit inline util::Atomic s_next_thread_id{0}; private: util::IntrusiveListNode m_process_list_node{}; util::IntrusiveRedBlackTreeNode m_condvar_arbiter_tree_node{}; @@ -348,11 +346,11 @@ namespace ams::kern { #endif ALWAYS_INLINE void RegisterDpc(DpcFlag flag) { - this->GetStackParameters().dpc_flags.FetchOr(flag); + this->GetStackParameters().dpc_flags |= flag; } ALWAYS_INLINE void ClearDpc(DpcFlag flag) { - this->GetStackParameters().dpc_flags.FetchAnd(~flag); + this->GetStackParameters().dpc_flags &= ~flag; } ALWAYS_INLINE u8 GetDpc() const { @@ -544,7 +542,7 @@ namespace ams::kern { constexpr bool IsAttachedToDebugger() const { return m_debug_attached; } void AddCpuTime(s32 core_id, s64 amount) { - m_cpu_time.FetchAdd(amount); + m_cpu_time += amount; /* TODO: Debug kernels track per-core tick counts. Should we? */ MESOSPHERE_UNUSED(core_id); } diff --git a/libraries/libmesosphere/source/arch/arm64/kern_cpu.cpp b/libraries/libmesosphere/source/arch/arm64/kern_cpu.cpp index bfcad7a6b..3f3e83125 100644 --- a/libraries/libmesosphere/source/arch/arm64/kern_cpu.cpp +++ b/libraries/libmesosphere/source/arch/arm64/kern_cpu.cpp @@ -287,7 +287,7 @@ namespace ams::kern::arch::arm64::cpu { break; } - m_target_cores.FetchAnd(~(1ul << GetCurrentCoreId())); + m_target_cores &= (~(1ul << GetCurrentCoreId())); } ALWAYS_INLINE void SetEventLocally() { diff --git a/libraries/libmesosphere/source/kern_k_client_port.cpp b/libraries/libmesosphere/source/kern_k_client_port.cpp index 9b523a2b4..b3e7410da 100644 --- a/libraries/libmesosphere/source/kern_k_client_port.cpp +++ b/libraries/libmesosphere/source/kern_k_client_port.cpp @@ -28,7 +28,7 @@ namespace ams::kern { void KClientPort::OnSessionFinalized() { KScopedSchedulerLock sl; - if (m_num_sessions.FetchSub(1) == m_max_sessions) { + if (const auto prev = m_num_sessions--; prev == m_max_sessions) { this->NotifyAvailable(); } } diff --git a/libraries/libmesosphere/source/kern_k_process.cpp b/libraries/libmesosphere/source/kern_k_process.cpp index f5253e986..9e2771b5d 100644 --- a/libraries/libmesosphere/source/kern_k_process.cpp +++ b/libraries/libmesosphere/source/kern_k_process.cpp @@ -25,8 +25,8 @@ namespace ams::kern { constexpr u64 ProcessIdMin = InitialProcessIdMax + 1; constexpr u64 ProcessIdMax = std::numeric_limits::max(); - constinit util::Atomic g_initial_process_id{InitialProcessIdMin}; - constinit util::Atomic g_process_id{ProcessIdMin}; + constinit util::Atomic g_initial_process_id = InitialProcessIdMin; + constinit util::Atomic g_process_id = ProcessIdMin; Result TerminateChildren(KProcess *process, const KThread *thread_to_not_terminate) { /* Request that all children threads terminate. */ @@ -299,7 +299,7 @@ namespace ams::kern { R_TRY(m_capabilities.Initialize(caps, num_caps, std::addressof(m_page_table))); /* Initialize the process id. */ - m_process_id = g_initial_process_id.FetchAdd(1); + m_process_id = g_initial_process_id++; MESOSPHERE_ABORT_UNLESS(InitialProcessIdMin <= m_process_id); MESOSPHERE_ABORT_UNLESS(m_process_id <= InitialProcessIdMax); @@ -409,7 +409,7 @@ namespace ams::kern { R_TRY(m_capabilities.Initialize(user_caps, num_caps, std::addressof(m_page_table))); /* Initialize the process id. */ - m_process_id = g_process_id.FetchAdd(1); + m_process_id = g_process_id++; MESOSPHERE_ABORT_UNLESS(ProcessIdMin <= m_process_id); MESOSPHERE_ABORT_UNLESS(m_process_id <= ProcessIdMax); @@ -791,13 +791,13 @@ namespace ams::kern { void KProcess::IncrementRunningThreadCount() { MESOSPHERE_ASSERT(m_num_running_threads.Load() >= 0); - m_num_running_threads.FetchAdd(1); + ++m_num_running_threads; } void KProcess::DecrementRunningThreadCount() { MESOSPHERE_ASSERT(m_num_running_threads.Load() > 0); - if (m_num_running_threads.FetchSub(1) == 1) { + if (const auto prev = m_num_running_threads--; prev == 1) { this->Terminate(); } } diff --git a/libraries/libmesosphere/source/kern_k_thread.cpp b/libraries/libmesosphere/source/kern_k_thread.cpp index e24ba6d53..bce7238f1 100644 --- a/libraries/libmesosphere/source/kern_k_thread.cpp +++ b/libraries/libmesosphere/source/kern_k_thread.cpp @@ -21,6 +21,8 @@ namespace ams::kern { constexpr inline s32 TerminatingThreadPriority = ams::svc::SystemThreadPriorityHighest - 1; + constinit util::Atomic g_thread_id = 0; + constexpr ALWAYS_INLINE bool IsKernelAddressKey(KProcessAddress key) { const uintptr_t key_uptr = GetInteger(key); return KernelVirtualAddressSpaceBase <= key_uptr && key_uptr <= KernelVirtualAddressSpaceLast && (key_uptr & 1) == 0; @@ -219,7 +221,7 @@ namespace ams::kern { this->SetInExceptionHandler(); /* Set thread ID. */ - m_thread_id = s_next_thread_id.FetchAdd(1); + m_thread_id = g_thread_id++; /* We initialized! */ m_initialized = true; diff --git a/libraries/libmesosphere/source/kern_panic.cpp b/libraries/libmesosphere/source/kern_panic.cpp index f15afdfd8..a42f6f961 100644 --- a/libraries/libmesosphere/source/kern_panic.cpp +++ b/libraries/libmesosphere/source/kern_panic.cpp @@ -42,15 +42,15 @@ namespace ams::kern { return arr; }(); - constinit util::Atomic g_next_ticket{0}; - constinit util::Atomic g_current_ticket{0}; + constinit util::Atomic g_next_ticket = 0; + constinit util::Atomic g_current_ticket = 0; constinit std::array g_core_tickets = NegativeArray; s32 GetCoreTicket() { const s32 core_id = GetCurrentCoreId(); if (g_core_tickets[core_id] == -1) { - g_core_tickets[core_id] = 2 * g_next_ticket.FetchAdd(1); + g_core_tickets[core_id] = 2 * (g_next_ticket++); } return g_core_tickets[core_id]; } diff --git a/libraries/libvapours/include/vapours/util/arch/arm64/util_atomic.hpp b/libraries/libvapours/include/vapours/util/arch/arm64/util_atomic.hpp index 3d8d412d8..2e025d313 100644 --- a/libraries/libvapours/include/vapours/util/arch/arm64/util_atomic.hpp +++ b/libraries/libvapours/include/vapours/util/arch/arm64/util_atomic.hpp @@ -110,6 +110,11 @@ namespace ams::util { using StorageType = impl::AtomicStorage; static constexpr bool IsIntegral = std::integral; + static constexpr bool IsPointer = std::is_pointer::value; + + static constexpr bool HasArithmeticFunctions = IsIntegral || IsPointer; + + using DifferenceType = typename std::conditional::type>::type; static constexpr ALWAYS_INLINE T ConvertToType(StorageType s) { if constexpr (std::integral) { @@ -140,8 +145,8 @@ namespace ams::util { ALWAYS_INLINE volatile StorageType *GetStoragePointer() { return reinterpret_cast< volatile StorageType *>(std::addressof(m_v)); } ALWAYS_INLINE const volatile StorageType *GetStoragePointer() const { return reinterpret_cast(std::addressof(m_v)); } public: - ALWAYS_INLINE explicit Atomic() { /* ... */ } - constexpr ALWAYS_INLINE explicit Atomic(T v) : m_v(ConvertToStorage(v)) { /* ... */ } + ALWAYS_INLINE Atomic() { /* ... */ } + constexpr ALWAYS_INLINE Atomic(T v) : m_v(ConvertToStorage(v)) { /* ... */ } constexpr ALWAYS_INLINE T operator=(T desired) { if (std::is_constant_evaluated()) { @@ -296,27 +301,44 @@ namespace ams::util { return true; } - #define AMS_UTIL_IMPL_DEFINE_ATOMIC_FETCH_OPERATE_FUNCTION(_OPERATION_, _OPERATOR_) \ - template::type> \ - ALWAYS_INLINE T Fetch ## _OPERATION_(T arg) { \ - static_assert(Enable); \ - volatile StorageType * const p = this->GetStoragePointer(); \ - const StorageType s = ConvertToStorage(arg); \ - \ - StorageType current; \ - do { \ - current = impl::LoadAcquireExclusiveForAtomic(p); \ - } while (AMS_UNLIKELY(!impl::StoreReleaseExclusiveForAtomic(p, current _OPERATOR_ s))); \ - return static_cast(current); \ + #define AMS_UTIL_IMPL_DEFINE_ATOMIC_FETCH_OPERATE_FUNCTION(_OPERATION_, _OPERATOR_, _POINTER_ALLOWED_) \ + template::type> \ + ALWAYS_INLINE T Fetch ## _OPERATION_(DifferenceType arg) { \ + static_assert(Enable == (IsIntegral || (_POINTER_ALLOWED_ && IsPointer))); \ + volatile StorageType * const p = this->GetStoragePointer(); \ + \ + StorageType current; \ + do { \ + current = impl::LoadAcquireExclusiveForAtomic(p); \ + } while (AMS_UNLIKELY(!impl::StoreReleaseExclusiveForAtomic(p, ConvertToStorage(ConvertToType(current) _OPERATOR_ arg)))); \ + return ConvertToType(current); \ + } \ + \ + template::type> \ + ALWAYS_INLINE T operator _OPERATOR_##=(DifferenceType arg) { \ + static_assert(Enable == (IsIntegral || (_POINTER_ALLOWED_ && IsPointer))); \ + return this->Fetch ## _OPERATION_(arg) _OPERATOR_ arg; \ } - AMS_UTIL_IMPL_DEFINE_ATOMIC_FETCH_OPERATE_FUNCTION(Add, +) - AMS_UTIL_IMPL_DEFINE_ATOMIC_FETCH_OPERATE_FUNCTION(Sub, -) - AMS_UTIL_IMPL_DEFINE_ATOMIC_FETCH_OPERATE_FUNCTION(And, &) - AMS_UTIL_IMPL_DEFINE_ATOMIC_FETCH_OPERATE_FUNCTION(Or, |) - AMS_UTIL_IMPL_DEFINE_ATOMIC_FETCH_OPERATE_FUNCTION(Xor, ^) + AMS_UTIL_IMPL_DEFINE_ATOMIC_FETCH_OPERATE_FUNCTION(Add, +, true) + AMS_UTIL_IMPL_DEFINE_ATOMIC_FETCH_OPERATE_FUNCTION(Sub, -, true) + AMS_UTIL_IMPL_DEFINE_ATOMIC_FETCH_OPERATE_FUNCTION(And, &, false) + AMS_UTIL_IMPL_DEFINE_ATOMIC_FETCH_OPERATE_FUNCTION(Or, |, false) + AMS_UTIL_IMPL_DEFINE_ATOMIC_FETCH_OPERATE_FUNCTION(Xor, ^, false) #undef AMS_UTIL_IMPL_DEFINE_ATOMIC_FETCH_OPERATE_FUNCTION + + template::type> + ALWAYS_INLINE T operator++() { static_assert(Enable == HasArithmeticFunctions); return this->FetchAdd(1) + 1; } + + template::type> + ALWAYS_INLINE T operator++(int) { static_assert(Enable == HasArithmeticFunctions); return this->FetchAdd(1); } + + template::type> + ALWAYS_INLINE T operator--() { static_assert(Enable == HasArithmeticFunctions); return this->FetchSub(1) - 1; } + + template::type> + ALWAYS_INLINE T operator--(int) { static_assert(Enable == HasArithmeticFunctions); return this->FetchSub(1); } }; diff --git a/libraries/libvapours/include/vapours/util/arch/generic/util_atomic.hpp b/libraries/libvapours/include/vapours/util/arch/generic/util_atomic.hpp index 5c57f0a45..cb458a136 100644 --- a/libraries/libvapours/include/vapours/util/arch/generic/util_atomic.hpp +++ b/libraries/libvapours/include/vapours/util/arch/generic/util_atomic.hpp @@ -55,13 +55,20 @@ namespace ams::util { class Atomic { NON_COPYABLE(Atomic); NON_MOVEABLE(Atomic); + private: + static constexpr bool IsIntegral = std::integral; + static constexpr bool IsPointer = std::is_pointer::value; + + static constexpr bool HasArithmeticFunctions = IsIntegral || IsPointer; + + using DifferenceType = typename std::conditional::type>::type; private: static_assert(std::atomic::is_always_lock_free); private: std::atomic m_v; public: - ALWAYS_INLINE explicit Atomic() { /* ... */ } - constexpr ALWAYS_INLINE explicit Atomic(T v) : m_v(v) { /* ... */ } + ALWAYS_INLINE Atomic() { /* ... */ } + constexpr ALWAYS_INLINE Atomic(T v) : m_v(v) { /* ... */ } ALWAYS_INLINE T operator=(T desired) { return (m_v = desired); @@ -93,18 +100,38 @@ namespace ams::util { } - #define AMS_UTIL_IMPL_DEFINE_ATOMIC_FETCH_OPERATE_FUNCTION(_OPERATION_, _OPERATION_LOWER_) \ - ALWAYS_INLINE T Fetch ## _OPERATION_(T arg) { \ - return m_v.fetch_##_OPERATION_LOWER_(arg); \ + #define AMS_UTIL_IMPL_DEFINE_ATOMIC_FETCH_OPERATE_FUNCTION(_OPERATION_, _OPERATION_LOWER_, _OPERATOR_, _POINTER_ALLOWED_) \ + template::type> \ + ALWAYS_INLINE T Fetch ## _OPERATION_(DifferenceType arg) { \ + static_assert(Enable == (IsIntegral || (_POINTER_ALLOWED_ && IsPointer))); \ + return m_v.fetch_##_OPERATION_LOWER_(arg); \ + } \ + \ + template::type> \ + ALWAYS_INLINE T operator _OPERATOR_##=(DifferenceType arg) { \ + static_assert(Enable == (IsIntegral || (_POINTER_ALLOWED_ && IsPointer))); \ + return this->Fetch##_OPERATION_(arg) _OPERATOR_ arg; \ } - AMS_UTIL_IMPL_DEFINE_ATOMIC_FETCH_OPERATE_FUNCTION(Add, add) - AMS_UTIL_IMPL_DEFINE_ATOMIC_FETCH_OPERATE_FUNCTION(Sub, sub) - AMS_UTIL_IMPL_DEFINE_ATOMIC_FETCH_OPERATE_FUNCTION(And, and) - AMS_UTIL_IMPL_DEFINE_ATOMIC_FETCH_OPERATE_FUNCTION(Or, or) - AMS_UTIL_IMPL_DEFINE_ATOMIC_FETCH_OPERATE_FUNCTION(Xor, xor) + AMS_UTIL_IMPL_DEFINE_ATOMIC_FETCH_OPERATE_FUNCTION(Add, add, +, true) + AMS_UTIL_IMPL_DEFINE_ATOMIC_FETCH_OPERATE_FUNCTION(Sub, sub, -, true) + AMS_UTIL_IMPL_DEFINE_ATOMIC_FETCH_OPERATE_FUNCTION(And, and, &, false) + AMS_UTIL_IMPL_DEFINE_ATOMIC_FETCH_OPERATE_FUNCTION(Or, or, |, false) + AMS_UTIL_IMPL_DEFINE_ATOMIC_FETCH_OPERATE_FUNCTION(Xor, xor, ^, false) #undef AMS_UTIL_IMPL_DEFINE_ATOMIC_FETCH_OPERATE_FUNCTION + + template::type> + ALWAYS_INLINE T operator++() { static_assert(Enable == HasArithmeticFunctions); return this->FetchAdd(1) + 1; } + + template::type> + ALWAYS_INLINE T operator++(int) { static_assert(Enable == HasArithmeticFunctions); return this->FetchAdd(1); } + + template::type> + ALWAYS_INLINE T operator--() { static_assert(Enable == HasArithmeticFunctions); return this->FetchSub(1) - 1; } + + template::type> + ALWAYS_INLINE T operator--(int) { static_assert(Enable == HasArithmeticFunctions); return this->FetchSub(1); } };