Scripting: reimplement protocol over plan UDP using boost::asio

This commit is contained in:
Weiyi Wang 2019-01-29 17:28:50 -05:00
parent bad2e084e3
commit d765a73a53
9 changed files with 152 additions and 134 deletions

View file

@ -1,22 +1,22 @@
import zmq
import struct import struct
import random import random
import enum import enum
import socket
CURRENT_REQUEST_VERSION = 1 CURRENT_REQUEST_VERSION = 1
MAX_REQUEST_DATA_SIZE = 32 MAX_REQUEST_DATA_SIZE = 32
MAX_PACKET_SIZE = 48
class RequestType(enum.IntEnum): class RequestType(enum.IntEnum):
ReadMemory = 1, ReadMemory = 1,
WriteMemory = 2 WriteMemory = 2
CITRA_PORT = "45987" CITRA_PORT = 45987
class Citra: class Citra:
def __init__(self, address="127.0.0.1", port=CITRA_PORT): def __init__(self, address="127.0.0.1", port=CITRA_PORT):
self.context = zmq.Context() self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.socket = self.context.socket(zmq.REQ) self.address = address
self.socket.connect("tcp://" + address + ":" + port)
def is_connected(self): def is_connected(self):
return self.socket is not None return self.socket is not None
@ -45,9 +45,9 @@ class Citra:
request_data = struct.pack("II", read_address, temp_read_size) request_data = struct.pack("II", read_address, temp_read_size)
request, request_id = self._generate_header(RequestType.ReadMemory, len(request_data)) request, request_id = self._generate_header(RequestType.ReadMemory, len(request_data))
request += request_data request += request_data
self.socket.send(request) self.socket.sendto(request, (self.address, CITRA_PORT))
raw_reply = self.socket.recv() raw_reply = self.socket.recv(MAX_PACKET_SIZE)
reply_data = self._read_and_validate_header(raw_reply, request_id, RequestType.ReadMemory) reply_data = self._read_and_validate_header(raw_reply, request_id, RequestType.ReadMemory)
if reply_data: if reply_data:
@ -77,9 +77,9 @@ class Citra:
request_data += write_contents[:temp_write_size] request_data += write_contents[:temp_write_size]
request, request_id = self._generate_header(RequestType.WriteMemory, len(request_data)) request, request_id = self._generate_header(RequestType.WriteMemory, len(request_data))
request += request_data request += request_data
self.socket.send(request) self.socket.sendto(request, (self.address, CITRA_PORT))
raw_reply = self.socket.recv() raw_reply = self.socket.recv(MAX_PACKET_SIZE)
reply_data = self._read_and_validate_header(raw_reply, request_id, RequestType.WriteMemory) reply_data = self._read_and_validate_header(raw_reply, request_id, RequestType.WriteMemory)
if None != reply_data: if None != reply_data:

View file

@ -439,8 +439,8 @@ if (ENABLE_SCRIPTING)
rpc/rpc_server.h rpc/rpc_server.h
rpc/server.cpp rpc/server.cpp
rpc/server.h rpc/server.h
rpc/zmq_server.cpp rpc/udp_server.cpp
rpc/zmq_server.h rpc/udp_server.h
) )
endif() endif()

View file

