From c06a4c696d90963256d844dd2e46a0b0391a8045 Mon Sep 17 00:00:00 2001 From: Michael Scire Date: Thu, 10 Oct 2024 19:14:07 -0700 Subject: [PATCH] kern: Perform page table validity pass during KPageTableImpl::InitializeForKernel --- .../arm64/init/kern_k_init_page_table.hpp | 104 +++++++++--------- .../arch/arm64/kern_k_page_table_entry.hpp | 22 ++-- .../source/arch/arm64/kern_k_page_table.cpp | 16 +-- .../arch/arm64/kern_k_page_table_impl.cpp | 54 +++++++++ 4 files changed, 127 insertions(+), 69 deletions(-) diff --git a/libraries/libmesosphere/include/mesosphere/arch/arm64/init/kern_k_init_page_table.hpp b/libraries/libmesosphere/include/mesosphere/arch/arm64/init/kern_k_init_page_table.hpp index 8b9a6953c..805b967db 100644 --- a/libraries/libmesosphere/include/mesosphere/arch/arm64/init/kern_k_init_page_table.hpp +++ b/libraries/libmesosphere/include/mesosphere/arch/arm64/init/kern_k_init_page_table.hpp @@ -110,47 +110,47 @@ namespace ams::kern::arch::arm64::init { L1PageTableEntry *l1_entry = this->GetL1Entry(virt_addr); /* If an L1 block is mapped or we're empty, advance by L1BlockSize. */ - if (l1_entry->IsBlock() || l1_entry->IsEmpty()) { + if (l1_entry->IsMappedBlock() || l1_entry->IsEmpty()) { MESOSPHERE_INIT_ABORT_UNLESS(util::IsAligned(GetInteger(virt_addr), L1BlockSize)); MESOSPHERE_INIT_ABORT_UNLESS(static_cast(end_virt_addr - virt_addr) >= L1BlockSize); virt_addr += L1BlockSize; - if (l1_entry->IsBlock() && block_size == L1BlockSize) { + if (l1_entry->IsMappedBlock() && block_size == L1BlockSize) { count++; } continue; } /* Non empty and non-block must be table. */ - MESOSPHERE_INIT_ABORT_UNLESS(l1_entry->IsTable()); + MESOSPHERE_INIT_ABORT_UNLESS(l1_entry->IsMappedTable()); /* Table, so check if we're mapped in L2. */ L2PageTableEntry *l2_entry = GetL2Entry(l1_entry, virt_addr); - if (l2_entry->IsBlock() || l2_entry->IsEmpty()) { - const size_t advance_size = (l2_entry->IsBlock() && l2_entry->IsContiguous()) ? L2ContiguousBlockSize : L2BlockSize; + if (l2_entry->IsMappedBlock() || l2_entry->IsEmpty()) { + const size_t advance_size = (l2_entry->IsMappedBlock() && l2_entry->IsContiguous()) ? L2ContiguousBlockSize : L2BlockSize; MESOSPHERE_INIT_ABORT_UNLESS(util::IsAligned(GetInteger(virt_addr), advance_size)); MESOSPHERE_INIT_ABORT_UNLESS(static_cast(end_virt_addr - virt_addr) >= advance_size); virt_addr += advance_size; - if (l2_entry->IsBlock() && block_size == advance_size) { + if (l2_entry->IsMappedBlock() && block_size == advance_size) { count++; } continue; } /* Non empty and non-block must be table. */ - MESOSPHERE_INIT_ABORT_UNLESS(l2_entry->IsTable()); + MESOSPHERE_INIT_ABORT_UNLESS(l2_entry->IsMappedTable()); /* Table, so check if we're mapped in L3. */ L3PageTableEntry *l3_entry = GetL3Entry(l2_entry, virt_addr); /* L3 must be block or empty. */ - MESOSPHERE_INIT_ABORT_UNLESS(l3_entry->IsBlock() || l3_entry->IsEmpty()); + MESOSPHERE_INIT_ABORT_UNLESS(l3_entry->IsMappedBlock() || l3_entry->IsEmpty()); - const size_t advance_size = (l3_entry->IsBlock() && l3_entry->IsContiguous()) ? L3ContiguousBlockSize : L3BlockSize; + const size_t advance_size = (l3_entry->IsMappedBlock() && l3_entry->IsContiguous()) ? L3ContiguousBlockSize : L3BlockSize; MESOSPHERE_INIT_ABORT_UNLESS(util::IsAligned(GetInteger(virt_addr), advance_size)); MESOSPHERE_INIT_ABORT_UNLESS(static_cast(end_virt_addr - virt_addr) >= advance_size); virt_addr += advance_size; - if (l3_entry->IsBlock() && block_size == advance_size) { + if (l3_entry->IsMappedBlock() && block_size == advance_size) { count++; } } @@ -164,10 +164,10 @@ namespace ams::kern::arch::arm64::init { L1PageTableEntry *l1_entry = this->GetL1Entry(virt_addr); /* If an L1 block is mapped or we're empty, advance by L1BlockSize. */ - if (l1_entry->IsBlock() || l1_entry->IsEmpty()) { + if (l1_entry->IsMappedBlock() || l1_entry->IsEmpty()) { MESOSPHERE_INIT_ABORT_UNLESS(util::IsAligned(GetInteger(virt_addr), L1BlockSize)); MESOSPHERE_INIT_ABORT_UNLESS(static_cast(end_virt_addr - virt_addr) >= L1BlockSize); - if (l1_entry->IsBlock() && block_size == L1BlockSize) { + if (l1_entry->IsMappedBlock() && block_size == L1BlockSize) { if ((count++) == index) { return virt_addr; } @@ -177,16 +177,16 @@ namespace ams::kern::arch::arm64::init { } /* Non empty and non-block must be table. */ - MESOSPHERE_INIT_ABORT_UNLESS(l1_entry->IsTable()); + MESOSPHERE_INIT_ABORT_UNLESS(l1_entry->IsMappedTable()); /* Table, so check if we're mapped in L2. */ L2PageTableEntry *l2_entry = GetL2Entry(l1_entry, virt_addr); - if (l2_entry->IsBlock() || l2_entry->IsEmpty()) { - const size_t advance_size = (l2_entry->IsBlock() && l2_entry->IsContiguous()) ? L2ContiguousBlockSize : L2BlockSize; + if (l2_entry->IsMappedBlock() || l2_entry->IsEmpty()) { + const size_t advance_size = (l2_entry->IsMappedBlock() && l2_entry->IsContiguous()) ? L2ContiguousBlockSize : L2BlockSize; MESOSPHERE_INIT_ABORT_UNLESS(util::IsAligned(GetInteger(virt_addr), advance_size)); MESOSPHERE_INIT_ABORT_UNLESS(static_cast(end_virt_addr - virt_addr) >= advance_size); - if (l2_entry->IsBlock() && block_size == advance_size) { + if (l2_entry->IsMappedBlock() && block_size == advance_size) { if ((count++) == index) { return virt_addr; } @@ -196,18 +196,18 @@ namespace ams::kern::arch::arm64::init { } /* Non empty and non-block must be table. */ - MESOSPHERE_INIT_ABORT_UNLESS(l2_entry->IsTable()); + MESOSPHERE_INIT_ABORT_UNLESS(l2_entry->IsMappedTable()); /* Table, so check if we're mapped in L3. */ L3PageTableEntry *l3_entry = GetL3Entry(l2_entry, virt_addr); /* L3 must be block or empty. */ - MESOSPHERE_INIT_ABORT_UNLESS(l3_entry->IsBlock() || l3_entry->IsEmpty()); + MESOSPHERE_INIT_ABORT_UNLESS(l3_entry->IsMappedBlock() || l3_entry->IsEmpty()); - const size_t advance_size = (l3_entry->IsBlock() && l3_entry->IsContiguous()) ? L3ContiguousBlockSize : L3BlockSize; + const size_t advance_size = (l3_entry->IsMappedBlock() && l3_entry->IsContiguous()) ? L3ContiguousBlockSize : L3BlockSize; MESOSPHERE_INIT_ABORT_UNLESS(util::IsAligned(GetInteger(virt_addr), advance_size)); MESOSPHERE_INIT_ABORT_UNLESS(static_cast(end_virt_addr - virt_addr) >= advance_size); - if (l3_entry->IsBlock() && block_size == advance_size) { + if (l3_entry->IsMappedBlock() && block_size == advance_size) { if ((count++) == index) { return virt_addr; } @@ -220,29 +220,29 @@ namespace ams::kern::arch::arm64::init { PageTableEntry *GetMappingEntry(KVirtualAddress virt_addr, size_t block_size) { L1PageTableEntry *l1_entry = this->GetL1Entry(virt_addr); - if (l1_entry->IsBlock()) { + if (l1_entry->IsMappedBlock()) { MESOSPHERE_INIT_ABORT_UNLESS(block_size == L1BlockSize); return l1_entry; } - MESOSPHERE_INIT_ABORT_UNLESS(l1_entry->IsTable()); + MESOSPHERE_INIT_ABORT_UNLESS(l1_entry->IsMappedTable()); /* Table, so check if we're mapped in L2. */ L2PageTableEntry *l2_entry = GetL2Entry(l1_entry, virt_addr); - if (l2_entry->IsBlock()) { + if (l2_entry->IsMappedBlock()) { const size_t real_size = (l2_entry->IsContiguous()) ? L2ContiguousBlockSize : L2BlockSize; MESOSPHERE_INIT_ABORT_UNLESS(real_size == block_size); return l2_entry; } - MESOSPHERE_INIT_ABORT_UNLESS(l2_entry->IsTable()); + MESOSPHERE_INIT_ABORT_UNLESS(l2_entry->IsMappedTable()); /* Table, so check if we're mapped in L3. */ L3PageTableEntry *l3_entry = GetL3Entry(l2_entry, virt_addr); /* L3 must be block. */ - MESOSPHERE_INIT_ABORT_UNLESS(l3_entry->IsBlock()); + MESOSPHERE_INIT_ABORT_UNLESS(l3_entry->IsMappedBlock()); const size_t real_size = (l3_entry->IsContiguous()) ? L3ContiguousBlockSize : L3BlockSize; MESOSPHERE_INIT_ABORT_UNLESS(real_size == block_size); @@ -340,7 +340,7 @@ namespace ams::kern::arch::arm64::init { } /* If we don't already have an L2 table, we need to make a new one. */ - if (!l1_entry->IsTable()) { + if (!l1_entry->IsMappedTable()) { KPhysicalAddress new_table = AllocateNewPageTable(allocator, phys_to_virt_offset); cpu::DataSynchronizationBarrierInnerShareable(); *l1_entry = L1PageTableEntry(PageTableEntry::TableTag{}, new_table, attr.IsPrivilegedExecuteNever()); @@ -371,7 +371,7 @@ namespace ams::kern::arch::arm64::init { } /* If we don't already have an L3 table, we need to make a new one. */ - if (!l2_entry->IsTable()) { + if (!l2_entry->IsMappedTable()) { KPhysicalAddress new_table = AllocateNewPageTable(allocator, phys_to_virt_offset); cpu::DataSynchronizationBarrierInnerShareable(); *l2_entry = L2PageTableEntry(PageTableEntry::TableTag{}, new_table, attr.IsPrivilegedExecuteNever()); @@ -416,12 +416,12 @@ namespace ams::kern::arch::arm64::init { for (size_t l1_index = 0; l1_index < m_num_entries[0]; l1_index++) { /* Get L1 entry. */ L1PageTableEntry * const l1_entry = l1_table + l1_index; - if (l1_entry->IsBlock()) { + if (l1_entry->IsMappedBlock()) { /* Unmap the L1 entry, if we should. */ if (ShouldUnmap(l1_entry)) { *static_cast(l1_entry) = InvalidPageTableEntry; } - } else if (l1_entry->IsTable()) { + } else if (l1_entry->IsMappedTable()) { /* Get the L2 table. */ L2PageTableEntry * const l2_table = reinterpret_cast(GetInteger(l1_entry->GetTable()) + phys_to_virt_offset); @@ -430,7 +430,7 @@ namespace ams::kern::arch::arm64::init { for (size_t l2_index = 0; l2_index < MaxPageTableEntries; ++l2_index) { /* Get L2 entry. */ L2PageTableEntry * const l2_entry = l2_table + l2_index; - if (l2_entry->IsBlock()) { + if (l2_entry->IsMappedBlock()) { const size_t num_to_clear = (l2_entry->IsContiguous() ? L2ContiguousBlockSize : L2BlockSize) / L2BlockSize; if (ShouldUnmap(l2_entry)) { @@ -442,7 +442,7 @@ namespace ams::kern::arch::arm64::init { } l2_index = l2_index + num_to_clear - 1; - } else if (l2_entry->IsTable()) { + } else if (l2_entry->IsMappedTable()) { /* Get the L3 table. */ L3PageTableEntry * const l3_table = reinterpret_cast(GetInteger(l2_entry->GetTable()) + phys_to_virt_offset); @@ -450,7 +450,7 @@ namespace ams::kern::arch::arm64::init { size_t remaining_l3_entries = 0; for (size_t l3_index = 0; l3_index < MaxPageTableEntries; ++l3_index) { /* Get L3 entry. */ - if (L3PageTableEntry * const l3_entry = l3_table + l3_index; l3_entry->IsBlock()) { + if (L3PageTableEntry * const l3_entry = l3_table + l3_index; l3_entry->IsMappedBlock()) { const size_t num_to_clear = (l3_entry->IsContiguous() ? L3ContiguousBlockSize : L3BlockSize) / L3BlockSize; if (ShouldUnmap(l3_entry)) { @@ -498,25 +498,25 @@ namespace ams::kern::arch::arm64::init { /* Get the L1 entry. */ const L1PageTableEntry *l1_entry = this->GetL1Entry(virt_addr); - if (l1_entry->IsBlock()) { + if (l1_entry->IsMappedBlock()) { return l1_entry->GetBlock() + (GetInteger(virt_addr) & (L1BlockSize - 1)); } - MESOSPHERE_INIT_ABORT_UNLESS(l1_entry->IsTable()); + MESOSPHERE_INIT_ABORT_UNLESS(l1_entry->IsMappedTable()); /* Get the L2 entry. */ const L2PageTableEntry *l2_entry = GetL2Entry(l1_entry, virt_addr); - if (l2_entry->IsBlock()) { + if (l2_entry->IsMappedBlock()) { return l2_entry->GetBlock() + (GetInteger(virt_addr) & (L2BlockSize - 1)); } - MESOSPHERE_INIT_ABORT_UNLESS(l2_entry->IsTable()); + MESOSPHERE_INIT_ABORT_UNLESS(l2_entry->IsMappedTable()); /* Get the L3 entry. */ const L3PageTableEntry *l3_entry = GetL3Entry(l2_entry, virt_addr); - MESOSPHERE_INIT_ABORT_UNLESS(l3_entry->IsBlock()); + MESOSPHERE_INIT_ABORT_UNLESS(l3_entry->IsMappedBlock()); return l3_entry->GetBlock() + (GetInteger(virt_addr) & (L3BlockSize - 1)); } @@ -561,26 +561,26 @@ namespace ams::kern::arch::arm64::init { L1PageTableEntry *l1_entry = this->GetL1Entry(virt_addr); /* If an L1 block is mapped, update. */ - if (l1_entry->IsBlock()) { + if (l1_entry->IsMappedBlock()) { UpdateExtents(l1_entry->GetBlock(), L1BlockSize); continue; } /* Not a block, so we must have a table. */ - MESOSPHERE_INIT_ABORT_UNLESS(l1_entry->IsTable()); + MESOSPHERE_INIT_ABORT_UNLESS(l1_entry->IsMappedTable()); L2PageTableEntry *l2_entry = GetL2Entry(l1_entry, virt_addr); - if (l2_entry->IsBlock()) { + if (l2_entry->IsMappedBlock()) { UpdateExtents(l2_entry->GetBlock(), l2_entry->IsContiguous() ? L2ContiguousBlockSize : L2BlockSize); continue; } /* Not a block, so we must have a table. */ - MESOSPHERE_INIT_ABORT_UNLESS(l2_entry->IsTable()); + MESOSPHERE_INIT_ABORT_UNLESS(l2_entry->IsMappedTable()); /* We must have a mapped l3 entry to inspect. */ L3PageTableEntry *l3_entry = GetL3Entry(l2_entry, virt_addr); - MESOSPHERE_INIT_ABORT_UNLESS(l3_entry->IsBlock()); + MESOSPHERE_INIT_ABORT_UNLESS(l3_entry->IsMappedBlock()); UpdateExtents(l3_entry->GetBlock(), l3_entry->IsContiguous() ? L3ContiguousBlockSize : L3BlockSize); } @@ -602,11 +602,11 @@ namespace ams::kern::arch::arm64::init { L1PageTableEntry *l1_entry = this->GetL1Entry(virt_addr); /* If an L1 block is mapped, the address isn't free. */ - if (l1_entry->IsBlock()) { + if (l1_entry->IsMappedBlock()) { return false; } - if (!l1_entry->IsTable()) { + if (!l1_entry->IsMappedTable()) { /* Not a table, so just move to check the next region. */ virt_addr = util::AlignDown(GetInteger(virt_addr) + L1BlockSize, L1BlockSize); continue; @@ -615,11 +615,11 @@ namespace ams::kern::arch::arm64::init { /* Table, so check if we're mapped in L2. */ L2PageTableEntry *l2_entry = GetL2Entry(l1_entry, virt_addr); - if (l2_entry->IsBlock()) { + if (l2_entry->IsMappedBlock()) { return false; } - if (!l2_entry->IsTable()) { + if (!l2_entry->IsMappedTable()) { /* Not a table, so just move to check the next region. */ virt_addr = util::AlignDown(GetInteger(virt_addr) + L2BlockSize, L2BlockSize); continue; @@ -628,7 +628,7 @@ namespace ams::kern::arch::arm64::init { /* Table, so check if we're mapped in L3. */ L3PageTableEntry *l3_entry = GetL3Entry(l2_entry, virt_addr); - if (l3_entry->IsBlock()) { + if (l3_entry->IsMappedBlock()) { return false; } @@ -648,7 +648,7 @@ namespace ams::kern::arch::arm64::init { L1PageTableEntry *l1_entry = this->GetL1Entry(virt_addr); /* Check if an L1 block is present. */ - if (l1_entry->IsBlock()) { + if (l1_entry->IsMappedBlock()) { /* Ensure that we are allowed to have an L1 block here. */ const KPhysicalAddress block = l1_entry->GetBlock(); MESOSPHERE_INIT_ABORT_UNLESS(util::IsAligned(GetInteger(virt_addr), L1BlockSize)); @@ -669,10 +669,10 @@ namespace ams::kern::arch::arm64::init { } /* Not a block, so we must be a table. */ - MESOSPHERE_INIT_ABORT_UNLESS(l1_entry->IsTable()); + MESOSPHERE_INIT_ABORT_UNLESS(l1_entry->IsMappedTable()); L2PageTableEntry *l2_entry = GetL2Entry(l1_entry, virt_addr); - if (l2_entry->IsBlock()) { + if (l2_entry->IsMappedBlock()) { const KPhysicalAddress block = l2_entry->GetBlock(); if (l2_entry->IsContiguous()) { @@ -720,11 +720,11 @@ namespace ams::kern::arch::arm64::init { } /* Not a block, so we must be a table. */ - MESOSPHERE_INIT_ABORT_UNLESS(l2_entry->IsTable()); + MESOSPHERE_INIT_ABORT_UNLESS(l2_entry->IsMappedTable()); /* We must have a mapped l3 entry to reprotect. */ L3PageTableEntry *l3_entry = GetL3Entry(l2_entry, virt_addr); - MESOSPHERE_INIT_ABORT_UNLESS(l3_entry->IsBlock()); + MESOSPHERE_INIT_ABORT_UNLESS(l3_entry->IsMappedBlock()); const KPhysicalAddress block = l3_entry->GetBlock(); if (l3_entry->IsContiguous()) { diff --git a/libraries/libmesosphere/include/mesosphere/arch/arm64/kern_k_page_table_entry.hpp b/libraries/libmesosphere/include/mesosphere/arch/arm64/kern_k_page_table_entry.hpp index d3ac79089..6fce69f81 100644 --- a/libraries/libmesosphere/include/mesosphere/arch/arm64/kern_k_page_table_entry.hpp +++ b/libraries/libmesosphere/include/mesosphere/arch/arm64/kern_k_page_table_entry.hpp @@ -128,8 +128,8 @@ namespace ams::kern::arch::arm64 { } /* Construct a table. */ - constexpr explicit ALWAYS_INLINE PageTableEntry(TableTag, KPhysicalAddress phys_addr, bool is_kernel, bool pxn, size_t num_blocks) - : PageTableEntry(((is_kernel ? 0x3ul : 0) << 60) | (static_cast(pxn) << 59) | GetInteger(phys_addr) | (num_blocks << 2) | 0x3) + constexpr explicit ALWAYS_INLINE PageTableEntry(TableTag, KPhysicalAddress phys_addr, bool is_kernel, bool pxn, size_t ref_count) + : PageTableEntry(((is_kernel ? 0x3ul : 0) << 60) | (static_cast(pxn) << 59) | GetInteger(phys_addr) | (ref_count << 2) | 0x3) { /* ... */ } @@ -203,6 +203,7 @@ namespace ams::kern::arch::arm64 { constexpr ALWAYS_INLINE KPhysicalAddress GetTable() const { return this->SelectBits(12, 36); } + constexpr ALWAYS_INLINE bool IsMappedBlock() const { return this->GetBits(0, 2) == 1; } constexpr ALWAYS_INLINE bool IsMappedTable() const { return this->GetBits(0, 2) == 3; } constexpr ALWAYS_INLINE bool IsMapped() const { return this->GetBits(0, 1) != 0; } @@ -217,11 +218,13 @@ namespace ams::kern::arch::arm64 { constexpr ALWAYS_INLINE decltype(auto) SetPageAttribute(PageAttribute a) { this->SetBitsDirect(2, 3, a); return *this; } constexpr ALWAYS_INLINE decltype(auto) SetMapped(bool m) { static_assert(static_cast(MappingFlag_Mapped == (1 << 0))); this->SetBit(0, m); return *this; } - constexpr ALWAYS_INLINE size_t GetTableNumEntries() const { return this->GetBits(2, 10); } - constexpr ALWAYS_INLINE decltype(auto) SetTableNumEntries(size_t num) { this->SetBits(2, 10, num); } + constexpr ALWAYS_INLINE size_t GetTableReferenceCount() const { return this->GetBits(2, 10); } + constexpr ALWAYS_INLINE decltype(auto) SetTableReferenceCount(size_t num) { this->SetBits(2, 10, num); return *this; } - constexpr ALWAYS_INLINE decltype(auto) AddTableEntries(size_t num) { return this->SetTableNumEntries(this->GetTableNumEntries() + num); } - constexpr ALWAYS_INLINE decltype(auto) RemoveTableEntries(size_t num) { return this->SetTableNumEntries(this->GetTableNumEntries() - num); } + constexpr ALWAYS_INLINE decltype(auto) OpenTableReferences(size_t num) { MESOSPHERE_ASSERT(this->GetTableReferenceCount() + num <= BlocksPerTable + 1); return this->SetTableReferenceCount(this->GetTableReferenceCount() + num); } + constexpr ALWAYS_INLINE decltype(auto) CloseTableReferences(size_t num) { MESOSPHERE_ASSERT(this->GetTableReferenceCount() >= num); return this->SetTableReferenceCount(this->GetTableReferenceCount() - num); } + + constexpr ALWAYS_INLINE decltype(auto) SetValid() { MESOSPHERE_ASSERT((m_attributes & ExtensionFlag_Valid) == 0); m_attributes |= ExtensionFlag_Valid; return *this; } constexpr ALWAYS_INLINE u64 GetEntryTemplateForMerge() const { constexpr u64 BaseMask = (0xFFFF000000000FFFul & ~static_cast((0x1ul << 52) | ExtensionFlag_TestTableMask | ExtensionFlag_DisableMergeHead | ExtensionFlag_DisableMergeHeadAndBody | ExtensionFlag_DisableMergeTail)); @@ -301,7 +304,7 @@ namespace ams::kern::arch::arm64 { } constexpr explicit ALWAYS_INLINE L1PageTableEntry(BlockTag, KPhysicalAddress phys_addr, const PageTableEntry &attr, u8 sw_reserved_bits, bool contig) - : PageTableEntry(attr, (static_cast(sw_reserved_bits) << 55) | (static_cast(contig) << 52) | GetInteger(phys_addr) | PageTableEntry::ExtensionFlag_Valid) + : PageTableEntry(attr, (static_cast(sw_reserved_bits) << 55) | (static_cast(contig) << 52) | GetInteger(phys_addr) | 0x1) { /* ... */ } @@ -363,7 +366,7 @@ namespace ams::kern::arch::arm64 { } constexpr explicit ALWAYS_INLINE L2PageTableEntry(BlockTag, KPhysicalAddress phys_addr, const PageTableEntry &attr, u8 sw_reserved_bits, bool contig) - : PageTableEntry(attr, (static_cast(sw_reserved_bits) << 55) | (static_cast(contig) << 52) | GetInteger(phys_addr) | PageTableEntry::ExtensionFlag_Valid) + : PageTableEntry(attr, (static_cast(sw_reserved_bits) << 55) | (static_cast(contig) << 52) | GetInteger(phys_addr) | 0x1) { /* ... */ } @@ -428,12 +431,13 @@ namespace ams::kern::arch::arm64 { constexpr explicit ALWAYS_INLINE L3PageTableEntry(InvalidTag) : PageTableEntry(InvalidTag{}) { /* ... */ } constexpr explicit ALWAYS_INLINE L3PageTableEntry(BlockTag, KPhysicalAddress phys_addr, const PageTableEntry &attr, u8 sw_reserved_bits, bool contig) - : PageTableEntry(attr, (static_cast(sw_reserved_bits) << 55) | (static_cast(contig) << 52) | GetInteger(phys_addr) | static_cast(ExtensionFlag_TestTableMask)) + : PageTableEntry(attr, (static_cast(sw_reserved_bits) << 55) | (static_cast(contig) << 52) | GetInteger(phys_addr) | 0x3) { /* ... */ } constexpr ALWAYS_INLINE bool IsBlock() const { return (GetRawAttributes() & ExtensionFlag_TestTableMask) == ExtensionFlag_TestTableMask; } + constexpr ALWAYS_INLINE bool IsMappedBlock() const { return this->GetBits(0, 2) == 3; } constexpr ALWAYS_INLINE KPhysicalAddress GetBlock() const { return this->SelectBits(12, 36); diff --git a/libraries/libmesosphere/source/arch/arm64/kern_k_page_table.cpp b/libraries/libmesosphere/source/arch/arm64/kern_k_page_table.cpp index 39efc0b5f..d13f81979 100644 --- a/libraries/libmesosphere/source/arch/arm64/kern_k_page_table.cpp +++ b/libraries/libmesosphere/source/arch/arm64/kern_k_page_table.cpp @@ -220,7 +220,7 @@ namespace ams::kern::arch::arm64 { /* Remove the entries from the previous table. */ if (context.level != KPageTableImpl::EntryLevel_L1) { - context.level_entries[context.level + 1]->RemoveTableEntries(num_to_clear); + context.level_entries[context.level + 1]->CloseTableReferences(num_to_clear); } /* If we cleared a table, we need to note that we updated and free the table. */ @@ -238,7 +238,7 @@ namespace ams::kern::arch::arm64 { context.level_entries[context.level] = pte + num_to_clear - 1; /* We may have removed the last entries in a table, in which case we can free and unmap the tables. */ - if (context.level >= KPageTableImpl::EntryLevel_L1 || context.level_entries[context.level + 1]->GetTableNumEntries() != 0) { + if (context.level >= KPageTableImpl::EntryLevel_L1 || context.level_entries[context.level + 1]->GetTableReferenceCount() != 0) { break; } @@ -395,7 +395,7 @@ namespace ams::kern::arch::arm64 { /* Remove the entries from the previous table. */ if (context.level != KPageTableImpl::EntryLevel_L1) { - context.level_entries[context.level + 1]->RemoveTableEntries(num_to_clear); + context.level_entries[context.level + 1]->CloseTableReferences(num_to_clear); } /* If we cleared a table, we need to note that we updated and free the table. */ @@ -415,7 +415,7 @@ namespace ams::kern::arch::arm64 { context.level_entries[context.level] = pte + num_to_clear - 1; /* We may have removed the last entries in a table, in which case we can free and unmap the tables. */ - if (context.level >= KPageTableImpl::EntryLevel_L1 || context.level_entries[context.level + 1]->GetTableNumEntries() != 0) { + if (context.level >= KPageTableImpl::EntryLevel_L1 || context.level_entries[context.level + 1]->GetTableReferenceCount() != 0) { break; } @@ -485,7 +485,7 @@ namespace ams::kern::arch::arm64 { /* Remove entries for and free any tables. */ while (context.level < KPageTableImpl::EntryLevel_L1) { /* If the higher-level table has entries, we don't need to do a free. */ - if (context.level_entries[context.level + 1]->GetTableNumEntries() != 0) { + if (context.level_entries[context.level + 1]->GetTableReferenceCount() != 0) { break; } @@ -500,7 +500,7 @@ namespace ams::kern::arch::arm64 { /* Remove the entry for the table one level higher. */ if (context.level + 1 < KPageTableImpl::EntryLevel_L1) { - context.level_entries[context.level + 2]->RemoveTableEntries(1); + context.level_entries[context.level + 2]->CloseTableReferences(1); } /* Advance our level. */ @@ -527,7 +527,7 @@ namespace ams::kern::arch::arm64 { /* Add the entry to the table containing this one. */ if (context.level != KPageTableImpl::EntryLevel_L1) { - context.level_entries[context.level + 1]->AddTableEntries(1); + context.level_entries[context.level + 1]->OpenTableReferences(1); } /* Decrease our level. */ @@ -559,7 +559,7 @@ namespace ams::kern::arch::arm64 { /* Add the entries to the table containing this one. */ if (context.level != KPageTableImpl::EntryLevel_L1) { - context.level_entries[context.level + 1]->AddTableEntries(num_ptes); + context.level_entries[context.level + 1]->OpenTableReferences(num_ptes); } /* Update our context. */ diff --git a/libraries/libmesosphere/source/arch/arm64/kern_k_page_table_impl.cpp b/libraries/libmesosphere/source/arch/arm64/kern_k_page_table_impl.cpp index bddf5323f..32dfb00e5 100644 --- a/libraries/libmesosphere/source/arch/arm64/kern_k_page_table_impl.cpp +++ b/libraries/libmesosphere/source/arch/arm64/kern_k_page_table_impl.cpp @@ -27,6 +27,60 @@ namespace ams::kern::arch::arm64 { m_table = static_cast(tb); m_is_kernel = false; m_num_entries = util::AlignUp(end - start, L1BlockSize) / L1BlockSize; + + /* Page table entries created by KInitialPageTable need to be iterated and modified to ensure KPageTable invariants. */ + PageTableEntry *level_entries[EntryLevel_Count] = { nullptr, nullptr, m_table }; + u32 level = EntryLevel_L1; + while (level != EntryLevel_L1 || (level_entries[EntryLevel_L1] - static_cast(m_table)) < m_num_entries) { + /* Get the pte; it must never have the validity-extension flag set. */ + auto *pte = level_entries[level]; + MESOSPHERE_ASSERT((pte->GetSoftwareReservedBits() & PageTableEntry::SoftwareReservedBit_Valid) == 0); + + /* While we're a table, recurse, fixing up the reference counts. */ + while (level > EntryLevel_L3 && pte->IsMappedTable()) { + /* Count how many references are in the table. */ + auto *table = GetPointer(GetPageTableVirtualAddress(pte->GetTable())); + + size_t ref_count = 0; + for (size_t i = 0; i < BlocksPerTable; ++i) { + if (table[i].IsMapped()) { + ++ref_count; + } + } + + /* Set the reference count for our new page, adding one additional uncloseable reference; kernel pages must never be unreferenced. */ + pte->SetTableReferenceCount(ref_count + 1).SetValid(); + + /* Iterate downwards. */ + level -= 1; + level_entries[level] = table; + pte = level_entries[level]; + + /* Check that the entry isn't unexpected. */ + MESOSPHERE_ASSERT((pte->GetSoftwareReservedBits() & PageTableEntry::SoftwareReservedBit_Valid) == 0); + } + + /* We're dealing with some block. If it's mapped, set it valid. */ + if (pte->IsMapped()) { + pte->SetValid(); + } + + /* Advance. */ + while (true) { + /* Advance to the next entry at the current level. */ + ++level_entries[level]; + if (!util::IsAligned(reinterpret_cast(++level_entries[level]), PageSize)) { + break; + } + + /* If we're at the end of a level, advance upwards. */ + level_entries[level++] = nullptr; + + if (level > EntryLevel_L1) { + return; + } + } + } } L1PageTableEntry *KPageTableImpl::Finalize() {