/*
* Copyright (c) 2018-2020 Atmosphère-NX
*
* This program is free software; you can redistribute it and/or modify it
* under the terms and conditions of the GNU General Public License,
* version 2, as published by the Free Software Foundation.
*
* This program is distributed in the hope it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
* more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
#include
#include "impl/mem_impl_platform.hpp"
#include "impl/heap/mem_impl_heap_tls_heap_static.hpp"
#include "impl/heap/mem_impl_heap_tls_heap_cache.hpp"
#include "impl/heap/mem_impl_heap_tls_heap_central.hpp"
namespace ams::mem {
constexpr inline size_t DefaultAlignment = alignof(std::max_align_t);
constexpr inline size_t MinimumAllocatorSize = 16_KB;
namespace {
void ThreadDestroy(uintptr_t arg) {
if (arg) {
reinterpret_cast(arg)->Finalize();
}
}
ALWAYS_INLINE impl::heap::CentralHeap *GetCentral(const impl::InternalCentralHeapStorage *storage) {
return reinterpret_cast(const_cast(storage));
}
ALWAYS_INLINE impl::heap::CentralHeap *GetCentral(const impl::InternalCentralHeapStorage &storage) {
return GetCentral(std::addressof(storage));
}
ALWAYS_INLINE void GetCache(impl::heap::CentralHeap *central, os::TlsSlot slot) {
impl::heap::CachedHeap tmp_cache;
if (central->MakeCache(std::addressof(tmp_cache))) {
impl::heap::TlsHeapCache *cache = tmp_cache.Release();
os::SetTlsValue(slot, reinterpret_cast(cache));
}
}
struct InternalHash {
size_t allocated_count;
size_t allocated_size;
crypto::Sha1Generator sha1;
};
int InternalHashCallback(void *ptr, size_t size, void *user_data) {
InternalHash *hash = reinterpret_cast(user_data);
hash->sha1.Update(reinterpret_cast(std::addressof(ptr)), sizeof(ptr));
hash->sha1.Update(reinterpret_cast(std::addressof(size)), sizeof(size));
hash->allocated_count++;
hash->allocated_size += size;
return 1;
}
}
StandardAllocator::StandardAllocator() : initialized(false), enable_thread_cache(false), unused(0) {
static_assert(sizeof(impl::heap::CentralHeap) <= sizeof(this->central_heap_storage));
std::construct_at(GetCentral(this->central_heap_storage));
}
StandardAllocator::StandardAllocator(void *mem, size_t size) : StandardAllocator() {
this->Initialize(mem, size);
}
StandardAllocator::StandardAllocator(void *mem, size_t size, bool enable_cache) : StandardAllocator() {
this->Initialize(mem, size, enable_cache);
}
void StandardAllocator::Initialize(void *mem, size_t size) {
this->Initialize(mem, size, false);
}
void StandardAllocator::Initialize(void *mem, size_t size, bool enable_cache) {
AMS_ABORT_UNLESS(!this->initialized);
const uintptr_t aligned_start = util::AlignUp(reinterpret_cast(mem), impl::heap::TlsHeapStatic::PageSize);
const uintptr_t aligned_end = util::AlignDown(reinterpret_cast(mem) + size, impl::heap::TlsHeapStatic::PageSize);
const size_t aligned_size = aligned_end - aligned_start;
if (mem == nullptr) {
AMS_ABORT_UNLESS(os::IsVirtualAddressMemoryEnabled());
AMS_ABORT_UNLESS(GetCentral(this->central_heap_storage)->Initialize(nullptr, size, 0) == 0);
} else {
AMS_ABORT_UNLESS(aligned_start < aligned_end);
AMS_ABORT_UNLESS(aligned_size >= MinimumAllocatorSize);
AMS_ABORT_UNLESS(GetCentral(this->central_heap_storage)->Initialize(reinterpret_cast(aligned_start), aligned_size, 0) == 0);
}
this->enable_thread_cache = enable_cache;
if (this->enable_thread_cache) {
R_ABORT_UNLESS(os::AllocateTlsSlot(std::addressof(this->tls_slot), ThreadDestroy));
}
this->initialized = true;
}
void StandardAllocator::Finalize() {
AMS_ABORT_UNLESS(this->initialized);
if (this->enable_thread_cache) {
os::FreeTlsSlot(this->tls_slot);
}
GetCentral(this->central_heap_storage)->Finalize();
this->initialized = false;
}
void *StandardAllocator::Allocate(size_t size) {
AMS_ASSERT(this->initialized);
return this->Allocate(size, DefaultAlignment);
}
void *StandardAllocator::Allocate(size_t size, size_t alignment) {
AMS_ASSERT(this->initialized);
impl::heap::TlsHeapCache *heap_cache = nullptr;
if (this->enable_thread_cache) {
heap_cache = reinterpret_cast(os::GetTlsValue(this->tls_slot));
if (!heap_cache) {
GetCache(GetCentral(this->central_heap_storage), this->tls_slot);
heap_cache = reinterpret_cast(os::GetTlsValue(this->tls_slot));
}
}
void *ptr = nullptr;
if (heap_cache) {
ptr = heap_cache->Allocate(size, alignment);
if (ptr) {
return ptr;
}
impl::heap::CachedHeap cache;
cache.Reset(heap_cache);
cache.Query(impl::AllocQuery_FinalizeCache);
os::SetTlsValue(this->tls_slot, 0);
}
return GetCentral(this->central_heap_storage)->Allocate(size, alignment);
}
void StandardAllocator::Free(void *ptr) {
AMS_ASSERT(this->initialized);
if (ptr == nullptr) {
return;
}
if (this->enable_thread_cache) {
impl::heap::TlsHeapCache *heap_cache = reinterpret_cast(os::GetTlsValue(this->tls_slot));
if (heap_cache) {
heap_cache->Free(ptr);
return;
}
}
auto err = GetCentral(this->central_heap_storage)->Free(ptr);
AMS_ASSERT(err == 0);
}
void *StandardAllocator::Reallocate(void *ptr, size_t new_size) {
AMS_ASSERT(this->initialized);
if (new_size > impl::MaxSize) {
return nullptr;
}
if (ptr == nullptr) {
return this->Allocate(new_size);
}
if (new_size == 0) {
this->Free(ptr);
return nullptr;
}
size_t aligned_new_size = util::AlignUp(new_size, DefaultAlignment);
impl::heap::TlsHeapCache *heap_cache = nullptr;
if (this->enable_thread_cache) {
heap_cache = reinterpret_cast(os::GetTlsValue(this->tls_slot));
if (!heap_cache) {
GetCache(GetCentral(this->central_heap_storage), this->tls_slot);
heap_cache = reinterpret_cast(os::GetTlsValue(this->tls_slot));
}
}
void *p = nullptr;
impl::errno_t err;
if (heap_cache) {
err = heap_cache->Reallocate(ptr, aligned_new_size, std::addressof(p));
} else {
err = GetCentral(this->central_heap_storage)->Reallocate(ptr, aligned_new_size, std::addressof(p));
}
if (err == 0) {
return p;
} else {
return nullptr;
}
}
size_t StandardAllocator::Shrink(void *ptr, size_t new_size) {
AMS_ASSERT(this->initialized);
if (this->enable_thread_cache) {
impl::heap::TlsHeapCache *heap_cache = reinterpret_cast(os::GetTlsValue(this->tls_slot));
if (heap_cache) {
if (heap_cache->Shrink(ptr, new_size) == 0) {
return heap_cache->GetAllocationSize(ptr);
} else {
return 0;
}
}
}
if (GetCentral(this->central_heap_storage)->Shrink(ptr, new_size) == 0) {
return GetCentral(this->central_heap_storage)->GetAllocationSize(ptr);
} else {
return 0;
}
}
void StandardAllocator::ClearThreadCache() const {
if (this->enable_thread_cache) {
impl::heap::TlsHeapCache *heap_cache = reinterpret_cast(os::GetTlsValue(this->tls_slot));
impl::heap::CachedHeap cache;
cache.Reset(heap_cache);
cache.Query(impl::AllocQuery_ClearCache);
cache.Release();
}
}
void StandardAllocator::CleanUpManagementArea() const {
AMS_ASSERT(this->initialized);
auto err = GetCentral(this->central_heap_storage)->Query(impl::AllocQuery_UnifyFreeList);
AMS_ASSERT(err == 0);
}
size_t StandardAllocator::GetSizeOf(const void *ptr) const {
AMS_ASSERT(this->initialized);
if (!util::IsAligned(reinterpret_cast(ptr), DefaultAlignment)) {
return 0;
}
impl::heap::TlsHeapCache *heap_cache = nullptr;
if (this->enable_thread_cache) {
heap_cache = reinterpret_cast(os::GetTlsValue(this->tls_slot));
if (!heap_cache) {
GetCache(GetCentral(this->central_heap_storage), this->tls_slot);
heap_cache = reinterpret_cast(os::GetTlsValue(this->tls_slot));
}
}
if (heap_cache) {
return heap_cache->GetAllocationSize(ptr);
} else {
return GetCentral(this->central_heap_storage)->GetAllocationSize(ptr);
}
}
size_t StandardAllocator::GetTotalFreeSize() const {
size_t size = 0;
auto err = GetCentral(this->central_heap_storage)->Query(impl::AllocQuery_FreeSizeMapped, std::addressof(size));
if (err != 0) {
err = GetCentral(this->central_heap_storage)->Query(impl::AllocQuery_FreeSize, std::addressof(size));
}
AMS_ASSERT(err == 0);
return size;
}
size_t StandardAllocator::GetAllocatableSize() const {
size_t size = 0;
auto err = GetCentral(this->central_heap_storage)->Query(impl::AllocQuery_MaxAllocatableSizeMapped, std::addressof(size));
if (err != 0) {
err = GetCentral(this->central_heap_storage)->Query(impl::AllocQuery_MaxAllocatableSize, std::addressof(size));
}
AMS_ASSERT(err == 0);
return size;
}
void StandardAllocator::WalkAllocatedBlocks(WalkCallback callback, void *user_data) const {
AMS_ASSERT(this->initialized);
this->ClearThreadCache();
GetCentral(this->central_heap_storage)->WalkAllocatedPointers(callback, user_data);
}
void StandardAllocator::Dump() const {
AMS_ASSERT(this->initialized);
size_t tmp;
auto err = GetCentral(this->central_heap_storage)->Query(impl::AllocQuery_MaxAllocatableSizeMapped, std::addressof(tmp));
if (err == 0) {
GetCentral(this->central_heap_storage)->Query(impl::AllocQuery_Dump, impl::DumpMode_Spans | impl::DumpMode_Pointers, 1);
} else {
GetCentral(this->central_heap_storage)->Query(impl::AllocQuery_Dump, impl::DumpMode_All, 1);
}
}
StandardAllocator::AllocatorHash StandardAllocator::Hash() const {
AMS_ASSERT(this->initialized);
AllocatorHash alloc_hash;
{
char temp_hash[crypto::Sha1Generator::HashSize];
InternalHash internal_hash;
internal_hash.allocated_count = 0;
internal_hash.allocated_size = 0;
internal_hash.sha1.Initialize();
this->WalkAllocatedBlocks(InternalHashCallback, reinterpret_cast(std::addressof(internal_hash)));
alloc_hash.allocated_count = internal_hash.allocated_count;
alloc_hash.allocated_size = internal_hash.allocated_size;
internal_hash.sha1.GetHash(temp_hash, sizeof(temp_hash));
std::memcpy(std::addressof(alloc_hash.hash), temp_hash, sizeof(alloc_hash.hash));
}
return alloc_hash;
}
}