1
0
Fork 0
mirror of https://github.com/Atmosphere-NX/Atmosphere.git synced 2024-11-26 22:02:15 +00:00

kern/ipc: implement most of reply

This commit is contained in:
Michael Scire 2020-07-10 13:42:36 -07:00
parent b29dc76b20
commit 1b429918de

View file

@ -48,13 +48,31 @@ namespace ams::kern {
}
}
public:
ReceiveList(const u32 *dst_msg, const ipc::MessageBuffer::MessageHeader &dst_header, const ipc::MessageBuffer::SpecialHeader &dst_special_header, size_t msg_size, size_t out_offset, s32 dst_recv_list_idx) {
ReceiveList(const u32 *dst_msg, uintptr_t dst_address, const KProcessPageTable &dst_page_table, const ipc::MessageBuffer::MessageHeader &dst_header, const ipc::MessageBuffer::SpecialHeader &dst_special_header, size_t msg_size, size_t out_offset, s32 dst_recv_list_idx, bool is_tls) {
this->recv_list_count = dst_header.GetReceiveListCount();
this->msg_buffer_end = reinterpret_cast<uintptr_t>(dst_msg) + sizeof(u32) * out_offset;
this->msg_buffer_space_end = reinterpret_cast<uintptr_t>(dst_msg) + msg_size;
this->msg_buffer_end = dst_address + sizeof(u32) * out_offset;
this->msg_buffer_space_end = dst_address + msg_size;
const u32 *recv_list = dst_msg + dst_recv_list_idx;
__builtin_memcpy(this->data, recv_list, GetEntryCount(dst_header) * ipc::MessageBuffer::ReceiveListEntry::GetDataSize());
const auto entry_count = GetEntryCount(dst_header);
if (is_tls) {
__builtin_memcpy(this->data, recv_list, entry_count * ipc::MessageBuffer::ReceiveListEntry::GetDataSize());
} else {
uintptr_t page_addr = util::AlignDown(dst_address, PageSize);
uintptr_t cur_addr = dst_address + dst_recv_list_idx * sizeof(u32);
for (size_t i = 0; i < entry_count * ipc::MessageBuffer::ReceiveListEntry::GetDataSize() / sizeof(u32); ++i) {
if (page_addr != util::AlignDown(cur_addr, PageSize)) {
KPhysicalAddress phys_addr;
dst_page_table.GetPhysicalAddress(std::addressof(phys_addr), KProcessAddress(cur_addr));
recv_list = GetPointer<u32>(KPageTable::GetHeapVirtualAddress(phys_addr));
page_addr = util::AlignDown(cur_addr, PageSize);
}
this->data[i] = *(recv_list++);
cur_addr += sizeof(u32);
}
}
}
constexpr bool IsIndex() const {
@ -513,8 +531,8 @@ namespace ams::kern {
R_UNLESS(dst_buffer_size >= src_end_offset * sizeof(u32), svc::ResultMessageTooLarge());
/* Get the receive list. */
const s32 dst_recv_list_idx = static_cast<s32>(ipc::MessageBuffer::GetReceiveListIndex(dst_header, dst_special_header));
ReceiveList dst_recv_list(dst_msg_ptr, dst_header, dst_special_header, dst_buffer_size, src_end_offset, dst_recv_list_idx);
const s32 dst_recv_list_idx = ipc::MessageBuffer::GetReceiveListIndex(dst_header, dst_special_header);
ReceiveList dst_recv_list(dst_msg_ptr, dst_message_buffer, dst_page_table, dst_header, dst_special_header, dst_buffer_size, src_end_offset, dst_recv_list_idx, !dst_user);
/* Ensure that the source special header isn't invalid. */
const bool src_has_special_header = src_header.GetHasSpecialHeader();
@ -631,10 +649,186 @@ namespace ams::kern {
return ResultSuccess();
}
ALWAYS_INLINE Result SendMessage(uintptr_t src_message_buffer, size_t src_buffer_size, KPhysicalAddress src_message_paddr, KThread &dst_thread, uintptr_t dst_message_buffer, size_t dst_buffer_size, KServerSession *session, KSessionRequest *request) {
ALWAYS_INLINE Result ProcessSendMessageReceiveMapping(KProcessPageTable &dst_page_table, KProcessAddress client_address, KProcessAddress server_address, size_t size, KMemoryState src_state) {
MESOSPHERE_UNIMPLEMENTED();
}
ALWAYS_INLINE Result ProcessSendMessagePointerDescriptors(int &offset, int &pointer_key, KProcessPageTable &dst_page_table, KProcessPageTable &src_page_table, const ipc::MessageBuffer &dst_msg, const ipc::MessageBuffer &src_msg, const ReceiveList &dst_recv_list, bool dst_user) {
MESOSPHERE_UNIMPLEMENTED();
}
ALWAYS_INLINE Result SendMessage(uintptr_t src_message_buffer, size_t src_buffer_size, KPhysicalAddress src_message_paddr, KThread &dst_thread, uintptr_t dst_message_buffer, size_t dst_buffer_size, KServerSession *session, KSessionRequest *request) {
/* Prepare variables for send. */
KThread &src_thread = GetCurrentThread();
KProcess &dst_process = *(dst_thread.GetOwnerProcess());
KProcess &src_process = *(src_thread.GetOwnerProcess());
auto &dst_page_table = dst_process.GetPageTable();
auto &src_page_table = src_process.GetPageTable();
/* Determine the message buffers. */
u32 *dst_msg_ptr, *src_msg_ptr;
bool dst_user, src_user;
if (dst_message_buffer) {
/* NOTE: Nintendo does not check the result of this GetPhysicalAddress call. */
KPhysicalAddress dst_message_paddr;
dst_page_table.GetPhysicalAddress(std::addressof(dst_message_paddr), dst_message_buffer);
dst_msg_ptr = GetPointer<u32>(KPageTable::GetHeapVirtualAddress(dst_message_paddr));
dst_user = true;
} else {
dst_msg_ptr = static_cast<ams::svc::ThreadLocalRegion *>(dst_thread.GetThreadLocalRegionHeapAddress())->message_buffer;
dst_buffer_size = sizeof(ams::svc::ThreadLocalRegion{}.message_buffer);
dst_message_buffer = GetInteger(dst_thread.GetThreadLocalRegionAddress());
dst_user = false;
}
if (src_message_buffer) {
src_msg_ptr = GetPointer<u32>(KPageTable::GetHeapVirtualAddress(src_message_paddr));
src_user = true;
} else {
src_msg_ptr = static_cast<ams::svc::ThreadLocalRegion *>(src_thread.GetThreadLocalRegionHeapAddress())->message_buffer;
src_buffer_size = sizeof(ams::svc::ThreadLocalRegion{}.message_buffer);
src_message_buffer = GetInteger(src_thread.GetThreadLocalRegionAddress());
src_user = false;
}
/* Parse the headers. */
const ipc::MessageBuffer dst_msg(dst_msg_ptr, dst_buffer_size);
const ipc::MessageBuffer src_msg(src_msg_ptr, src_buffer_size);
const ipc::MessageBuffer::MessageHeader dst_header(dst_msg);
const ipc::MessageBuffer::MessageHeader src_header(src_msg);
const ipc::MessageBuffer::SpecialHeader dst_special_header(dst_msg, dst_header);
const ipc::MessageBuffer::SpecialHeader src_special_header(src_msg, src_header);
/* Get the end of the source message. */
const size_t src_end_offset = ipc::MessageBuffer::GetRawDataIndex(src_header, src_special_header) + src_header.GetRawCount();
/* Declare variables for processing. */
int offset = 0;
int pointer_key = 0;
bool processed_special_data = false;
/* Set up a guard to make sure that we end up in a clean state on error. */
auto cleanup_guard = SCOPE_GUARD {
/* Cleanup special data. */
if (processed_special_data) {
if (src_header.GetHasSpecialHeader()) {
CleanupSpecialData(dst_process, dst_msg_ptr, dst_buffer_size);
}
} else {
CleanupServerHandles(src_message_buffer, src_buffer_size, src_message_paddr);
}
/* Cleanup mappings. */
CleanupMap(request, std::addressof(src_process), std::addressof(dst_page_table));
};
/* Ensure that the headers fit. */
R_UNLESS(ipc::MessageBuffer::GetMessageBufferSize(src_header, src_special_header) <= src_buffer_size, svc::ResultInvalidCombination());
R_UNLESS(ipc::MessageBuffer::GetMessageBufferSize(dst_header, dst_special_header) <= dst_buffer_size, svc::ResultInvalidCombination());
/* Ensure the receive list offset is after the end of raw data. */
if (dst_header.GetReceiveListOffset()) {
R_UNLESS(dst_header.GetReceiveListOffset() >= ipc::MessageBuffer::GetRawDataIndex(dst_header, dst_special_header) + dst_header.GetRawCount(), svc::ResultInvalidCombination());
}
/* Ensure that the destination buffer is big enough to receive the source. */
R_UNLESS(dst_buffer_size >= src_end_offset * sizeof(u32), svc::ResultMessageTooLarge());
/* Replies must have no buffers. */
R_UNLESS(src_header.GetSendCount() == 0, svc::ResultInvalidCombination());
R_UNLESS(src_header.GetReceiveCount() == 0, svc::ResultInvalidCombination());
R_UNLESS(src_header.GetExchangeCount() == 0, svc::ResultInvalidCombination());
/* Get the receive list. */
const s32 dst_recv_list_idx = ipc::MessageBuffer::GetReceiveListIndex(dst_header, dst_special_header);
ReceiveList dst_recv_list(dst_msg_ptr, dst_message_buffer, dst_page_table, dst_header, dst_special_header, dst_buffer_size, src_end_offset, dst_recv_list_idx, !dst_user);
/* Handle any receive buffers. */
for (size_t i = 0; i < request->GetReceiveCount(); ++i) {
R_TRY(ProcessSendMessageReceiveMapping(dst_page_table, request->GetReceiveClientAddress(i), request->GetReceiveServerAddress(i), request->GetReceiveSize(i), request->GetReceiveMemoryState(i)));
}
/* Handle any exchange buffers. */
for (size_t i = 0; i < request->GetReceiveCount(); ++i) {
R_TRY(ProcessSendMessageReceiveMapping(dst_page_table, request->GetExchangeClientAddress(i), request->GetExchangeServerAddress(i), request->GetExchangeSize(i), request->GetExchangeMemoryState(i)));
}
/* Set the header. */
offset = dst_msg.Set(src_header);
/* Process any special data. */
MESOSPHERE_ASSERT(GetCurrentThreadPointer() == std::addressof(src_thread));
processed_special_data = true;
if (src_header.GetHasSpecialHeader()) {
R_TRY(ProcessMessageSpecialData<true>(offset, dst_process, src_process, src_thread, dst_msg, src_msg, src_special_header));
}
/* Process any pointer buffers. */
for (auto i = 0; i < src_header.GetPointerCount(); ++i) {
R_TRY(ProcessSendMessagePointerDescriptors(offset, pointer_key, dst_page_table, src_page_table, dst_msg, src_msg, dst_recv_list, dst_user && dst_header.GetReceiveListCount() == ipc::MessageBuffer::MessageHeader::ReceiveListCountType_ToMessageBuffer));
}
/* Clear any map alias buffers. */
for (auto i = 0; i < src_header.GetMapAliasCount(); ++i) {
offset = dst_msg.Set(offset, ipc::MessageBuffer::MapAliasDescriptor());
}
/* Process any raw data. */
if (const auto raw_count = src_header.GetRawCount(); raw_count != 0) {
/* Get the offset and size. */
const size_t offset_words = offset * sizeof(u32);
const size_t raw_size = raw_count * sizeof(u32);
/* Fast case is TLS -> TLS, do raw memcpy if we can. */
if (!dst_user && !src_user) {
std::memcpy(dst_msg_ptr + offset, src_msg_ptr + offset, raw_size);
} else if (src_user) {
/* Determine how much fast size we can copy. */
const size_t max_fast_size = std::min<size_t>(offset_words + raw_size, PageSize);
const size_t fast_size = max_fast_size - offset_words;
/* Determine the dst permission. User buffer should be unmapped + read, TLS should be user readable. */
const KMemoryPermission dst_perm = static_cast<KMemoryPermission>(dst_user ? KMemoryPermission_NotMapped | KMemoryPermission_KernelReadWrite : KMemoryPermission_UserReadWrite);
/* Perform the fast part of the copy. */
R_TRY(dst_page_table.CopyMemoryFromKernelToLinear(dst_message_buffer + offset_words, fast_size,
KMemoryState_FlagReferenceCounted, KMemoryState_FlagReferenceCounted,
dst_perm,
KMemoryAttribute_Uncached, KMemoryAttribute_None,
reinterpret_cast<uintptr_t>(src_msg_ptr) + offset_words));
/* If the fast part of the copy didn't get everything, perform the slow part of the copy. */
if (fast_size < raw_size) {
R_TRY(src_page_table.CopyMemoryFromLinearToLinear(dst_page_table, dst_message_buffer + max_fast_size, raw_size - fast_size,
KMemoryState_FlagReferenceCounted, KMemoryState_FlagReferenceCounted,
dst_perm,
KMemoryAttribute_Uncached, KMemoryAttribute_None,
src_message_buffer + max_fast_size,
KMemoryState_FlagReferenceCounted, KMemoryState_FlagReferenceCounted,
static_cast<KMemoryPermission>(KMemoryPermission_NotMapped | KMemoryPermission_KernelRead),
KMemoryAttribute_AnyLocked | KMemoryAttribute_Uncached | KMemoryAttribute_Locked, KMemoryAttribute_AnyLocked | KMemoryAttribute_Locked));
}
} else /* if (dst_user) */ {
/* The destination is a user buffer, so it should be unmapped + readable. */
constexpr KMemoryPermission DestinationPermission = static_cast<KMemoryPermission>(KMemoryPermission_NotMapped | KMemoryPermission_KernelReadWrite);
/* Copy the memory. */
R_TRY(src_page_table.CopyMemoryFromUserToLinear(dst_message_buffer + offset_words, raw_size,
KMemoryState_FlagReferenceCounted, KMemoryState_FlagReferenceCounted,
DestinationPermission,
KMemoryAttribute_Uncached, KMemoryAttribute_None,
src_message_buffer + offset_words));
}
}
/* We succeeded. Perform cleanup with validation. */
cleanup_guard.Cancel();
return CleanupMap(request, std::addressof(src_process), std::addressof(dst_page_table));
}
ALWAYS_INLINE void ReplyAsyncError(KProcess *to_process, uintptr_t to_msg_buf, size_t to_msg_buf_size, Result result) {
/* Convert the buffer to a physical address. */
KPhysicalAddress phys_addr;