diff --git a/libraries/libmesosphere/include/mesosphere/kern_k_event.hpp b/libraries/libmesosphere/include/mesosphere/kern_k_event.hpp index 15362e8a3..2689bcf61 100644 --- a/libraries/libmesosphere/include/mesosphere/kern_k_event.hpp +++ b/libraries/libmesosphere/include/mesosphere/kern_k_event.hpp @@ -46,6 +46,8 @@ namespace ams::kern { static void PostDestroy(uintptr_t arg); + virtual KProcess *GetOwner() const override { return this->owner; } + KReadableEvent &GetReadableEvent() { return this->readable_event; } KWritableEvent &GetWritableEvent() { return this->writable_event; } }; diff --git a/libraries/libmesosphere/include/mesosphere/kern_k_handle_table.hpp b/libraries/libmesosphere/include/mesosphere/kern_k_handle_table.hpp index 8900797ec..907c9d08d 100644 --- a/libraries/libmesosphere/include/mesosphere/kern_k_handle_table.hpp +++ b/libraries/libmesosphere/include/mesosphere/kern_k_handle_table.hpp @@ -155,21 +155,7 @@ namespace ams::kern { } } - template - ALWAYS_INLINE KScopedAutoObject GetObjectForIpc(ams::svc::Handle handle) const { - static_assert(!std::is_base_of::value); - - /* Handle pseudo-handles. */ - if constexpr (std::is_base_of::value) { - if (handle == ams::svc::PseudoHandle::CurrentProcess) { - return GetCurrentProcessPointer(); - } - } else if constexpr (std::is_base_of::value) { - if (handle == ams::svc::PseudoHandle::CurrentThread) { - return GetCurrentThreadPointer(); - } - } - + ALWAYS_INLINE KScopedAutoObject GetObjectForIpcWithoutPseudoHandle(ams::svc::Handle handle) const { /* Lock and look up in table. */ KScopedDisableDispatch dd; KScopedSpinLock lk(this->lock); @@ -178,15 +164,20 @@ namespace ams::kern { if (obj->DynamicCast() != nullptr) { return nullptr; } - if constexpr (std::is_same::value) { - return obj; - } else { - if (auto *obj = this->GetObjectImpl(handle); obj != nullptr) { - return obj->DynamicCast(); - } else { - return nullptr; - } + + return obj; + } + + ALWAYS_INLINE KScopedAutoObject GetObjectForIpc(ams::svc::Handle handle, KThread *cur_thread) const { + /* Handle pseudo-handles. */ + if (handle == ams::svc::PseudoHandle::CurrentProcess) { + return static_cast(static_cast(cur_thread->GetOwnerProcess())); } + if (handle == ams::svc::PseudoHandle::CurrentThread) { + return static_cast(cur_thread); + } + + return GetObjectForIpcWithoutPseudoHandle(handle); } ALWAYS_INLINE KScopedAutoObject GetObjectByIndex(ams::svc::Handle *out_handle, size_t index) const { diff --git a/libraries/libmesosphere/include/mesosphere/kern_k_session_request.hpp b/libraries/libmesosphere/include/mesosphere/kern_k_session_request.hpp index ab9c6325a..7696e0024 100644 --- a/libraries/libmesosphere/include/mesosphere/kern_k_session_request.hpp +++ b/libraries/libmesosphere/include/mesosphere/kern_k_session_request.hpp @@ -128,6 +128,11 @@ namespace ams::kern { size_t GetSize() const { return this->size; } KProcess *GetServerProcess() const { return this->server; } + void SetServerProcess(KProcess *process) { + this->server = process; + this->server->Open(); + } + void ClearThread() { this->thread = nullptr; } void ClearEvent() { this->event = nullptr; } diff --git a/libraries/libmesosphere/source/kern_k_server_session.cpp b/libraries/libmesosphere/source/kern_k_server_session.cpp index c72bf9af8..07218098a 100644 --- a/libraries/libmesosphere/source/kern_k_server_session.cpp +++ b/libraries/libmesosphere/source/kern_k_server_session.cpp @@ -17,8 +17,246 @@ namespace ams::kern { + namespace ipc { + + using MessageBuffer = ams::svc::ipc::MessageBuffer; + + } + namespace { + class ReceiveList { + private: + u32 data[ipc::MessageBuffer::MessageHeader::ReceiveListCountType_CountMax * ipc::MessageBuffer::ReceiveListEntry::GetDataSize() / sizeof(u32)]; + s32 recv_list_count; + uintptr_t msg_buffer_end; + uintptr_t msg_buffer_space_end; + public: + static constexpr int GetEntryCount(const ipc::MessageBuffer::MessageHeader &header) { + const auto count = header.GetReceiveListCount(); + switch (count) { + case ipc::MessageBuffer::MessageHeader::ReceiveListCountType_None: + return 0; + case ipc::MessageBuffer::MessageHeader::ReceiveListCountType_ToMessageBuffer: + return 0; + case ipc::MessageBuffer::MessageHeader::ReceiveListCountType_ToSingleBuffer: + return 1; + default: + return count - ipc::MessageBuffer::MessageHeader::ReceiveListCountType_CountOffset; + } + } + public: + ReceiveList(const u32 *dst_msg, const ipc::MessageBuffer::MessageHeader &dst_header, const ipc::MessageBuffer::SpecialHeader &dst_special_header, size_t msg_size, size_t out_offset, s32 dst_recv_list_idx) { + this->recv_list_count = dst_header.GetReceiveListCount(); + this->msg_buffer_end = reinterpret_cast(dst_msg) + sizeof(u32) * out_offset; + this->msg_buffer_space_end = reinterpret_cast(dst_msg) + msg_size; + + const u32 *recv_list = dst_msg + dst_recv_list_idx; + __builtin_memcpy(this->data, recv_list, GetEntryCount(dst_header) * ipc::MessageBuffer::ReceiveListEntry::GetDataSize()); + } + + constexpr bool IsIndex() const { + return this->recv_list_count > ipc::MessageBuffer::MessageHeader::ReceiveListCountType_CountOffset; + } + }; + + template + ALWAYS_INLINE Result ProcessMessageSpecialData(int &offset, KProcess &dst_process, KProcess &src_process, KThread &src_thread, const ipc::MessageBuffer &dst_msg, const ipc::MessageBuffer &src_msg, const ipc::MessageBuffer::SpecialHeader &src_special_header) { + /* Copy the special header to the destination. */ + offset = dst_msg.Set(src_special_header); + + /* Copy the process ID. */ + if (src_special_header.GetHasProcessId()) { + /* TODO: Atmosphere mitm extension support. */ + offset = dst_msg.SetProcessId(offset, src_process.GetId()); + } + + /* Prepare to process handles. */ + auto &dst_handle_table = dst_process.GetHandleTable(); + auto &src_handle_table = src_process.GetHandleTable(); + Result result = ResultSuccess(); + + /* Process copy handles. */ + for (auto i = 0; i < src_special_header.GetCopyHandleCount(); ++i) { + /* Get the handles. */ + const ams::svc::Handle src_handle = src_msg.GetHandle(offset); + ams::svc::Handle dst_handle = ams::svc::InvalidHandle; + + /* If we're in a success state, try to move the handle to the new table. */ + if (R_SUCCEEDED(result) && src_handle != ams::svc::InvalidHandle) { + KScopedAutoObject obj = src_handle_table.GetObjectForIpc(src_handle, std::addressof(src_thread)); + if (obj.IsNotNull()) { + Result add_result = dst_handle_table.Add(std::addressof(dst_handle), obj.GetPointerUnsafe()); + if (R_FAILED(add_result)) { + result = add_result; + dst_handle = ams::svc::InvalidHandle; + } + } else { + result = svc::ResultInvalidHandle(); + } + } + + /* Set the handle. */ + offset = dst_msg.SetHandle(offset, dst_handle); + } + + /* Process move handles. */ + if constexpr (MoveHandleAllowed) { + for (auto i = 0; i < src_special_header.GetMoveHandleCount(); ++i) { + /* Get the handles. */ + const ams::svc::Handle src_handle = src_msg.GetHandle(offset); + ams::svc::Handle dst_handle = ams::svc::InvalidHandle; + + /* Whether or not we've succeeded, we need to remove the handles from the source table. */ + if (src_handle != ams::svc::InvalidHandle) { + if (R_SUCCEEDED(result)) { + KScopedAutoObject obj = src_handle_table.GetObjectForIpcWithoutPseudoHandle(src_handle); + if (obj.IsNotNull()) { + Result add_result = dst_handle_table.Add(std::addressof(dst_handle), obj.GetPointerUnsafe()); + + src_handle_table.Remove(src_handle); + + if (R_FAILED(add_result)) { + result = add_result; + dst_handle = ams::svc::InvalidHandle; + } + } else { + result = svc::ResultInvalidHandle(); + } + } else { + src_handle_table.Remove(src_handle); + } + } + + /* Set the handle. */ + offset = dst_msg.SetHandle(offset, dst_handle); + } + } + + return result; + } + + ALWAYS_INLINE Result ReceiveMessage(bool &recv_list_broken, uintptr_t dst_message_buffer, size_t dst_buffer_size, KPhysicalAddress dst_message_paddr, KThread &src_thread, uintptr_t src_message_buffer, size_t src_buffer_size, KServerSession *session, KSessionRequest *request) { + /* Prepare variables for receive. */ + const KThread &dst_thread = GetCurrentThread(); + KProcess &dst_process = *(dst_thread.GetOwnerProcess()); + KProcess &src_process = *(src_thread.GetOwnerProcess()); + auto &dst_page_table = dst_process.GetPageTable(); + auto &src_page_table = src_process.GetPageTable(); + + /* The receive list is initially not broken. */ + recv_list_broken = false; + + /* Set the server process for the request. */ + request->SetServerProcess(std::addressof(dst_process)); + + /* Determine the message buffers. */ + u32 *dst_msg_ptr, *src_msg_ptr; + bool dst_user, src_user; + + if (dst_message_buffer) { + dst_msg_ptr = GetPointer(KPageTable::GetHeapVirtualAddress(dst_message_paddr)); + dst_user = true; + } else { + dst_msg_ptr = static_cast(dst_thread.GetThreadLocalRegionHeapAddress())->message_buffer; + dst_buffer_size = sizeof(ams::svc::ThreadLocalRegion{}.message_buffer); + dst_message_buffer = GetInteger(dst_thread.GetThreadLocalRegionAddress()); + dst_user = false; + } + + if (src_message_buffer) { + /* NOTE: Nintendo does not check the result of this GetPhysicalAddress call. */ + KPhysicalAddress src_message_paddr; + src_page_table.GetPhysicalAddress(std::addressof(src_message_paddr), src_message_buffer); + + src_msg_ptr = GetPointer(KPageTable::GetHeapVirtualAddress(src_message_paddr)); + src_user = true; + } else { + src_msg_ptr = static_cast(src_thread.GetThreadLocalRegionHeapAddress())->message_buffer; + src_buffer_size = sizeof(ams::svc::ThreadLocalRegion{}.message_buffer); + src_message_buffer = GetInteger(src_thread.GetThreadLocalRegionAddress()); + src_user = false; + } + + /* Parse the headers. */ + const ipc::MessageBuffer dst_msg(dst_msg_ptr, dst_buffer_size); + const ipc::MessageBuffer src_msg(src_msg_ptr, src_buffer_size); + const ipc::MessageBuffer::MessageHeader dst_header(dst_msg); + const ipc::MessageBuffer::MessageHeader src_header(src_msg); + const ipc::MessageBuffer::SpecialHeader dst_special_header(dst_msg, dst_header); + const ipc::MessageBuffer::SpecialHeader src_special_header(src_msg, src_header); + + /* Get the end of the source message. */ + const size_t src_end_offset = ipc::MessageBuffer::GetRawDataIndex(src_header, src_special_header) + src_header.GetRawCount(); + + /* Ensure that the headers fit. */ + R_UNLESS(ipc::MessageBuffer::GetMessageBufferSize(dst_header, dst_special_header) <= dst_buffer_size, svc::ResultInvalidCombination()); + R_UNLESS(ipc::MessageBuffer::GetMessageBufferSize(src_header, src_special_header) <= src_buffer_size, svc::ResultInvalidCombination()); + + /* Ensure the receive list offset is after the end of raw data. */ + if (dst_header.GetReceiveListOffset()) { + R_UNLESS(dst_header.GetReceiveListOffset() >= ipc::MessageBuffer::GetRawDataIndex(dst_header, dst_special_header) + dst_header.GetRawCount(), svc::ResultInvalidCombination()); + } + + /* Ensure that the destination buffer is big enough to receive the source. */ + R_UNLESS(dst_buffer_size >= src_end_offset * sizeof(u32), svc::ResultMessageTooLarge()); + + /* Get the receive list. */ + const s32 dst_recv_list_idx = static_cast(ipc::MessageBuffer::GetReceiveListIndex(dst_header, dst_special_header)); + ReceiveList dst_recv_list(dst_msg_ptr, dst_header, dst_special_header, dst_buffer_size, src_end_offset, dst_recv_list_idx); + + /* Ensure that the source special header isn't invalid. */ + const bool src_has_special_header = src_header.GetHasSpecialHeader(); + if (src_has_special_header) { + /* Sending move handles from client -> server is not allowed. */ + R_UNLESS(src_special_header.GetMoveHandleCount() == 0, svc::ResultInvalidCombination()); + } + + /* Prepare for further processing. */ + int pointer_key = 0; + int offset = dst_msg.Set(src_header); + + /* Set up a guard to make sure that we end up in a clean state on error. */ + auto cleanup_guard = SCOPE_GUARD { + /* TODO */ + MESOSPHERE_UNIMPLEMENTED(); + }; + + /* Process any special data. */ + if (src_header.GetHasSpecialHeader()) { + /* After we process, make sure we track whether the receive list is broken. */ + ON_SCOPE_EXIT { if (offset > dst_recv_list_idx) { recv_list_broken = true; } }; + + /* Process special data. */ + R_TRY(ProcessMessageSpecialData(offset, dst_process, src_process, src_thread, dst_msg, src_msg, src_special_header)); + } + + /* Process any pointer buffers. */ + for (auto i = 0; i < src_header.GetPointerCount(); ++i) { + MESOSPHERE_UNIMPLEMENTED(); + } + + /* Process any map alias buffers. */ + for (auto i = 0; i < src_header.GetMapAliasCount(); ++i) { + MESOSPHERE_UNIMPLEMENTED(); + } + + /* Process any raw data. */ + if (src_header.GetRawCount()) { + MESOSPHERE_UNIMPLEMENTED(); + } + + /* TODO: Remove this when done, as these variables will be used by unimplemented stuff above. */ + static_cast(dst_page_table); + static_cast(dst_user); + static_cast(src_user); + static_cast(pointer_key); + + /* We succeeded! */ + cleanup_guard.Cancel(); + return ResultSuccess(); + } + ALWAYS_INLINE void ReplyAsyncError(KProcess *to_process, uintptr_t to_msg_buf, size_t to_msg_buf_size, Result result) { /* Convert the buffer to a physical address. */ KPhysicalAddress phys_addr; @@ -28,7 +266,7 @@ namespace ams::kern { u32 *to_msg = GetPointer(KPageTable::GetHeapVirtualAddress(phys_addr)); /* Set the error. */ - ams::svc::ipc::MessageBuffer msg(to_msg, to_msg_buf_size); + ipc::MessageBuffer msg(to_msg, to_msg_buf_size); msg.SetAsyncResult(result); } @@ -44,8 +282,54 @@ namespace ams::kern { this->parent->Close(); } - Result KServerSession::ReceiveRequest(uintptr_t message, uintptr_t buffer_size, KPhysicalAddress message_paddr) { - MESOSPHERE_UNIMPLEMENTED(); + Result KServerSession::ReceiveRequest(uintptr_t server_message, uintptr_t server_buffer_size, KPhysicalAddress server_message_paddr) { + MESOSPHERE_ASSERT_THIS(); + + /* Lock the session. */ + KScopedLightLock lk(this->lock); + + /* Get the request and client thread. */ + KSessionRequest *request; + KScopedAutoObject client_thread; + { + KScopedSchedulerLock sl; + + /* Ensure that we can service the request. */ + R_UNLESS(!this->parent->IsClientClosed(), svc::ResultSessionClosed()); + + /* Ensure we aren't already servicing a request. */ + R_UNLESS(this->current_request == nullptr, svc::ResultNotFound()); + + /* Ensure we have a request to service. */ + R_UNLESS(!this->request_list.empty(), svc::ResultNotFound()); + + /* Pop the first request from the list. */ + request = std::addressof(this->request_list.front()); + this->request_list.pop_front(); + + /* Get the thread for the request. */ + client_thread = KScopedAutoObject(request->GetThread()); + R_UNLESS(client_thread.IsNotNull(), svc::ResultSessionClosed()); + } + + /* Set the request as our current. */ + this->current_request = request; + + /* Get the client address. */ + uintptr_t client_message = request->GetAddress(); + size_t client_buffer_size = request->GetSize(); + bool recv_list_broken = false; + + /* Receive the message. */ + Result result = ReceiveMessage(recv_list_broken, server_message, server_buffer_size, server_message_paddr, *client_thread.GetPointerUnsafe(), client_message, client_buffer_size, this, request); + + /* Handle cleanup on receive failure. */ + if (R_FAILED(result)) { + /* TODO */ + MESOSPHERE_UNIMPLEMENTED(); + } + + return result; } Result KServerSession::SendReply(uintptr_t message, uintptr_t buffer_size, KPhysicalAddress message_paddr) { diff --git a/libraries/libvapours/include/vapours/svc/ipc/svc_message_buffer.hpp b/libraries/libvapours/include/vapours/svc/ipc/svc_message_buffer.hpp index 1c379ca74..8c0c81aaf 100644 --- a/libraries/libvapours/include/vapours/svc/ipc/svc_message_buffer.hpp +++ b/libraries/libvapours/include/vapours/svc/ipc/svc_message_buffer.hpp @@ -439,34 +439,34 @@ namespace ams::svc::ipc { return index + (spc.GetHeaderSize() / sizeof(*this->buffer)); } - ALWAYS_INLINE s32 SetHandle(s32 index, const ::ams::svc::Handle &hnd) { + ALWAYS_INLINE s32 SetHandle(s32 index, const ::ams::svc::Handle &hnd) const { static_assert(util::IsAligned(sizeof(hnd), sizeof(*this->buffer))); __builtin_memcpy(this->buffer + index, std::addressof(hnd), sizeof(hnd)); return index + (sizeof(hnd) / sizeof(*this->buffer)); } - ALWAYS_INLINE s32 SetProcessId(s32 index, const u64 pid) { + ALWAYS_INLINE s32 SetProcessId(s32 index, const u64 pid) const { static_assert(util::IsAligned(sizeof(pid), sizeof(*this->buffer))); __builtin_memcpy(this->buffer + index, std::addressof(pid), sizeof(pid)); return index + (sizeof(pid) / sizeof(*this->buffer)); } - ALWAYS_INLINE s32 Set(s32 index, const MapAliasDescriptor &desc) { + ALWAYS_INLINE s32 Set(s32 index, const MapAliasDescriptor &desc) const { __builtin_memcpy(this->buffer + index, desc.GetData(), desc.GetDataSize()); return index + (desc.GetDataSize() / sizeof(*this->buffer)); } - ALWAYS_INLINE s32 Set(s32 index, const PointerDescriptor &desc) { + ALWAYS_INLINE s32 Set(s32 index, const PointerDescriptor &desc) const { __builtin_memcpy(this->buffer + index, desc.GetData(), desc.GetDataSize()); return index + (desc.GetDataSize() / sizeof(*this->buffer)); } - ALWAYS_INLINE s32 Set(s32 index, const ReceiveListEntry &desc) { + ALWAYS_INLINE s32 Set(s32 index, const ReceiveListEntry &desc) const { __builtin_memcpy(this->buffer + index, desc.GetData(), desc.GetDataSize()); return index + (desc.GetDataSize() / sizeof(*this->buffer)); } - ALWAYS_INLINE s32 Set(s32 index, const u32 val) { + ALWAYS_INLINE s32 Set(s32 index, const u32 val) const { static_assert(util::IsAligned(sizeof(val), sizeof(*this->buffer))); __builtin_memcpy(this->buffer + index, std::addressof(val), sizeof(val)); return index + (sizeof(val) / sizeof(*this->buffer)); @@ -521,7 +521,7 @@ namespace ams::svc::ipc { } } - static constexpr ALWAYS_INLINE s32 GetMessageBufferSize(const MessageHeader &hdr, const SpecialHeader &spc) { + static constexpr ALWAYS_INLINE size_t GetMessageBufferSize(const MessageHeader &hdr, const SpecialHeader &spc) { /* Get the size of the plain message. */ size_t msg_size = GetReceiveListIndex(hdr, spc) * sizeof(util::BitPack32);