@ -13,6 +13,9 @@
namespace RPC { namespace RPC {
class Packet;
struct PacketHeader;
class RPCServer { class RPCServer {
public: public:
RPCServer(); RPCServer();

View file

@ -1,27 +1,30 @@
#include <functional> #include <functional>
#include "core/core.h" #include "core/core.h"
#include "core/rpc/packet.h"
#include "core/rpc/rpc_server.h" #include "core/rpc/rpc_server.h"
#include "core/rpc/server.h" #include "core/rpc/server.h"
#include "core/rpc/udp_server.h"
namespace RPC { namespace RPC {
Server::Server(RPCServer& rpc_server) : rpc_server(rpc_server) {} Server::Server(RPCServer& rpc_server) : rpc_server(rpc_server) {}
Server::~Server() = default;
void Server::Start() { void Server::Start() {
const auto callback = [this](std::unique_ptr<RPC::Packet> new_request) { const auto callback = [this](std::unique_ptr<Packet> new_request) {
NewRequestCallback(std::move(new_request)); NewRequestCallback(std::move(new_request));
}; };
try { try {
zmq_server = std::make_unique<ZMQServer>(callback); udp_server = std::make_unique<UDPServer>(callback);
} catch (...) { } catch (...) {
LOG_ERROR(RPC_Server, "Error starting ZeroMQ server"); LOG_ERROR(RPC_Server, "Error starting UDP server");
} }
} }
void Server::Stop() { void Server::Stop() {
zmq_server.reset(); udp_server.reset();
} }
void Server::NewRequestCallback(std::unique_ptr<RPC::Packet> new_request) { void Server::NewRequestCallback(std::unique_ptr<RPC::Packet> new_request) {

View file

@ -4,24 +4,25 @@
#pragma once #pragma once
#include "core/rpc/packet.h" #include <memory>
#include "core/rpc/zmq_server.h"
namespace RPC { namespace RPC {
class RPCServer; class RPCServer;
class ZMQServer; class UDPServer;
class Packet;
class Server { class Server {
public: public:
Server(RPCServer& rpc_server); Server(RPCServer& rpc_server);
~Server();
void Start(); void Start();
void Stop(); void Stop();
void NewRequestCallback(std::unique_ptr<RPC::Packet> new_request); void NewRequestCallback(std::unique_ptr<Packet> new_request);
private: private:
RPCServer& rpc_server; RPCServer& rpc_server;
std::unique_ptr<ZMQServer> zmq_server; std::unique_ptr<UDPServer> udp_server;
}; };
} // namespace RPC } // namespace RPC

100
src/core/rpc/udp_server.cpp Normal file
View file

@ -0,0 +1,100 @@
// Copyright 2019 Citra Emulator Project
// Licensed under GPLv2 or any later version
// Refer to the license.txt file included.
#include <thread>
#include <boost/asio.hpp>
#include "common/common_types.h"
#include "common/logging/log.h"
#include "core/rpc/packet.h"
#include "core/rpc/udp_server.h"
namespace RPC {
class UDPServer::Impl {
public:
explicit Impl(std::function<void(std::unique_ptr<Packet>)> new_request_callback)
// Use a random high port
// TODO: Make configurable or increment port number on failure
: socket(io_context, boost::asio::ip::udp::endpoint(boost::asio::ip::udp::v4(), 45987)),
new_request_callback(std::move(new_request_callback)) {
StartReceive();
worker_thread = std::thread([this] {
io_context.run();
this->new_request_callback(nullptr);
});
}
~Impl() {
io_context.stop();
worker_thread.join();
}
private:
void StartReceive() {
socket.async_receive_from(boost::asio::buffer(request_buffer), remote_endpoint,
[this](const boost::system::error_code& error, std::size_t size) {
HandleReceive(error, size);
});
}
void HandleReceive(const boost::system::error_code& error, std::size_t size) {
if (error) {
LOG_WARNING(RPC_Server, "Failed to receive data on UDP socket: {}", error.message());
} else if (size >= MIN_PACKET_SIZE && size <= MAX_PACKET_SIZE) {
PacketHeader header;
std::memcpy(&header, request_buffer.data(), sizeof(header));
if ((size - MIN_PACKET_SIZE) == header.packet_size) {
u8* data = request_buffer.data() + MIN_PACKET_SIZE;
std::function<void(Packet&)> send_reply_callback =
std::bind(&Impl::SendReply, this, remote_endpoint, std::placeholders::_1);
std::unique_ptr<Packet> new_packet =
std::make_unique<Packet>(header, data, send_reply_callback);
// Send the request to the upper layer for handling
new_request_callback(std::move(new_packet));
}
} else {
LOG_WARNING(RPC_Server, "Received message with wrong size: {}", size);
}
StartReceive();
}
void SendReply(boost::asio::ip::udp::endpoint endpoint, Packet& reply_packet) {
std::vector<u8> reply_buffer(MIN_PACKET_SIZE + reply_packet.GetPacketDataSize());
auto reply_header = reply_packet.GetHeader();
std::memcpy(reply_buffer.data(), &reply_header, sizeof(reply_header));
std::memcpy(reply_buffer.data() + (4 * sizeof(u32)), reply_packet.GetPacketData().data(),
reply_packet.GetPacketDataSize());
boost::system::error_code error;
socket.send_to(boost::asio::buffer(reply_buffer), endpoint, 0, error);
if (error) {
LOG_WARNING(RPC_Server, "Failed to send reply: {}", error.message());
} else {
LOG_INFO(RPC_Server, "Sent reply version({}) id=({}) type=({}) size=({})",
reply_packet.GetVersion(), reply_packet.GetId(),
static_cast<u32>(reply_packet.GetPacketType()),
reply_packet.GetPacketDataSize());
}
}
std::thread worker_thread;
boost::asio::io_context io_context;
boost::asio::ip::udp::socket socket;
std::array<u8, MAX_PACKET_SIZE> request_buffer;
boost::asio::ip::udp::endpoint remote_endpoint;
std::function<void(std::unique_ptr<Packet>)> new_request_callback;
};
UDPServer::UDPServer(std::function<void(std::unique_ptr<Packet>)> new_request_callback)
: impl(std::make_unique<Impl>(new_request_callback)) {}
UDPServer::~UDPServer() = default;
} // namespace RPC

24
src/core/rpc/udp_server.h Normal file
View file

@ -0,0 +1,24 @@
// Copyright 2019 Citra Emulator Project
// Licensed under GPLv2 or any later version
// Refer to the license.txt file included.
#pragma once
#include <functional>
#include <memory>
namespace RPC {
class Packet;
class UDPServer {
public:
explicit UDPServer(std::function<void(std::unique_ptr<Packet>)> new_request_callback);
~UDPServer();
private:
class Impl;
std::unique_ptr<Impl> impl;
};
} // namespace RPC

View file

@ -1,79 +0,0 @@
#include "common/common_types.h"
#include "core/core.h"
#include "core/rpc/packet.h"
#include "core/rpc/zmq_server.h"
namespace RPC {
ZMQServer::ZMQServer(std::function<void(std::unique_ptr<Packet>)> new_request_callback)
: zmq_context(std::move(std::make_unique<zmq::context_t>(1))),
zmq_socket(std::move(std::make_unique<zmq::socket_t>(*zmq_context, ZMQ_REP))),
new_request_callback(std::move(new_request_callback)) {
// Use a random high port
// TODO: Make configurable or increment port number on failure
zmq_socket->bind("tcp://127.0.0.1:45987");
LOG_INFO(RPC_Server, "ZeroMQ listening on port 45987");
worker_thread = std::thread(&ZMQServer::WorkerLoop, this);
}
ZMQServer::~ZMQServer() {
// Triggering the zmq_context destructor will cancel
// any blocking calls to zmq_socket->recv()
running = false;
zmq_context.reset();
worker_thread.join();
LOG_INFO(RPC_Server, "ZeroMQ stopped");
}
void ZMQServer::WorkerLoop() {
zmq::message_t request;
while (running) {
try {
if (zmq_socket->recv(&request, 0)) {
if (request.size() >= MIN_PACKET_SIZE && request.size() <= MAX_PACKET_SIZE) {
u8* request_buffer = static_cast<u8*>(request.data());
PacketHeader header;
std::memcpy(&header, request_buffer, sizeof(header));
if ((request.size() - MIN_PACKET_SIZE) == header.packet_size) {
u8* data = request_buffer + MIN_PACKET_SIZE;
std::function<void(Packet&)> send_reply_callback =
std::bind(&ZMQServer::SendReply, this, std::placeholders::_1);
std::unique_ptr<Packet> new_packet =
std::make_unique<Packet>(header, data, send_reply_callback);
// Send the request to the upper layer for handling
new_request_callback(std::move(new_packet));
}
}
}
} catch (...) {
LOG_WARNING(RPC_Server, "Failed to receive data on ZeroMQ socket");
}
}
std::unique_ptr<Packet> end_packet = nullptr;
new_request_callback(std::move(end_packet));
// Destroying the socket must be done by this thread.
zmq_socket.reset();
}
void ZMQServer::SendReply(Packet& reply_packet) {
if (running) {
auto reply_buffer =
std::make_unique<u8[]>(MIN_PACKET_SIZE + reply_packet.GetPacketDataSize());
auto reply_header = reply_packet.GetHeader();
std::memcpy(reply_buffer.get(), &reply_header, sizeof(reply_header));
std::memcpy(reply_buffer.get() + (4 * sizeof(u32)), reply_packet.GetPacketData().data(),
reply_packet.GetPacketDataSize());
zmq_socket->send(reply_buffer.get(), MIN_PACKET_SIZE + reply_packet.GetPacketDataSize());
LOG_INFO(RPC_Server, "Sent reply version({}) id=({}) type=({}) size=({})",
reply_packet.GetVersion(), reply_packet.GetId(),
static_cast<u32>(reply_packet.GetPacketType()), reply_packet.GetPacketDataSize());
}
}
}; // namespace RPC

View file

@ -1,34 +0,0 @@
// Copyright 2018 Citra Emulator Project
// Licensed under GPLv2 or any later version
// Refer to the license.txt file included.
#pragma once
#include <functional>
#include <thread>
#define ZMQ_STATIC
#include <zmq.hpp>
namespace RPC {
class Packet;
class ZMQServer {
public:
explicit ZMQServer(std::function<void(std::unique_ptr<Packet>)> new_request_callback);
~ZMQServer();
private:
void WorkerLoop();
void SendReply(Packet& request);
std::thread worker_thread;
std::atomic_bool running = true;
std::unique_ptr<zmq::context_t> zmq_context;
std::unique_ptr<zmq::socket_t> zmq_socket;
std::function<void(std::unique_ptr<Packet>)> new_request_callback;
};
} // namespace RPC