2
1
Fork 0
mirror of https://github.com/yuzu-emu/yuzu.git synced 2024-07-04 23:31:19 +01:00

kernel: remove most SessionRequestManager handling from KServerSession

This commit is contained in:
Liam 2022-10-15 21:57:40 -04:00
parent 3efb8eb2dc
commit fca195b4fb
6 changed files with 119 additions and 138 deletions

View file

@ -86,13 +86,13 @@ public:
u32 num_domain_objects{}; u32 num_domain_objects{};
const bool always_move_handles{ const bool always_move_handles{
(static_cast<u32>(flags) & static_cast<u32>(Flags::AlwaysMoveHandles)) != 0}; (static_cast<u32>(flags) & static_cast<u32>(Flags::AlwaysMoveHandles)) != 0};
if (!ctx.Session()->IsDomain() || always_move_handles) { if (!ctx.Session()->GetSessionRequestManager()->IsDomain() || always_move_handles) {
num_handles_to_move = num_objects_to_move; num_handles_to_move = num_objects_to_move;
} else { } else {
num_domain_objects = num_objects_to_move; num_domain_objects = num_objects_to_move;
} }
if (ctx.Session()->IsDomain()) { if (ctx.Session()->GetSessionRequestManager()->IsDomain()) {
raw_data_size += raw_data_size +=
static_cast<u32>(sizeof(DomainMessageHeader) / sizeof(u32) + num_domain_objects); static_cast<u32>(sizeof(DomainMessageHeader) / sizeof(u32) + num_domain_objects);
ctx.write_size += num_domain_objects; ctx.write_size += num_domain_objects;
@ -125,7 +125,8 @@ public:
if (!ctx.IsTipc()) { if (!ctx.IsTipc()) {
AlignWithPadding(); AlignWithPadding();
if (ctx.Session()->IsDomain() && ctx.HasDomainMessageHeader()) { if (ctx.Session()->GetSessionRequestManager()->IsDomain() &&
ctx.HasDomainMessageHeader()) {
IPC::DomainMessageHeader domain_header{}; IPC::DomainMessageHeader domain_header{};
domain_header.num_objects = num_domain_objects; domain_header.num_objects = num_domain_objects;
PushRaw(domain_header); PushRaw(domain_header);
@ -145,7 +146,7 @@ public:
template <class T> template <class T>
void PushIpcInterface(std::shared_ptr<T> iface) { void PushIpcInterface(std::shared_ptr<T> iface) {
if (context->Session()->IsDomain()) { if (context->Session()->GetSessionRequestManager()->IsDomain()) {
context->AddDomainObject(std::move(iface)); context->AddDomainObject(std::move(iface));
} else { } else {
kernel.CurrentProcess()->GetResourceLimit()->Reserve( kernel.CurrentProcess()->GetResourceLimit()->Reserve(
@ -386,7 +387,7 @@ public:
template <class T> template <class T>
std::weak_ptr<T> PopIpcInterface() { std::weak_ptr<T> PopIpcInterface() {
ASSERT(context->Session()->IsDomain()); ASSERT(context->Session()->GetSessionRequestManager()->IsDomain());
ASSERT(context->GetDomainMessageHeader().input_object_count > 0); ASSERT(context->GetDomainMessageHeader().input_object_count > 0);
return context->GetDomainHandler<T>(Pop<u32>() - 1); return context->GetDomainHandler<T>(Pop<u32>() - 1);
} }

View file

@ -19,6 +19,7 @@
#include "core/hle/kernel/k_server_session.h" #include "core/hle/kernel/k_server_session.h"
#include "core/hle/kernel/k_thread.h" #include "core/hle/kernel/k_thread.h"
#include "core/hle/kernel/kernel.h" #include "core/hle/kernel/kernel.h"
#include "core/hle/kernel/service_thread.h"
#include "core/memory.h" #include "core/memory.h"
namespace Kernel { namespace Kernel {
@ -56,16 +57,103 @@ bool SessionRequestManager::HasSessionRequestHandler(const HLERequestContext& co
} }
} }
Result SessionRequestManager::CompleteSyncRequest(KServerSession* server_session,
HLERequestContext& context) {
Result result = ResultSuccess;
// If the session has been converted to a domain, handle the domain request
if (this->HasSessionRequestHandler(context)) {
if (IsDomain() && context.HasDomainMessageHeader()) {
result = HandleDomainSyncRequest(server_session, context);
// If there is no domain header, the regular session handler is used
} else if (this->HasSessionHandler()) {
// If this manager has an associated HLE handler, forward the request to it.
result = this->SessionHandler().HandleSyncRequest(*server_session, context);
}
} else {
ASSERT_MSG(false, "Session handler is invalid, stubbing response!");
IPC::ResponseBuilder rb(context, 2);
rb.Push(ResultSuccess);
}
if (convert_to_domain) {
ASSERT_MSG(!IsDomain(), "ServerSession is already a domain instance.");
this->ConvertToDomain();
convert_to_domain = false;
}
return result;
}
Result SessionRequestManager::HandleDomainSyncRequest(KServerSession* server_session,
HLERequestContext& context) {
if (!context.HasDomainMessageHeader()) {
return ResultSuccess;
}
// Set domain handlers in HLE context, used for domain objects (IPC interfaces) as inputs
context.SetSessionRequestManager(server_session->GetSessionRequestManager());
// If there is a DomainMessageHeader, then this is CommandType "Request"
const auto& domain_message_header = context.GetDomainMessageHeader();
const u32 object_id{domain_message_header.object_id};
switch (domain_message_header.command) {
case IPC::DomainMessageHeader::CommandType::SendMessage:
if (object_id > this->DomainHandlerCount()) {
LOG_CRITICAL(IPC,
"object_id {} is too big! This probably means a recent service call "
"needed to return a new interface!",
object_id);
ASSERT(false);
return ResultSuccess; // Ignore error if asserts are off
}
if (auto strong_ptr = this->DomainHandler(object_id - 1).lock()) {
return strong_ptr->HandleSyncRequest(*server_session, context);
} else {
ASSERT(false);
return ResultSuccess;
}
case IPC::DomainMessageHeader::CommandType::CloseVirtualHandle: {
LOG_DEBUG(IPC, "CloseVirtualHandle, object_id=0x{:08X}", object_id);
this->CloseDomainHandler(object_id - 1);
IPC::ResponseBuilder rb{context, 2};
rb.Push(ResultSuccess);
return ResultSuccess;
}
}
LOG_CRITICAL(IPC, "Unknown domain command={}", domain_message_header.command.Value());
ASSERT(false);
return ResultSuccess;
}
Result SessionRequestManager::QueueSyncRequest(KSession* parent,
std::shared_ptr<HLERequestContext>&& context) {
// Ensure we have a session request handler
if (this->HasSessionRequestHandler(*context)) {
if (auto strong_ptr = this->GetServiceThread().lock()) {
strong_ptr->QueueSyncRequest(*parent, std::move(context));
} else {
ASSERT_MSG(false, "strong_ptr is nullptr!");
}
} else {
ASSERT_MSG(false, "handler is invalid!");
}
return ResultSuccess;
}
void SessionRequestHandler::ClientConnected(KServerSession* session) { void SessionRequestHandler::ClientConnected(KServerSession* session) {
session->ClientConnected(shared_from_this()); session->GetSessionRequestManager()->SetSessionHandler(shared_from_this());
// Ensure our server session is tracked globally. // Ensure our server session is tracked globally.
kernel.RegisterServerObject(session); kernel.RegisterServerObject(session);
} }
void SessionRequestHandler::ClientDisconnected(KServerSession* session) { void SessionRequestHandler::ClientDisconnected(KServerSession* session) {}
session->ClientDisconnected();
}
HLERequestContext::HLERequestContext(KernelCore& kernel_, Core::Memory::Memory& memory_, HLERequestContext::HLERequestContext(KernelCore& kernel_, Core::Memory::Memory& memory_,
KServerSession* server_session_, KThread* thread_) KServerSession* server_session_, KThread* thread_)
@ -126,7 +214,7 @@ void HLERequestContext::ParseCommandBuffer(const KHandleTable& handle_table, u32
// Padding to align to 16 bytes // Padding to align to 16 bytes
rp.AlignWithPadding(); rp.AlignWithPadding();
if (Session()->IsDomain() && if (Session()->GetSessionRequestManager()->IsDomain() &&
((command_header->type == IPC::CommandType::Request || ((command_header->type == IPC::CommandType::Request ||
command_header->type == IPC::CommandType::RequestWithContext) || command_header->type == IPC::CommandType::RequestWithContext) ||
!incoming)) { !incoming)) {
@ -135,7 +223,7 @@ void HLERequestContext::ParseCommandBuffer(const KHandleTable& handle_table, u32
if (incoming || domain_message_header) { if (incoming || domain_message_header) {
domain_message_header = rp.PopRaw<IPC::DomainMessageHeader>(); domain_message_header = rp.PopRaw<IPC::DomainMessageHeader>();
} else { } else {
if (Session()->IsDomain()) { if (Session()->GetSessionRequestManager()->IsDomain()) {
LOG_WARNING(IPC, "Domain request has no DomainMessageHeader!"); LOG_WARNING(IPC, "Domain request has no DomainMessageHeader!");
} }
} }
@ -228,12 +316,12 @@ Result HLERequestContext::WriteToOutgoingCommandBuffer(KThread& requesting_threa
// Write the domain objects to the command buffer, these go after the raw untranslated data. // Write the domain objects to the command buffer, these go after the raw untranslated data.
// TODO(Subv): This completely ignores C buffers. // TODO(Subv): This completely ignores C buffers.
if (Session()->IsDomain()) { if (server_session->GetSessionRequestManager()->IsDomain()) {
current_offset = domain_offset - static_cast<u32>(outgoing_domain_objects.size()); current_offset = domain_offset - static_cast<u32>(outgoing_domain_objects.size());
for (const auto& object : outgoing_domain_objects) { for (auto& object : outgoing_domain_objects) {
server_session->AppendDomainHandler(object); server_session->GetSessionRequestManager()->AppendDomainHandler(std::move(object));
cmd_buf[current_offset++] = cmd_buf[current_offset++] = static_cast<u32_le>(
static_cast<u32_le>(server_session->NumDomainRequestHandlers()); server_session->GetSessionRequestManager()->DomainHandlerCount());
} }
} }

View file

@ -121,6 +121,10 @@ public:
is_domain = true; is_domain = true;
} }
void ConvertToDomainOnRequestEnd() {
convert_to_domain = true;
}
std::size_t DomainHandlerCount() const { std::size_t DomainHandlerCount() const {
return domain_handlers.size(); return domain_handlers.size();
} }
@ -164,7 +168,12 @@ public:
bool HasSessionRequestHandler(const HLERequestContext& context) const; bool HasSessionRequestHandler(const HLERequestContext& context) const;
Result HandleDomainSyncRequest(KServerSession* server_session, HLERequestContext& context);
Result CompleteSyncRequest(KServerSession* server_session, HLERequestContext& context);
Result QueueSyncRequest(KSession* parent, std::shared_ptr<HLERequestContext>&& context);
private: private:
bool convert_to_domain{};
bool is_domain{}; bool is_domain{};
SessionRequestHandlerPtr session_handler; SessionRequestHandlerPtr session_handler;
std::vector<SessionRequestHandlerPtr> domain_handlers; std::vector<SessionRequestHandlerPtr> domain_handlers;

View file

@ -22,7 +22,6 @@
#include "core/hle/kernel/k_thread.h" #include "core/hle/kernel/k_thread.h"
#include "core/hle/kernel/k_thread_queue.h" #include "core/hle/kernel/k_thread_queue.h"
#include "core/hle/kernel/kernel.h" #include "core/hle/kernel/kernel.h"
#include "core/hle/kernel/service_thread.h"
#include "core/memory.h" #include "core/memory.h"
namespace Kernel { namespace Kernel {
@ -74,101 +73,17 @@ bool KServerSession::IsSignaled() const {
return !m_request_list.empty() && m_current_request == nullptr; return !m_request_list.empty() && m_current_request == nullptr;
} }
void KServerSession::AppendDomainHandler(SessionRequestHandlerPtr handler) {
manager->AppendDomainHandler(std::move(handler));
}
std::size_t KServerSession::NumDomainRequestHandlers() const {
return manager->DomainHandlerCount();
}
Result KServerSession::HandleDomainSyncRequest(Kernel::HLERequestContext& context) {
if (!context.HasDomainMessageHeader()) {
return ResultSuccess;
}
// Set domain handlers in HLE context, used for domain objects (IPC interfaces) as inputs
context.SetSessionRequestManager(manager);
// If there is a DomainMessageHeader, then this is CommandType "Request"
const auto& domain_message_header = context.GetDomainMessageHeader();
const u32 object_id{domain_message_header.object_id};
switch (domain_message_header.command) {
case IPC::DomainMessageHeader::CommandType::SendMessage:
if (object_id > manager->DomainHandlerCount()) {
LOG_CRITICAL(IPC,
"object_id {} is too big! This probably means a recent service call "
"to {} needed to return a new interface!",
object_id, name);
ASSERT(false);
return ResultSuccess; // Ignore error if asserts are off
}
if (auto strong_ptr = manager->DomainHandler(object_id - 1).lock()) {
return strong_ptr->HandleSyncRequest(*this, context);
} else {
ASSERT(false);
return ResultSuccess;
}
case IPC::DomainMessageHeader::CommandType::CloseVirtualHandle: {
LOG_DEBUG(IPC, "CloseVirtualHandle, object_id=0x{:08X}", object_id);
manager->CloseDomainHandler(object_id - 1);
IPC::ResponseBuilder rb{context, 2};
rb.Push(ResultSuccess);
return ResultSuccess;
}
}
LOG_CRITICAL(IPC, "Unknown domain command={}", domain_message_header.command.Value());
ASSERT(false);
return ResultSuccess;
}
Result KServerSession::QueueSyncRequest(KThread* thread, Core::Memory::Memory& memory) { Result KServerSession::QueueSyncRequest(KThread* thread, Core::Memory::Memory& memory) {
u32* cmd_buf{reinterpret_cast<u32*>(memory.GetPointer(thread->GetTLSAddress()))}; u32* cmd_buf{reinterpret_cast<u32*>(memory.GetPointer(thread->GetTLSAddress()))};
auto context = std::make_shared<HLERequestContext>(kernel, memory, this, thread); auto context = std::make_shared<HLERequestContext>(kernel, memory, this, thread);
context->PopulateFromIncomingCommandBuffer(kernel.CurrentProcess()->GetHandleTable(), cmd_buf); context->PopulateFromIncomingCommandBuffer(kernel.CurrentProcess()->GetHandleTable(), cmd_buf);
// Ensure we have a session request handler return manager->QueueSyncRequest(parent, std::move(context));
if (manager->HasSessionRequestHandler(*context)) {
if (auto strong_ptr = manager->GetServiceThread().lock()) {
strong_ptr->QueueSyncRequest(*parent, std::move(context));
} else {
ASSERT_MSG(false, "strong_ptr is nullptr!");
}
} else {
ASSERT_MSG(false, "handler is invalid!");
}
return ResultSuccess;
} }
Result KServerSession::CompleteSyncRequest(HLERequestContext& context) { Result KServerSession::CompleteSyncRequest(HLERequestContext& context) {
Result result = ResultSuccess; Result result = manager->CompleteSyncRequest(this, context);
// If the session has been converted to a domain, handle the domain request
if (manager->HasSessionRequestHandler(context)) {
if (IsDomain() && context.HasDomainMessageHeader()) {
result = HandleDomainSyncRequest(context);
// If there is no domain header, the regular session handler is used
} else if (manager->HasSessionHandler()) {
// If this ServerSession has an associated HLE handler, forward the request to it.
result = manager->SessionHandler().HandleSyncRequest(*this, context);
}
} else {
ASSERT_MSG(false, "Session handler is invalid, stubbing response!");
IPC::ResponseBuilder rb(context, 2);
rb.Push(ResultSuccess);
}
if (convert_to_domain) {
ASSERT_MSG(!IsDomain(), "ServerSession is already a domain instance.");
manager->ConvertToDomain();
convert_to_domain = false;
}
// The calling thread is waiting for this request to complete, so wake it up. // The calling thread is waiting for this request to complete, so wake it up.
context.GetThread().EndWait(result); context.GetThread().EndWait(result);

View file

@ -58,37 +58,8 @@ public:
} }
bool IsSignaled() const override; bool IsSignaled() const override;
void OnClientClosed(); void OnClientClosed();
void ClientConnected(SessionRequestHandlerPtr handler) {
if (manager) {
manager->SetSessionHandler(std::move(handler));
}
}
void ClientDisconnected() {
manager = nullptr;
}
/// Adds a new domain request handler to the collection of request handlers within
/// this ServerSession instance.
void AppendDomainHandler(SessionRequestHandlerPtr handler);
/// Retrieves the total number of domain request handlers that have been
/// appended to this ServerSession instance.
std::size_t NumDomainRequestHandlers() const;
/// Returns true if the session has been converted to a domain, otherwise False
bool IsDomain() const {
return manager && manager->IsDomain();
}
/// Converts the session to a domain at the end of the current command
void ConvertToDomain() {
convert_to_domain = true;
}
/// Gets the session request manager, which forwards requests to the underlying service /// Gets the session request manager, which forwards requests to the underlying service
std::shared_ptr<SessionRequestManager>& GetSessionRequestManager() { std::shared_ptr<SessionRequestManager>& GetSessionRequestManager() {
return manager; return manager;
@ -109,10 +80,6 @@ private:
/// Completes a sync request from the emulated application. /// Completes a sync request from the emulated application.
Result CompleteSyncRequest(HLERequestContext& context); Result CompleteSyncRequest(HLERequestContext& context);
/// Handles a SyncRequest to a domain, forwarding the request to the proper object or closing an
/// object handle.
Result HandleDomainSyncRequest(Kernel::HLERequestContext& context);
/// This session's HLE request handlers; if nullptr, this is not an HLE server /// This session's HLE request handlers; if nullptr, this is not an HLE server
std::shared_ptr<SessionRequestManager> manager; std::shared_ptr<SessionRequestManager> manager;

View file

@ -15,9 +15,10 @@
namespace Service::SM { namespace Service::SM {
void Controller::ConvertCurrentObjectToDomain(Kernel::HLERequestContext& ctx) { void Controller::ConvertCurrentObjectToDomain(Kernel::HLERequestContext& ctx) {
ASSERT_MSG(!ctx.Session()->IsDomain(), "Session is already a domain"); ASSERT_MSG(!ctx.Session()->GetSessionRequestManager()->IsDomain(),
"Session is already a domain");
LOG_DEBUG(Service, "called, server_session={}", ctx.Session()->GetId()); LOG_DEBUG(Service, "called, server_session={}", ctx.Session()->GetId());
ctx.Session()->ConvertToDomain(); ctx.Session()->GetSessionRequestManager()->ConvertToDomainOnRequestEnd();
IPC::ResponseBuilder rb{ctx, 3}; IPC::ResponseBuilder rb{ctx, 3};
rb.Push(ResultSuccess); rb.Push(ResultSuccess);