diff --git a/libraries/libstratosphere/include/stratosphere/os.hpp b/libraries/libstratosphere/include/stratosphere/os.hpp index 48f33f894..e0690a20b 100644 --- a/libraries/libstratosphere/include/stratosphere/os.hpp +++ b/libraries/libstratosphere/include/stratosphere/os.hpp @@ -29,6 +29,7 @@ #include #include #include +#include #include #include #include diff --git a/libraries/libstratosphere/include/stratosphere/os/os_sdk_condition_variable.hpp b/libraries/libstratosphere/include/stratosphere/os/os_sdk_condition_variable.hpp index 3d5e38b8a..c7f8b21f5 100644 --- a/libraries/libstratosphere/include/stratosphere/os/os_sdk_condition_variable.hpp +++ b/libraries/libstratosphere/include/stratosphere/os/os_sdk_condition_variable.hpp @@ -17,6 +17,7 @@ #pragma once #include #include +#include #include namespace ams::os { @@ -27,20 +28,21 @@ namespace ams::os { impl::InternalConditionVariableStorage _storage; }; - void Initialize() { + ALWAYS_INLINE void Initialize() { GetReference(this->_storage).Initialize(); } void Wait(SdkMutexType &mutex); bool TimedWait(SdkMutexType &mutex, TimeSpan timeout); - /* TODO: SdkRecursiveMutexType */ + void Wait(SdkRecursiveMutexType &mutex); + bool TimedWait(SdkRecursiveMutexType &mutex, TimeSpan timeout); - void Signal() { + ALWAYS_INLINE void Signal() { GetReference(this->_storage).Signal(); } - void Broadcast() { + ALWAYS_INLINE void Broadcast() { GetReference(this->_storage).Broadcast(); } }; @@ -48,26 +50,32 @@ namespace ams::os { class SdkConditionVariable { private: - SdkConditionVariableType cv; + SdkConditionVariableType m_cv; public: - constexpr SdkConditionVariable() : cv{{0}} { /* ... */ } + constexpr SdkConditionVariable() : m_cv{{0}} { /* ... */ } - void Wait(SdkMutex &m) { - return this->cv.Wait(m.mutex); + ALWAYS_INLINE void Wait(SdkMutex &m) { + return m_cv.Wait(m.m_mutex); } - bool TimedWait(SdkMutex &m, TimeSpan timeout) { - return this->cv.TimedWait(m.mutex, timeout); + ALWAYS_INLINE bool TimedWait(SdkMutex &m, TimeSpan timeout) { + return m_cv.TimedWait(m.m_mutex, timeout); } - /* TODO: SdkRecursiveMutexType */ - - void Signal() { - return this->cv.Signal(); + ALWAYS_INLINE void Wait(SdkRecursiveMutex &m) { + return m_cv.Wait(m.m_mutex); } - void Broadcast() { - return this->cv.Broadcast(); + ALWAYS_INLINE bool TimedWait(SdkRecursiveMutex &m, TimeSpan timeout) { + return m_cv.TimedWait(m.m_mutex, timeout); + } + + ALWAYS_INLINE void Signal() { + return m_cv.Signal(); + } + + ALWAYS_INLINE void Broadcast() { + return m_cv.Broadcast(); } }; diff --git a/libraries/libstratosphere/include/stratosphere/os/os_sdk_mutex.hpp b/libraries/libstratosphere/include/stratosphere/os/os_sdk_mutex.hpp index 675c12ac2..98311944f 100644 --- a/libraries/libstratosphere/include/stratosphere/os/os_sdk_mutex.hpp +++ b/libraries/libstratosphere/include/stratosphere/os/os_sdk_mutex.hpp @@ -42,15 +42,15 @@ namespace ams::os { private: friend class SdkConditionVariable; private: - SdkMutexType mutex; + SdkMutexType m_mutex; public: - constexpr SdkMutex() : mutex{{0}} { /* ... */ } + constexpr SdkMutex() : m_mutex{{0}} { /* ... */ } - ALWAYS_INLINE void Lock() { return os::LockSdkMutex(std::addressof(this->mutex)); } - ALWAYS_INLINE bool TryLock() { return os::TryLockSdkMutex(std::addressof(this->mutex)); } - ALWAYS_INLINE void Unlock() { return os::UnlockSdkMutex(std::addressof(this->mutex)); } + ALWAYS_INLINE void Lock() { return os::LockSdkMutex(std::addressof(m_mutex)); } + ALWAYS_INLINE bool TryLock() { return os::TryLockSdkMutex(std::addressof(m_mutex)); } + ALWAYS_INLINE void Unlock() { return os::UnlockSdkMutex(std::addressof(m_mutex)); } - ALWAYS_INLINE bool IsLockedByCurrentThread() const { return os::IsSdkMutexLockedByCurrentThread(std::addressof(this->mutex)); } + ALWAYS_INLINE bool IsLockedByCurrentThread() const { return os::IsSdkMutexLockedByCurrentThread(std::addressof(m_mutex)); } ALWAYS_INLINE void lock() { return this->Lock(); } ALWAYS_INLINE bool try_lock() { return this->TryLock(); } diff --git a/libraries/libstratosphere/include/stratosphere/os/os_sdk_recursive_mutex.hpp b/libraries/libstratosphere/include/stratosphere/os/os_sdk_recursive_mutex.hpp new file mode 100644 index 000000000..55d270beb --- /dev/null +++ b/libraries/libstratosphere/include/stratosphere/os/os_sdk_recursive_mutex.hpp @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2018-2020 Atmosphère-NX + * + * This program is free software; you can redistribute it and/or modify it + * under the terms and conditions of the GNU General Public License, + * version 2, as published by the Free Software Foundation. + * + * This program is distributed in the hope it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#pragma once +#include +#include + +namespace ams::os { + + class SdkConditionVariable; + + struct SdkRecursiveMutexType { + union { + s32 _arr[sizeof(impl::InternalCriticalSectionStorage) / sizeof(s32)]; + impl::InternalCriticalSectionStorage _storage; + }; + u32 recursive_count; + }; + static_assert(std::is_trivial::value); + + void InitializeSdkRecursiveMutex(SdkRecursiveMutexType *rmutex); + + void LockSdkRecursiveMutex(SdkRecursiveMutexType *rmutex); + bool TryLockSdkRecursiveMutex(SdkRecursiveMutexType *rmutex); + void UnlockSdkRecursiveMutex(SdkRecursiveMutexType *rmutex); + + bool IsSdkRecursiveMutexLockedByCurrentThread(const SdkRecursiveMutexType *rmutex); + + class SdkRecursiveMutex { + private: + friend class SdkConditionVariable; + private: + SdkRecursiveMutexType m_mutex; + public: + constexpr SdkRecursiveMutex() : m_mutex{{0}, 0} { /* ... */ } + + ALWAYS_INLINE void Lock() { return os::LockSdkRecursiveMutex(std::addressof(m_mutex)); } + ALWAYS_INLINE bool TryLock() { return os::TryLockSdkRecursiveMutex(std::addressof(m_mutex)); } + ALWAYS_INLINE void Unlock() { return os::UnlockSdkRecursiveMutex(std::addressof(m_mutex)); } + + ALWAYS_INLINE bool IsLockedByCurrentThread() const { return os::IsSdkRecursiveMutexLockedByCurrentThread(std::addressof(m_mutex)); } + + ALWAYS_INLINE void lock() { return this->Lock(); } + ALWAYS_INLINE bool try_lock() { return this->TryLock(); } + ALWAYS_INLINE void unlock() { return this->Unlock(); } + }; + +} diff --git a/libraries/libstratosphere/source/os/os_sdk_condition_variable.cpp b/libraries/libstratosphere/source/os/os_sdk_condition_variable.cpp index 78bff6612..2dcea6ba5 100644 --- a/libraries/libstratosphere/source/os/os_sdk_condition_variable.cpp +++ b/libraries/libstratosphere/source/os/os_sdk_condition_variable.cpp @@ -45,4 +45,58 @@ namespace ams::os { return status == ConditionVariableStatus::Success; } + void SdkConditionVariableType::Wait(SdkRecursiveMutexType &mutex) { + /* Check preconditions. */ + AMS_ABORT_UNLESS(os::IsSdkRecursiveMutexLockedByCurrentThread(std::addressof(mutex))); + AMS_ABORT_UNLESS(mutex.recursive_count == 1); + + /* Decrement the mutex's recursive count. */ + --mutex.recursive_count; + + /* Wait on the mutex. */ + GetReference(this->_storage).Wait(GetPointer(mutex._storage)); + + /* Increment the mutex's recursive count. */ + ++mutex.recursive_count; + + /* Check that the mutex's recursive count is valid. */ + AMS_ABORT_UNLESS(mutex.recursive_count != 0); + } + + bool SdkConditionVariableType::TimedWait(SdkRecursiveMutexType &mutex, TimeSpan timeout) { + /* Check preconditions. */ + AMS_ASSERT(timeout.GetNanoSeconds() >= 0); + AMS_ABORT_UNLESS(os::IsSdkRecursiveMutexLockedByCurrentThread(std::addressof(mutex))); + + /* Handle zero timeout by unlocking and re-locking. */ + if (timeout == TimeSpan(0)) { + /* NOTE: Nintendo doesn't check recursive_count here...seems really suspicious? */ + /* Not sure that this is correct, or if they just forgot to check. */ + GetReference(mutex._storage).Leave(); + GetReference(mutex._storage).Enter(); + return false; + } + + /* Check that the mutex is held exactly once. */ + AMS_ABORT_UNLESS(mutex.recursive_count == 1); + + /* Decrement the mutex's recursive count. */ + --mutex.recursive_count; + + /* Create timeout helper. */ + impl::TimeoutHelper timeout_helper(timeout); + + /* Perform timed wait. */ + auto status = GetReference(this->_storage).TimedWait(GetPointer(mutex._storage), timeout_helper); + + /* Increment the mutex's recursive count. */ + ++mutex.recursive_count; + + /* Check that the mutex's recursive count is valid. */ + AMS_ABORT_UNLESS(mutex.recursive_count != 0); + + /* Return whether we succeeded. */ + return status == ConditionVariableStatus::Success; + } + } diff --git a/libraries/libstratosphere/source/os/os_sdk_mutex.cpp b/libraries/libstratosphere/source/os/os_sdk_mutex.cpp index e7dbeb56a..f0fc508fa 100644 --- a/libraries/libstratosphere/source/os/os_sdk_mutex.cpp +++ b/libraries/libstratosphere/source/os/os_sdk_mutex.cpp @@ -18,25 +18,36 @@ namespace ams::os { void InitializeSdkMutex(SdkMutexType *mutex) { + /* Initialize the critical section. */ GetReference(mutex->_storage).Initialize(); } bool IsSdkMutexLockedByCurrentThread(const SdkMutexType *mutex) { + /* Check whether the critical section is held. */ return GetReference(mutex->_storage).IsLockedByCurrentThread(); } void LockSdkMutex(SdkMutexType *mutex) { + /* Check pre-conditions. */ AMS_ABORT_UNLESS(!IsSdkMutexLockedByCurrentThread(mutex)); + + /* Enter the critical section. */ return GetReference(mutex->_storage).Enter(); } bool TryLockSdkMutex(SdkMutexType *mutex) { + /* Check pre-conditions. */ AMS_ABORT_UNLESS(!IsSdkMutexLockedByCurrentThread(mutex)); + + /* Try to enter the critical section. */ return GetReference(mutex->_storage).TryEnter(); } void UnlockSdkMutex(SdkMutexType *mutex) { + /* Check pre-conditions. */ AMS_ABORT_UNLESS(IsSdkMutexLockedByCurrentThread(mutex)); + + /* Leave the critical section. */ return GetReference(mutex->_storage).Leave(); } diff --git a/libraries/libstratosphere/source/os/os_sdk_recursive_mutex.cpp b/libraries/libstratosphere/source/os/os_sdk_recursive_mutex.cpp new file mode 100644 index 000000000..92f1072b1 --- /dev/null +++ b/libraries/libstratosphere/source/os/os_sdk_recursive_mutex.cpp @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2018-2020 Atmosphère-NX + * + * This program is free software; you can redistribute it and/or modify it + * under the terms and conditions of the GNU General Public License, + * version 2, as published by the Free Software Foundation. + * + * This program is distributed in the hope it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +#include + +namespace ams::os { + + void InitializeSdkRecursiveMutex(SdkRecursiveMutexType *rmutex) { + /* Initialize the critical section. */ + GetReference(rmutex->_storage).Initialize(); + + /* Set recursive count. */ + rmutex->recursive_count = 0; + } + + bool IsSdkRecursiveMutexLockedByCurrentThread(const SdkRecursiveMutexType *rmutex) { + /* Check whether the critical section is held. */ + return GetReference(rmutex->_storage).IsLockedByCurrentThread(); + } + + void LockSdkRecursiveMutex(SdkRecursiveMutexType *rmutex) { + /* If we don't hold the mutex, enter the critical section. */ + if (!IsSdkRecursiveMutexLockedByCurrentThread(rmutex)) { + GetReference(rmutex->_storage).Enter(); + } + + /* Increment (and check) recursive count. */ + ++rmutex->recursive_count; + AMS_ABORT_UNLESS(rmutex->recursive_count != 0); + } + + bool TryLockSdkRecursiveMutex(SdkRecursiveMutexType *rmutex) { + /* If we don't hold the mutex, try to enter the critical section. */ + if (!IsSdkRecursiveMutexLockedByCurrentThread(rmutex)) { + if (!GetReference(rmutex->_storage).TryEnter()) { + return false; + } + } + + /* Increment (and check) recursive count. */ + ++rmutex->recursive_count; + AMS_ABORT_UNLESS(rmutex->recursive_count != 0); + + return true; + } + + void UnlockSdkRecursiveMutex(SdkRecursiveMutexType *rmutex) { + /* Check pre-conditions. */ + AMS_ABORT_UNLESS(IsSdkRecursiveMutexLockedByCurrentThread(rmutex)); + + /* Decrement recursive count, and leave critical section if we no longer hold the mutex. */ + if ((--rmutex->recursive_count) == 0) { + GetReference(rmutex->_storage).Leave(); + } + } + +}