From 85f93551844d2d3968d79882cb3b75ddb375ac4c Mon Sep 17 00:00:00 2001 From: Michael Scire Date: Wed, 7 Apr 2021 09:57:32 -0700 Subject: [PATCH] kern: use KScopedLightLockPair helper for page table pair-locks --- .../mesosphere/kern_k_page_table_base.hpp | 15 +- .../source/kern_k_page_table_base.cpp | 132 +++++++++--------- 2 files changed, 83 insertions(+), 64 deletions(-) diff --git a/libraries/libmesosphere/include/mesosphere/kern_k_page_table_base.hpp b/libraries/libmesosphere/include/mesosphere/kern_k_page_table_base.hpp index 538b98c44..1a7afe4ef 100644 --- a/libraries/libmesosphere/include/mesosphere/kern_k_page_table_base.hpp +++ b/libraries/libmesosphere/include/mesosphere/kern_k_page_table_base.hpp @@ -301,9 +301,22 @@ namespace ams::kern { void CleanupForIpcClientOnServerSetupFailure(PageLinkedList *page_list, KProcessAddress address, size_t size, KMemoryPermission prot_perm); size_t GetSize(KMemoryState state) const; + + ALWAYS_INLINE bool GetPhysicalAddressLocked(KPhysicalAddress *out, KProcessAddress virt_addr) const { + /* Validate pre-conditions. */ + MESOSPHERE_AUDIT(this->IsLockedByCurrentThread()); + + return this->GetImpl().GetPhysicalAddress(out, virt_addr); + } public: bool GetPhysicalAddress(KPhysicalAddress *out, KProcessAddress virt_addr) const { - return this->GetImpl().GetPhysicalAddress(out, virt_addr); + /* Validate pre-conditions. */ + MESOSPHERE_AUDIT(!this->IsLockedByCurrentThread()); + + /* Acquire exclusive access to the table while doing address translation. */ + KScopedLightLock lk(m_general_lock); + + return this->GetPhysicalAddressLocked(out, virt_addr); } KBlockInfoManager *GetBlockInfoManager() const { return m_block_info_manager; } diff --git a/libraries/libmesosphere/source/kern_k_page_table_base.cpp b/libraries/libmesosphere/source/kern_k_page_table_base.cpp index 12e4d196e..ed46bbd23 100644 --- a/libraries/libmesosphere/source/kern_k_page_table_base.cpp +++ b/libraries/libmesosphere/source/kern_k_page_table_base.cpp @@ -18,6 +18,62 @@ namespace ams::kern { + namespace { + + class KScopedLightLockPair { + NON_COPYABLE(KScopedLightLockPair); + NON_MOVEABLE(KScopedLightLockPair); + private: + KLightLock *m_lower; + KLightLock *m_upper; + public: + ALWAYS_INLINE KScopedLightLockPair(KLightLock &lhs, KLightLock &rhs) { + /* Ensure our locks are in a consistent order. */ + if (std::addressof(lhs) <= std::addressof(rhs)) { + m_lower = std::addressof(lhs); + m_upper = std::addressof(rhs); + } else { + m_lower = std::addressof(rhs); + m_upper = std::addressof(lhs); + } + + /* Acquire both locks. */ + m_lower->Lock(); + if (m_lower != m_upper) { + m_upper->Lock(); + } + } + + ~KScopedLightLockPair() { + /* Unlock the upper lock. */ + if (m_upper != nullptr && m_upper != m_lower) { + m_upper->Unlock(); + } + + /* Unlock the lower lock. */ + if (m_lower != nullptr) { + m_lower->Unlock(); + } + } + public: + /* Utility. */ + ALWAYS_INLINE void TryUnlockHalf(KLightLock &lock) { + /* Only allow unlocking if the lock is half the pair. */ + if (m_lower != m_upper) { + /* We want to be sure the lock is one we own. */ + if (m_lower == std::addressof(lock)) { + lock.Unlock(); + m_lower = nullptr; + } else if (m_upper == std::addressof(lock)) { + lock.Unlock(); + m_upper = nullptr; + } + } + } + }; + + } + Result KPageTableBase::InitializeForKernel(bool is_64_bit, void *table, KVirtualAddress start, KVirtualAddress end) { /* Initialize our members. */ m_address_space_width = (is_64_bit) ? BITSIZEOF(u64) : BITSIZEOF(u32); @@ -529,7 +585,7 @@ namespace ams::kern { /* Get the physical address, if we're supposed to. */ if (out_paddr != nullptr) { - MESOSPHERE_ABORT_UNLESS(this->GetPhysicalAddress(out_paddr, addr)); + MESOSPHERE_ABORT_UNLESS(this->GetPhysicalAddressLocked(out_paddr, addr)); } /* Make the page group, if we're supposed to. */ @@ -2458,18 +2514,8 @@ namespace ams::kern { KPageTableBase &src_page_table = *this; KPageTableBase &dst_page_table = GetCurrentProcess().GetPageTable().GetBasePageTable(); - /* Get the table locks. */ - KLightLock &lock_0 = (reinterpret_cast(std::addressof(src_page_table)) <= reinterpret_cast(std::addressof(dst_page_table))) ? src_page_table.m_general_lock : dst_page_table.m_general_lock; - KLightLock &lock_1 = (reinterpret_cast(std::addressof(src_page_table)) <= reinterpret_cast(std::addressof(dst_page_table))) ? dst_page_table.m_general_lock : src_page_table.m_general_lock; - - /* Lock the first lock. */ - KScopedLightLock lk0(lock_0); - - /* If necessary, lock the second lock. */ - std::optional lk1; - if (std::addressof(lock_0) != std::addressof(lock_1)) { - lk1.emplace(lock_1); - } + /* Acquire the table locks. */ + KScopedLightLockPair lk(src_page_table.m_general_lock, dst_page_table.m_general_lock); /* Check that the desired range is readable io memory. */ R_TRY(this->CheckMemoryStateContiguous(address, size, KMemoryState_All, KMemoryState_Io, KMemoryPermission_UserRead, KMemoryPermission_UserRead, KMemoryAttribute_None, KMemoryAttribute_None)); @@ -2480,7 +2526,7 @@ namespace ams::kern { while (address <= last_address) { /* Get the current physical address. */ KPhysicalAddress phys_addr; - MESOSPHERE_ABORT_UNLESS(src_page_table.GetPhysicalAddress(std::addressof(phys_addr), address)); + MESOSPHERE_ABORT_UNLESS(src_page_table.GetPhysicalAddressLocked(std::addressof(phys_addr), address)); /* Determine the current read size. */ const size_t cur_size = std::min(last_address - address + 1, util::AlignDown(GetInteger(address) + PageSize, PageSize) - GetInteger(address)); @@ -2504,18 +2550,8 @@ namespace ams::kern { KPageTableBase &src_page_table = *this; KPageTableBase &dst_page_table = GetCurrentProcess().GetPageTable().GetBasePageTable(); - /* Get the table locks. */ - KLightLock &lock_0 = (reinterpret_cast(std::addressof(src_page_table)) <= reinterpret_cast(std::addressof(dst_page_table))) ? src_page_table.m_general_lock : dst_page_table.m_general_lock; - KLightLock &lock_1 = (reinterpret_cast(std::addressof(src_page_table)) <= reinterpret_cast(std::addressof(dst_page_table))) ? dst_page_table.m_general_lock : src_page_table.m_general_lock; - - /* Lock the first lock. */ - KScopedLightLock lk0(lock_0); - - /* If necessary, lock the second lock. */ - std::optional lk1; - if (std::addressof(lock_0) != std::addressof(lock_1)) { - lk1.emplace(lock_1); - } + /* Acquire the table locks. */ + KScopedLightLockPair lk(src_page_table.m_general_lock, dst_page_table.m_general_lock); /* Check that the desired range is writable io memory. */ R_TRY(this->CheckMemoryStateContiguous(address, size, KMemoryState_All, KMemoryState_Io, KMemoryPermission_UserReadWrite, KMemoryPermission_UserReadWrite, KMemoryAttribute_None, KMemoryAttribute_None)); @@ -2526,7 +2562,7 @@ namespace ams::kern { while (address <= last_address) { /* Get the current physical address. */ KPhysicalAddress phys_addr; - MESOSPHERE_ABORT_UNLESS(src_page_table.GetPhysicalAddress(std::addressof(phys_addr), address)); + MESOSPHERE_ABORT_UNLESS(src_page_table.GetPhysicalAddressLocked(std::addressof(phys_addr), address)); /* Determine the current read size. */ const size_t cur_size = std::min(last_address - address + 1, util::AlignDown(GetInteger(address) + PageSize, PageSize) - GetInteger(address)); @@ -3076,18 +3112,8 @@ namespace ams::kern { /* Copy the memory. */ { - /* Get the table locks. */ - KLightLock &lock_0 = (reinterpret_cast(std::addressof(src_page_table)) <= reinterpret_cast(std::addressof(dst_page_table))) ? src_page_table.m_general_lock : dst_page_table.m_general_lock; - KLightLock &lock_1 = (reinterpret_cast(std::addressof(src_page_table)) <= reinterpret_cast(std::addressof(dst_page_table))) ? dst_page_table.m_general_lock : src_page_table.m_general_lock; - - /* Lock the first lock. */ - KScopedLightLock lk0(lock_0); - - /* If necessary, lock the second lock. */ - std::optional lk1; - if (std::addressof(lock_0) != std::addressof(lock_1)) { - lk1.emplace(lock_1); - } + /* Acquire the table locks. */ + KScopedLightLockPair lk(src_page_table.m_general_lock, dst_page_table.m_general_lock); /* Check memory state. */ R_TRY(src_page_table.CheckMemoryStateContiguous(src_addr, size, src_state_mask, src_state, src_test_perm, src_test_perm, src_attr_mask | KMemoryAttribute_Uncached, src_attr)); @@ -3203,18 +3229,8 @@ namespace ams::kern { /* Copy the memory. */ { - /* Get the table locks. */ - KLightLock &lock_0 = (reinterpret_cast(std::addressof(src_page_table)) <= reinterpret_cast(std::addressof(dst_page_table))) ? src_page_table.m_general_lock : dst_page_table.m_general_lock; - KLightLock &lock_1 = (reinterpret_cast(std::addressof(src_page_table)) <= reinterpret_cast(std::addressof(dst_page_table))) ? dst_page_table.m_general_lock : src_page_table.m_general_lock; - - /* Lock the first lock. */ - KScopedLightLock lk0(lock_0); - - /* If necessary, lock the second lock. */ - std::optional lk1; - if (std::addressof(lock_0) != std::addressof(lock_1)) { - lk1.emplace(lock_1); - } + /* Acquire the table locks. */ + KScopedLightLockPair lk(src_page_table.m_general_lock, dst_page_table.m_general_lock); /* Check memory state for source. */ R_TRY(src_page_table.CheckMemoryStateContiguous(src_addr, size, src_state_mask, src_state, src_test_perm, src_test_perm, src_attr_mask | KMemoryAttribute_Uncached, src_attr)); @@ -3644,18 +3660,8 @@ namespace ams::kern { /* For convenience, alias this. */ KPageTableBase &dst_page_table = *this; - /* Get the table locks. */ - KLightLock &lock_0 = (reinterpret_cast(std::addressof(src_page_table)) <= reinterpret_cast(std::addressof(dst_page_table))) ? src_page_table.m_general_lock : dst_page_table.m_general_lock; - KLightLock &lock_1 = (reinterpret_cast(std::addressof(src_page_table)) <= reinterpret_cast(std::addressof(dst_page_table))) ? dst_page_table.m_general_lock : src_page_table.m_general_lock; - - /* Lock the first lock. */ - KScopedLightLock lk0(lock_0); - - /* If necessary, lock the second lock. */ - std::optional lk1; - if (std::addressof(lock_0) != std::addressof(lock_1)) { - lk1.emplace(lock_1); - } + /* Acquire the table locks. */ + KScopedLightLockPair lk(src_page_table.m_general_lock, dst_page_table.m_general_lock); /* We're going to perform an update, so create a helper. */ KScopedPageTableUpdater updater(std::addressof(src_page_table));