From c0d5140ef0548416fc1f0a52ea3af0743f130145 Mon Sep 17 00:00:00 2001 From: Michael Scire Date: Sun, 27 Mar 2022 14:36:31 -0700 Subject: [PATCH] strat: add windows socket api, linux/macos TODO --- .../socket_platform_types_translation.hpp | 99 +++ .../stratosphere/socket/socket_api.hpp | 2 + .../stratosphere/socket/socket_constants.hpp | 2 + .../stratosphere/socket/socket_errno.hpp | 150 +++- .../stratosphere/socket/socket_options.hpp | 86 +- .../stratosphere/socket/socket_types.hpp | 35 +- .../htclow_socket_discovery_manager.cpp | 6 +- .../htclow/driver/htclow_socket_driver.cpp | 6 +- .../source/socket/impl/socket_api.hpp | 2 + .../socket/impl/socket_api.os.horizon.cpp | 22 + .../socket/impl/socket_api.os.windows.cpp | 741 +++++++++++++++++ ..._platform_types_translation.os.windows.cpp | 776 ++++++++++++++++++ .../source/socket/socket_api.cpp | 4 + .../libstratosphere/source/time/time_api.cpp | 4 +- tests/TestSocket/Makefile | 51 ++ tests/TestSocket/source/test.cpp | 145 ++++ tests/TestSocket/unit_test.mk | 155 ++++ 17 files changed, 2258 insertions(+), 28 deletions(-) create mode 100644 libraries/libstratosphere/include/stratosphere/socket/impl/socket_platform_types_translation.hpp create mode 100644 libraries/libstratosphere/source/socket/impl/socket_api.os.windows.cpp create mode 100644 libraries/libstratosphere/source/socket/impl/socket_platform_types_translation.os.windows.cpp create mode 100644 tests/TestSocket/Makefile create mode 100644 tests/TestSocket/source/test.cpp create mode 100644 tests/TestSocket/unit_test.mk diff --git a/libraries/libstratosphere/include/stratosphere/socket/impl/socket_platform_types_translation.hpp b/libraries/libstratosphere/include/stratosphere/socket/impl/socket_platform_types_translation.hpp new file mode 100644 index 000000000..1586af239 --- /dev/null +++ b/libraries/libstratosphere/include/stratosphere/socket/impl/socket_platform_types_translation.hpp @@ -0,0 +1,99 @@ +/* + * Copyright (c) Atmosphère-NX + * + * This program is free software; you can redistribute it and/or modify it + * under the terms and conditions of the GNU General Public License, + * version 2, as published by the Free Software Foundation. + * + * This program is distributed in the hope it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +#pragma once +#include +#include +#include +#include +#include + +namespace ams::socket::impl { + + #if defined(ATMOSPHERE_OS_WINDOWS) + class PosixWinSockConverter { + private: + struct SocketData { + SOCKET winsock; + bool exempt; + bool shutdown; + + constexpr SocketData() : winsock(static_cast(INVALID_SOCKET)), exempt(), shutdown() { /* ... */ } + }; + private: + os::SdkMutex m_mutex{}; + SocketData m_data[MaxSocketsPerClient]{}; + private: + static constexpr int GetInitialIndex(SOCKET winsock) { + /* The lower 2 bits of a winsock are always zero; Nintendo uses the upper bits as a hashmap index into m_data. */ + return (winsock >> 2) % MaxSocketsPerClient; + } + public: + constexpr PosixWinSockConverter() = default; + + s32 AcquirePosixHandle(SOCKET winsock, bool exempt = false); + s32 GetShutdown(bool &shutdown, s32 posix); + s32 GetSocketExempt(bool &exempt, s32 posix); + SOCKET PosixToWinsockSocket(s32 posix); + void ReleaseAllPosixHandles(); + void ReleasePosixHandle(s32 posix); + s32 SetShutdown(s32 posix, bool shutdown); + s32 SetSocketExempt(s32 posix, bool exempt); + s32 WinsockToPosixSocket(SOCKET winsock); + }; + + s32 MapProtocolValue(Protocol protocol); + Protocol MapProtocolValue(s32 protocol); + + s32 MapTypeValue(Type type); + Type MapTypeValue(s32 type); + + s8 MapFamilyValue(Family family); + Family MapFamilyValue(s8 family); + + s32 MapMsgFlagValue(MsgFlag flag); + MsgFlag MapMsgFlagValue(s32 flag); + + u32 MapAddrInfoFlagValue(AddrInfoFlag flag); + AddrInfoFlag MapAddrInfoFlagValue(u32 flag); + + u32 MapShutdownMethodValue(ShutdownMethod how); + ShutdownMethod MapShutdownMethodValue(u32 how); + + u32 MapFcntlFlagValue(FcntlFlag flag); + FcntlFlag MapFcntlFlagValue(u32 flag); + + s32 MapLevelValue(Level level); + Level MapLevelValue(s32 level); + + s32 MapOptionValue(Level level, Option option); + Option MapOptionValue(s32 level, s32 option); + + s32 MapErrnoValue(Errno error); + Errno MapErrnoValue(s32 error); + + #endif + + #define AMS_SOCKET_IMPL_DECLARE_CONVERSION(AMS, PLATFORM) \ + void CopyToPlatform(PLATFORM *dst, const AMS *src); \ + void CopyFromPlatform(AMS *dst, const PLATFORM *src); + + AMS_SOCKET_IMPL_DECLARE_CONVERSION(SockAddrIn, sockaddr_in); + AMS_SOCKET_IMPL_DECLARE_CONVERSION(TimeVal, timeval); + AMS_SOCKET_IMPL_DECLARE_CONVERSION(Linger, linger); + + #undef AMS_SOCKET_IMPL_DECLARE_CONVERSION + +} diff --git a/libraries/libstratosphere/include/stratosphere/socket/socket_api.hpp b/libraries/libstratosphere/include/stratosphere/socket/socket_api.hpp index 07acdf968..4523842ed 100644 --- a/libraries/libstratosphere/include/stratosphere/socket/socket_api.hpp +++ b/libraries/libstratosphere/include/stratosphere/socket/socket_api.hpp @@ -49,6 +49,8 @@ namespace ams::socket { s32 Accept(s32 desc, SockAddr *out_address, SockLenT *out_addr_len); s32 Bind(s32 desc, const SockAddr *address, SockLenT len); + s32 Connect(s32 desc, const SockAddr *address, SockLenT len); + s32 GetSockName(s32 desc, SockAddr *out_address, SockLenT *out_addr_len); s32 SetSockOpt(s32 desc, Level level, Option option_name, const void *option_value, SockLenT option_size); diff --git a/libraries/libstratosphere/include/stratosphere/socket/socket_constants.hpp b/libraries/libstratosphere/include/stratosphere/socket/socket_constants.hpp index f7ac4fdd5..c742dede5 100644 --- a/libraries/libstratosphere/include/stratosphere/socket/socket_constants.hpp +++ b/libraries/libstratosphere/include/stratosphere/socket/socket_constants.hpp @@ -21,6 +21,8 @@ namespace ams::socket { constexpr inline s32 InvalidSocket = -1; constexpr inline s32 SocketError = -1; + constexpr inline u32 MaxSocketsPerClient = 0x80; + constexpr inline auto DefaultTcpAutoBufferSizeMax = 192_KB; constexpr inline auto MinTransferMemorySize = (2 * DefaultTcpAutoBufferSizeMax + 128_KB); constexpr inline auto MinSocketAllocatorSize = 128_KB; diff --git a/libraries/libstratosphere/include/stratosphere/socket/socket_errno.hpp b/libraries/libstratosphere/include/stratosphere/socket/socket_errno.hpp index f4bb45198..5a1036e90 100644 --- a/libraries/libstratosphere/include/stratosphere/socket/socket_errno.hpp +++ b/libraries/libstratosphere/include/stratosphere/socket/socket_errno.hpp @@ -19,21 +19,143 @@ namespace ams::socket { enum class Errno : u32 { - ESuccess = 0, + ESuccess = 0, + EPerm = 1, + ENoEnt = 2, + ESrch = 3, + EIntr = 4, + EIo = 5, + ENxIo = 6, + E2Big = 7, + ENoExec = 8, + EBadf = 9, + EChild = 10, + EAgain = 11, + EWouldBlock = EAgain, + ENoMem = 12, + EAcces = 13, + EFault = 14, + ENotBlk = 15, + EBusy = 16, + EExist = 17, + EXDev = 18, + ENoDev = 19, + ENotDir = 20, + EIsDir = 21, + EInval = 22, + ENFile = 23, + EMFile = 24, + ENotTy = 25, + ETxtBsy = 26, + EFBig = 27, + ENoSpc = 28, + ESPipe = 29, + ERofs = 30, + EMLink = 31, + EPipe = 32, + EDom = 33, + ERange = 34, + EDeadLk = 35, + EDeadLock = EDeadLk, + ENameTooLong = 36, + ENoLck = 37, + ENoSys = 38, + ENotEmpty = 39, + ELoop = 40, + ENoMsg = 42, + EIdrm = 43, + EChrng = 44, + EL2NSync = 45, + EL3Hlt = 46, + EL3Rst = 47, + ELnrng = 48, + EUnatch = 49, + ENoCsi = 50, + EL2Hlt = 51, + EBade = 52, + EBadr = 53, + EXFull = 54, + ENoAno = 55, + EBadRqc = 56, + EBadSsl = 57, + EBFont = 59, + ENoStr = 60, + ENoData = 61, + ETime = 62, + ENoSr = 63, + ENoNet = 64, + ENoPkg = 65, + ERemote = 66, + ENoLink = 67, + EAdv = 68, + ESrmnt = 69, + EComm = 70, + EProto = 71, + EMultiHop = 72, + EDotDot = 73, + EBadMsg = 74, + EOverflow = 75, + ENotUnuq = 76, + EBadFd = 77, + ERemChg = 78, + ELibAcc = 79, + ELibBad = 80, + ELibScn = 81, + ELibMax = 82, + ELibExec = 83, + EIlSeq = 84, + ERestart = 85, + EStrPipe = 86, + EUsers = 87, + ENotSock = 88, + EDestAddrReq = 89, + EMsgSize = 90, + EPrototype = 91, + ENoProtoOpt = 92, + EProtoNoSupport = 93, + ESocktNoSupport = 94, + EOpNotSupp = 95, + ENotSup = EOpNotSupp, + EPfNoSupport = 96, + EAfNoSupport = 97, + EAddrInUse = 98, + EAddrNotAvail = 99, + ENetDown = 100, + ENetUnreach = 101, + ENetReset = 102, + EConnAborted = 103, + EConnReset = 104, + ENoBufs = 105, + EIsConn = 106, + ENotConn = 107, + EShutDown = 108, + ETooManyRefs = 109, + ETimedOut = 110, + EConnRefused = 111, + EHostDown = 112, + EHostUnreach = 113, + EAlready = 114, + EInProgress = 115, + EStale = 116, + EUClean = 117, + ENotNam = 118, + ENAvail = 119, + EIsNam = 120, + ERemoteIo = 121, + EDQuot = 122, + ENoMedium = 123, + EMediumType = 124, + ECanceled = 125, + ENoKey = 126, + EKeyExpired = 127, + EKeyRevoked = 128, + EKeyRejected = 129, + EOwnerDead = 130, + ENotRecoverable = 131, + ERfKill = 132, + EHwPoison = 133, /* ... */ - EAgain = 11, - ENoMem = 12, - /* ... */ - EFault = 14, - /* ... */ - EInval = 22, - /* ... */ - ENoSpc = 28, - /* ... */ - EL3Hlt = 46, - /* ... */ - EOpNotSupp = 95, - ENotSup = EOpNotSupp, + EProcLim = 156, }; enum class HErrno : s32 { diff --git a/libraries/libstratosphere/include/stratosphere/socket/socket_options.hpp b/libraries/libstratosphere/include/stratosphere/socket/socket_options.hpp index fd4780100..485fdbd66 100644 --- a/libraries/libstratosphere/include/stratosphere/socket/socket_options.hpp +++ b/libraries/libstratosphere/include/stratosphere/socket/socket_options.hpp @@ -29,10 +29,88 @@ namespace ams::socket { }; enum class Option : u32 { - So_Debug = (1 << 0), - /* ... */ - So_ReuseAddr = (1 << 2), - /* ... */ + /* ==================================== */ + So_Debug = (1 << 0), + So_AcceptConn = (1 << 1), + So_ReuseAddr = (1 << 2), + So_KeepAlive = (1 << 3), + So_DontRoute = (1 << 4), + So_Broadcast = (1 << 5), + So_UseLoopback = (1 << 6), + So_Linger = (1 << 7), + So_OobInline = (1 << 8), + So_ReusePort = (1 << 9), + + So_SndBuf = (1 << 12) | 0x01, + So_RcvBuf = (1 << 12) | 0x02, + So_SndLoWat = (1 << 12) | 0x03, + So_RcvLoWat = (1 << 12) | 0x04, + So_SndTimeo = (1 << 12) | 0x05, + So_RcvTimeo = (1 << 12) | 0x06, + So_Error = (1 << 12) | 0x07, + So_Type = (1 << 12) | 0x08, + So_Label = (1 << 12) | 0x09, + So_PeerLabel = (1 << 12) | 0x10, + So_ListenQLimit = (1 << 12) | 0x11, + So_ListenQLen = (1 << 12) | 0x12, + So_ListenIncQLen = (1 << 12) | 0x13, + So_SetFib = (1 << 12) | 0x14, + So_User_Cookie = (1 << 12) | 0x15, + So_Protocol = (1 << 12) | 0x16, + + So_Nn_Shutdown_Exempt = (1 << 16), + + So_Vendor = (1u << 31), + So_Nn_Linger = So_Vendor | 0x01, + /* ==================================== */ + + /* ==================================== */ + Ip_Options = 1, + Ip_HdrIncl = 2, + Ip_Tos = 3, + Ip_Ttl = 4, + Ip_RecvOpts = 5, + Ip_Multicast_If = 9, + Ip_Multicast_Ttl = 10, + Ip_Multicast_Loop = 11, + Ip_Add_Membership = 12, + Ip_Drop_Membership = 13, + Ip_Multicast_Vif = 14, + Ip_Rsvp_On = 15, + Ip_Rsvp_Off = 16, + Ip_Rsvp_Vif_On = 17, + Ip_Rsvp_Vif_Off = 18, + Ip_PortRange = 19, + Ip_Faith = 22, + Ip_OnesBcast = 23, + Ip_BindAny = 24, + + Ip_RecvTtl = 65, + Ip_MinTtl = 66, + Ip_DontFrag = 67, + Ip_RecvTos = 68, + + Ip_Add_Source_Membership = 70, + Ip_Drop_Source_Membership = 71, + Ip_Block_Source = 72, + Ip_Unblock_Source = 73, + /* ==================================== */ + + /* ==================================== */ + Tcp_NoDelay = (1 << 0), + Tcp_MaxSeg = (1 << 1), + Tcp_NoPush = (1 << 2), + Tcp_NoOpt = (1 << 3), + Tcp_Md5Sig = (1 << 4), + Tcp_Info = (1 << 5), + Tcp_Congestion = (1 << 6), + Tcp_KeepInit = (1 << 7), + Tcp_KeepIdle = (1 << 8), + Tcp_KeepIntvl = (1 << 9), + Tcp_KeepCnt = (1 << 10), + + Tcp_Vendor = So_Vendor, + /* ==================================== */ }; } diff --git a/libraries/libstratosphere/include/stratosphere/socket/socket_types.hpp b/libraries/libstratosphere/include/stratosphere/socket/socket_types.hpp index 367c5f14e..8547cf821 100644 --- a/libraries/libstratosphere/include/stratosphere/socket/socket_types.hpp +++ b/libraries/libstratosphere/include/stratosphere/socket/socket_types.hpp @@ -42,6 +42,8 @@ namespace ams::socket { IpProto_Udp = 17, + IpProto_None = 59, + IpProto_UdpLite = 136, IpProto_Raw = 255, @@ -80,12 +82,29 @@ namespace ams::socket { }; enum class MsgFlag : s32 { - MsgFlag_None = (0 << 0), + Msg_None = (0 << 0), + + Msg_Oob = (1 << 0), + Msg_Peek = (1 << 1), + Msg_DontRoute = (1 << 2), /* ... */ - MsgFlag_WaitAll = (1 << 6), + Msg_Trunc = (1 << 4), + Msg_CTrunc = (1 << 5), + Msg_WaitAll = (1 << 6), + Msg_DontWait = (1 << 7), /* ... */ }; + enum class FcntlCommand : u32 { + F_GetFl = 3, + F_SetFl = 4, + }; + + enum class FcntlFlag : u32 { + None = (0 << 0), + O_NonBlock = (1 << 11), + }; + enum class ShutdownMethod : u32 { Shut_Rd = 0, Shut_Wr = 1, @@ -140,6 +159,16 @@ namespace ams::socket { AddrInfo *ai_next; }; + struct TimeVal { + long tv_sec; + long tv_usec; + }; + + struct Linger { + int l_onoff; + int l_linger; + }; + #define AMS_SOCKET_IMPL_DEFINE_ENUM_OPERATORS(__ENUM__) \ constexpr inline __ENUM__ operator | (__ENUM__ lhs, __ENUM__ rhs) { return static_cast<__ENUM__>(static_cast>(lhs) | static_cast>(rhs)); } \ constexpr inline __ENUM__ operator |=(__ENUM__ &lhs, __ENUM__ rhs) { return lhs = lhs | rhs; } \ @@ -151,6 +180,8 @@ namespace ams::socket { AMS_SOCKET_IMPL_DEFINE_ENUM_OPERATORS(Type) AMS_SOCKET_IMPL_DEFINE_ENUM_OPERATORS(AddrInfoFlag) + AMS_SOCKET_IMPL_DEFINE_ENUM_OPERATORS(MsgFlag) + AMS_SOCKET_IMPL_DEFINE_ENUM_OPERATORS(FcntlFlag) #undef AMS_SOCKET_IMPL_DEFINE_ENUM_OPERATORS diff --git a/libraries/libstratosphere/source/htclow/driver/htclow_socket_discovery_manager.cpp b/libraries/libstratosphere/source/htclow/driver/htclow_socket_discovery_manager.cpp index 5cb90ee7a..f864efc1d 100644 --- a/libraries/libstratosphere/source/htclow/driver/htclow_socket_discovery_manager.cpp +++ b/libraries/libstratosphere/source/htclow/driver/htclow_socket_discovery_manager.cpp @@ -100,7 +100,7 @@ namespace ams::htclow::driver { TmipcHeader header; socket::SockAddr recv_sockaddr; socket::SockLenT recv_sockaddr_len = sizeof(recv_sockaddr); - const auto recv_res = socket::RecvFrom(m_socket, std::addressof(header), sizeof(header), socket::MsgFlag::MsgFlag_None, std::addressof(recv_sockaddr), std::addressof(recv_sockaddr_len)); + const auto recv_res = socket::RecvFrom(m_socket, std::addressof(header), sizeof(header), socket::MsgFlag::Msg_None, std::addressof(recv_sockaddr), std::addressof(recv_sockaddr_len)); /* Check that our receive was valid. */ R_UNLESS(recv_res >= 0, htclow::ResultSocketReceiveFromError()); @@ -126,7 +126,7 @@ namespace ams::htclow::driver { } if (header.data_len > 0) { - const auto body_res = socket::RecvFrom(m_socket, packet_data, header.data_len, socket::MsgFlag::MsgFlag_None, std::addressof(recv_sockaddr), std::addressof(recv_sockaddr_len)); + const auto body_res = socket::RecvFrom(m_socket, packet_data, header.data_len, socket::MsgFlag::Msg_None, std::addressof(recv_sockaddr), std::addressof(recv_sockaddr_len)); R_UNLESS(body_res >= 0, htclow::ResultSocketReceiveFromError()); R_UNLESS(recv_sockaddr_len == sizeof(recv_sockaddr), htclow::ResultSocketReceiveFromError()); @@ -139,7 +139,7 @@ namespace ams::htclow::driver { const auto len = MakeBeaconResponsePacket(packet_data, sizeof(packet_data)); /* Send the beacon response data. */ - const auto send_res = socket::SendTo(m_socket, packet_data, len, socket::MsgFlag::MsgFlag_None, std::addressof(recv_sockaddr), sizeof(recv_sockaddr)); + const auto send_res = socket::SendTo(m_socket, packet_data, len, socket::MsgFlag::Msg_None, std::addressof(recv_sockaddr), sizeof(recv_sockaddr)); R_UNLESS(send_res >= 0, htclow::ResultSocketSendToError()); } diff --git a/libraries/libstratosphere/source/htclow/driver/htclow_socket_driver.cpp b/libraries/libstratosphere/source/htclow/driver/htclow_socket_driver.cpp index 23b479acb..63afcf78e 100644 --- a/libraries/libstratosphere/source/htclow/driver/htclow_socket_driver.cpp +++ b/libraries/libstratosphere/source/htclow/driver/htclow_socket_driver.cpp @@ -175,7 +175,7 @@ namespace ams::htclow::driver { }; /* Send the auto-connect packet. */ - socket::SendTo(desc, auto_connect_packet, len, socket::MsgFlag::MsgFlag_None, reinterpret_cast(std::addressof(sockaddr)), sizeof(sockaddr)); + socket::SendTo(desc, auto_connect_packet, len, socket::MsgFlag::Msg_None, reinterpret_cast(std::addressof(sockaddr)), sizeof(sockaddr)); } Result SocketDriver::Open() { @@ -247,7 +247,7 @@ namespace ams::htclow::driver { /* Repeatedly send data until it's all sent. */ ssize_t cur_sent; for (ssize_t sent = 0; sent < src_size; sent += cur_sent) { - cur_sent = socket::Send(m_client_socket, static_cast(src) + sent, src_size - sent, socket::MsgFlag::MsgFlag_None); + cur_sent = socket::Send(m_client_socket, static_cast(src) + sent, src_size - sent, socket::MsgFlag::Msg_None); R_UNLESS(cur_sent > 0, htclow::ResultSocketSendError()); } @@ -261,7 +261,7 @@ namespace ams::htclow::driver { /* Repeatedly receive data until it's all sent. */ ssize_t cur_recv; for (ssize_t received = 0; received < dst_size; received += cur_recv) { - cur_recv = socket::Recv(m_client_socket, static_cast(dst) + received, dst_size - received, socket::MsgFlag::MsgFlag_None); + cur_recv = socket::Recv(m_client_socket, static_cast(dst) + received, dst_size - received, socket::MsgFlag::Msg_None); R_UNLESS(cur_recv > 0, htclow::ResultSocketReceiveError()); } diff --git a/libraries/libstratosphere/source/socket/impl/socket_api.hpp b/libraries/libstratosphere/source/socket/impl/socket_api.hpp index f1dcc5512..85ffb55e4 100644 --- a/libraries/libstratosphere/source/socket/impl/socket_api.hpp +++ b/libraries/libstratosphere/source/socket/impl/socket_api.hpp @@ -47,6 +47,8 @@ namespace ams::socket::impl { s32 Accept(s32 desc, SockAddr *out_address, SockLenT *out_addr_len); s32 Bind(s32 desc, const SockAddr *address, SockLenT len); + s32 Connect(s32 desc, const SockAddr *address, SockLenT len); + s32 GetSockName(s32 desc, SockAddr *out_address, SockLenT *out_addr_len); s32 SetSockOpt(s32 desc, Level level, Option option_name, const void *option_value, SockLenT option_size); diff --git a/libraries/libstratosphere/source/socket/impl/socket_api.os.horizon.cpp b/libraries/libstratosphere/source/socket/impl/socket_api.os.horizon.cpp index 676817e75..1ebf01565 100644 --- a/libraries/libstratosphere/source/socket/impl/socket_api.os.horizon.cpp +++ b/libraries/libstratosphere/source/socket/impl/socket_api.os.horizon.cpp @@ -505,6 +505,28 @@ namespace ams::socket::impl { return result; } + s32 Connect(s32 desc, const SockAddr *address, SockLenT len) { + /* Check pre-conditions. */ + AMS_ABORT_UNLESS(IsInitialized()); + + /* Check input. */ + if (address == nullptr || len == 0) { + socket::impl::SetLastError(Errno::EInval); + return -1; + } + + /* Perform the call. */ + Errno error = Errno::ESuccess; + int result = ::bsdConnect(desc, ConvertForLibnx(address), len); + TranslateResultToBsdError(error, result); + + if (result < 0) { + socket::impl::SetLastError(error); + } + + return result; + } + s32 GetSockName(s32 desc, SockAddr *out_address, SockLenT *out_addr_len) { /* Check pre-conditions. */ AMS_ABORT_UNLESS(IsInitialized()); diff --git a/libraries/libstratosphere/source/socket/impl/socket_api.os.windows.cpp b/libraries/libstratosphere/source/socket/impl/socket_api.os.windows.cpp new file mode 100644 index 000000000..10fdd319e --- /dev/null +++ b/libraries/libstratosphere/source/socket/impl/socket_api.os.windows.cpp @@ -0,0 +1,741 @@ +/* + * Copyright (c) Atmosphère-NX + * + * This program is free software; you can redistribute it and/or modify it + * under the terms and conditions of the GNU General Public License, + * version 2, as published by the Free Software Foundation. + * + * This program is distributed in the hope it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +#include +#include "socket_api.hpp" +#include "socket_allocator.hpp" + +#include + +#include + +namespace ams::socket::impl { + + extern PosixWinSockConverter g_posix_winsock_converter; + + namespace { + + constinit util::Atomic g_init_counter = 0; + + ALWAYS_INLINE bool IsInitialized() { + return g_init_counter > 0; + } + + class FcntlState { + private: + FcntlFlag m_flags[MaxSocketsPerClient]{}; + os::SdkRecursiveMutex m_mutexes[MaxSocketsPerClient]{}; + public: + constexpr FcntlState() = default; + public: + void ClearFlag(int fd, FcntlFlag flag) { + std::scoped_lock lk(m_mutexes[fd]); + + m_flags[fd] &= ~flag; + } + + void ClearFlags(int fd) { + std::scoped_lock lk(m_mutexes[fd]); + + m_flags[fd] = FcntlFlag::None; + } + + FcntlFlag GetFlags(int fd) { + std::scoped_lock lk(m_mutexes[fd]); + + return m_flags[fd]; + } + + int GetFlagsInt(int fd) { + return static_cast(this->GetFlags(fd)); + } + + os::SdkRecursiveMutex &GetSocketLock(int fd) { + return m_mutexes[fd]; + } + + bool IsFlagClear(int fd, FcntlFlag flag) { + return !this->IsFlagSet(fd, flag); + } + + bool IsFlagSet(int fd, FcntlFlag flag) { + std::scoped_lock lk(m_mutexes[fd]); + + return (m_flags[fd] & flag) != static_cast(0); + } + + bool IsSocketBlocking(int fd) { + return !this->IsSocketNonBlocking(fd); + } + + bool IsSocketNonBlocking(int fd) { + return this->IsFlagSet(fd, FcntlFlag::O_NonBlock); + } + + void SetFlag(int fd, FcntlFlag flag) { + std::scoped_lock lk(m_mutexes[fd]); + + m_flags[fd] |= flag; + } + }; + + constinit FcntlState g_fcntl_state; + + void TransmuteWsaError() { + switch (::WSAGetLastError()) { + case WSAEFAULT: ::WSASetLastError(WSAEINVAL); break; + case WSAENOTSOCK: ::WSASetLastError(WSAEBADF); break; + case WSAETIMEDOUT: ::WSASetLastError(WSAEWOULDBLOCK); break; + } + } + + template + void TransmuteWsaError(T res) { + if (static_cast(res) == SOCKET_ERROR) { + TransmuteWsaError(); + } + } + + } + + #define AMS_SOCKET_IMPL_SCOPED_MAKE_NON_BLOCKING(_cond, _fd) \ + /* If the socket is blocking and we need to make it non-blocking, do so. */ \ + int nonblock_##__LINE__ = 1; \ + bool set_nonblock_##__LINE__ = false; \ + if (_cond && g_fcntl_state.IsSocketBlocking(_fd)) { \ + if (const auto res = ::ioctlsocket(handle, FIONBIO, reinterpret_cast(std::addressof( nonblock_##__LINE__ ))); res == SOCKET_ERROR) { \ + TransmuteWsaError(); \ + return res; \ + } \ + \ + set_nonblock_##__LINE__ = true; \ + } \ + \ + ON_SCOPE_EXIT { \ + /* Preserve last error. */ \ + const auto last_err = socket::impl::GetLastError(); \ + ON_SCOPE_EXIT { socket::impl::SetLastError(last_err); }; \ + \ + /* Restore non-blocking state. */ \ + if (set_nonblock_##__LINE__) { \ + nonblock_##__LINE__ = 0; \ + \ + while (true) { \ + const auto restore_res = ::ioctlsocket(handle, FIONBIO, reinterpret_cast(std::addressof( nonblock_##__LINE__ ))); \ + TransmuteWsaError(restore_res); \ + if (!(restore_res == SOCKET_ERROR && socket::impl::GetLastError() == Errno::EInProgress)) { \ + break; \ + } \ + \ + os::SleepThread(TimeSpan::FromMilliSeconds(1)); \ + } \ + } \ + } + + #define AMS_SOCKET_IMPL_DO_WITH_TRANSMUTE(expr) ({ const auto res = (expr); TransmuteWsaError(res); res; }) + + + void *Alloc(size_t size) { + return ::std::malloc(size); + } + + void *Calloc(size_t num, size_t size) { + const size_t total_size = size * num; + void *buf = Alloc(size); + if (buf != nullptr) { + std::memset(buf, 0, total_size); + } + return buf; + } + + void Free(void *ptr) { + return ::std::free(ptr); + } + + Errno GetLastError() { + if (AMS_LIKELY(IsInitialized())) { + return MapErrnoValue(::WSAGetLastError()); + } else { + return Errno::EInval; + } + } + + void SetLastError(Errno err) { + if (AMS_LIKELY(IsInitialized())) { + ::WSASetLastError(MapErrnoValue(err)); + } + } + + u32 InetHtonl(u32 host) { + return ::htonl(host); + } + + u16 InetHtons(u16 host) { + return ::htons(host); + } + + u32 InetNtohl(u32 net) { + return ::ntohl(net); + } + + u16 InetNtohs(u16 net) { + return ::ntohs(net); + } + + Result Initialize(const Config &config) { + AMS_UNUSED(config); + + /* Increment init counter. */ + ++g_init_counter; + + /* Initialize winsock. */ + WSADATA wsa_data; + WORD wVersionRequested = MAKEWORD(2, 2); + + const auto res = ::WSAStartup(wVersionRequested, std::addressof(wsa_data)); + AMS_ABORT_UNLESS(res == 0); + + /* Initialize time services. */ + R_ABORT_UNLESS(time::Initialize()); + + R_SUCCEED(); + } + + Result Finalize() { + /* Check pre-conditions. */ + --g_init_counter; + AMS_ABORT_UNLESS(g_init_counter >= 0); + + /* Cleanup WSA. */ + ::WSACleanup(); + + /* Finalize time services. */ + time::Finalize(); + + /* Release all posix handles. */ + g_posix_winsock_converter.ReleaseAllPosixHandles(); + + R_SUCCEED(); + } + + ssize_t RecvFromInternal(s32 desc, void *buffer, size_t buffer_size, MsgFlag flags, SockAddr *out_address, SockLenT *out_addr_len) { + /* Convert socket. */ + SOCKET handle = g_posix_winsock_converter.PosixToWinsockSocket(desc); + + /* Check input. */ + if (handle == static_cast(socket::InvalidSocket)) { + socket::impl::SetLastError(Errno::EBadf); + return -1; + } + + /* Convert the sockaddr. */ + sockaddr sa = {}; + socklen_t addr_len = sizeof(sa); + + /* Perform the call. */ + const auto res = ::recvfrom(handle, static_cast(buffer), static_cast(buffer_size), MapMsgFlagValue(flags), std::addressof(sa), std::addressof(addr_len)); + if (res == SOCKET_ERROR) { + if (::WSAGetLastError() == WSAESHUTDOWN) { + ::WSASetLastError(WSAENETDOWN); + } else { + TransmuteWsaError(); + } + } + + /* Set output. */ + if (out_address != nullptr && out_addr_len != nullptr) { + if (addr_len > static_cast(sizeof(*out_address))) { + addr_len = sizeof(*out_address); + } + + if (*out_addr_len != 0) { + if (static_cast(*out_addr_len) > addr_len) { + *out_addr_len = addr_len; + } + + SockAddr sa_pl = {}; + CopyFromPlatform(reinterpret_cast(std::addressof(sa_pl)), reinterpret_cast(std::addressof(sa))); + std::memcpy(out_address, std::addressof(sa_pl), *out_addr_len); + } + } + + return res; + } + + ssize_t RecvFrom(s32 desc, void *buffer, size_t buffer_size, MsgFlag flags, SockAddr *out_address, SockLenT *out_addr_len) { + /* Check pre-conditions. */ + AMS_ABORT_UNLESS(IsInitialized()); + + /* Convert socket. */ + SOCKET handle = g_posix_winsock_converter.PosixToWinsockSocket(desc); + + /* If the flags have DontWait set, clear WaitAll. */ + if ((flags & MsgFlag::Msg_DontWait) == MsgFlag::Msg_DontWait) { + flags &= ~MsgFlag::Msg_WaitAll; + } + + /* If the flags haev WaitAll set but the socket is non-blocking, clear WaitAll. */ + if ((flags & MsgFlag::Msg_WaitAll) == MsgFlag::Msg_WaitAll && g_fcntl_state.IsSocketNonBlocking(desc)) { + flags &= ~MsgFlag::Msg_WaitAll; + } + + /* Check input. */ + if (handle == static_cast(socket::InvalidSocket)) { + socket::impl::SetLastError(Errno::EBadf); + return -1; + } else if (buffer_size == 0) { + return 0; + } else if (buffer == nullptr) { + socket::impl::SetLastError(Errno::EInval); + return -1; + } else if (buffer_size > std::numeric_limits::max()) { + socket::impl::SetLastError(Errno::EFault); + return -1; + } + + /* Handle blocking vs non-blocking. */ + if ((flags & MsgFlag::Msg_DontWait) == MsgFlag::Msg_DontWait) { + return RecvFromInternal(desc, buffer, buffer_size, flags, out_address, out_addr_len); + } else { + /* Lock the socket. */ + std::scoped_lock lk(g_fcntl_state.GetSocketLock(desc)); + + /* Clear don't wait from the flags. */ + flags &= MsgFlag::Msg_DontWait; + + /* If the socket is blocking, we need to make it non-blocking. */ + AMS_SOCKET_IMPL_SCOPED_MAKE_NON_BLOCKING(true, desc); + + /* Do the recv from. */ + return RecvFromInternal(desc, buffer, buffer_size, flags, out_address, out_addr_len); + } + } + + ssize_t Recv(s32 desc, void *buffer, size_t buffer_size, MsgFlag flags) { + /* Check pre-conditions. */ + AMS_ABORT_UNLESS(IsInitialized()); + + /* Convert socket. */ + SOCKET handle = g_posix_winsock_converter.PosixToWinsockSocket(desc); + + /* Check input. */ + if (handle == static_cast(socket::InvalidSocket)) { + socket::impl::SetLastError(Errno::EBadf); + return -1; + } else if (buffer_size == 0) { + return 0; + } else if (buffer == nullptr) { + socket::impl::SetLastError(Errno::EInval); + return -1; + } else if (buffer_size > std::numeric_limits::max()) { + socket::impl::SetLastError(Errno::EFault); + return -1; + } + + /* If the socket is blocking, we need to make it non-blocking. */ + AMS_SOCKET_IMPL_SCOPED_MAKE_NON_BLOCKING(((flags & MsgFlag::Msg_DontWait) == MsgFlag::Msg_DontWait), desc); + + /* Perform the call. */ + return AMS_SOCKET_IMPL_DO_WITH_TRANSMUTE(::recv(handle, static_cast(buffer), static_cast(buffer_size), MapMsgFlagValue(flags & ~MsgFlag::Msg_DontWait))); + } + + ssize_t SendTo(s32 desc, const void *buffer, size_t buffer_size, MsgFlag flags, const SockAddr *address, SockLenT len) { + /* Check pre-conditions. */ + AMS_ABORT_UNLESS(IsInitialized()); + + /* Convert socket. */ + SOCKET handle = g_posix_winsock_converter.PosixToWinsockSocket(desc); + + /* Check input. */ + if (handle == static_cast(socket::InvalidSocket)) { + socket::impl::SetLastError(Errno::EBadf); + return -1; + } + + /* Clear don't wait from flags. */ + flags &= ~MsgFlag::Msg_DontWait; + + /* Convert the sockaddr. */ + sockaddr sa = {}; + socket::impl::CopyToPlatform(reinterpret_cast(std::addressof(sa)), reinterpret_cast(address)); + + /* Perform the call. */ + const auto res = ::sendto(handle, static_cast(buffer), static_cast(buffer_size), MapMsgFlagValue(flags), address != nullptr ? std::addressof(sa) : nullptr, static_cast(len)); + if (res == SOCKET_ERROR) { + if (::WSAGetLastError() == WSAESHUTDOWN) { + ::WSASetLastError(109); + } else { + TransmuteWsaError(); + } + } + + return res; + } + + ssize_t Send(s32 desc, const void *buffer, size_t buffer_size, MsgFlag flags) { + /* Check pre-conditions. */ + AMS_ABORT_UNLESS(IsInitialized()); + + /* Convert socket. */ + SOCKET handle = g_posix_winsock_converter.PosixToWinsockSocket(desc); + + /* Check input. */ + if (handle == static_cast(socket::InvalidSocket)) { + socket::impl::SetLastError(Errno::EBadf); + return -1; + } + + /* Perform the call. */ + return AMS_SOCKET_IMPL_DO_WITH_TRANSMUTE(::send(handle, static_cast(buffer), static_cast(buffer_size), MapMsgFlagValue(flags))); + } + + s32 Shutdown(s32 desc, ShutdownMethod how) { + /* Check pre-conditions. */ + AMS_ABORT_UNLESS(IsInitialized()); + + /* Convert socket. */ + SOCKET handle = g_posix_winsock_converter.PosixToWinsockSocket(desc); + + /* Check input. */ + if (handle == static_cast(socket::InvalidSocket)) { + socket::impl::SetLastError(Errno::EBadf); + return -1; + } + + /* Perform the call. */ + const auto res = ::shutdown(handle, MapShutdownMethodValue(how)); + g_posix_winsock_converter.SetShutdown(desc, true); + TransmuteWsaError(res); + return res; + } + + s32 Socket(Family domain, Type type, Protocol protocol, bool exempt) { + /* Check pre-conditions. */ + AMS_ABORT_UNLESS(IsInitialized()); + + const auto res = ::socket(MapFamilyValue(domain), MapTypeValue(type), MapProtocolValue(protocol)); + TransmuteWsaError(res); + + s32 posix_socket = -1; + if (res != static_cast::type>(SOCKET_ERROR)) { + if (posix_socket = g_posix_winsock_converter.AcquirePosixHandle(res, exempt); posix_socket < 0) { + /* Preserve last error. */ + const auto last_err = socket::impl::GetLastError(); + ON_SCOPE_EXIT { socket::impl::SetLastError(last_err); }; + + /* Close the socket. */ + ::closesocket(res); + } + } + + return posix_socket; + } + + s32 Socket(Family domain, Type type, Protocol protocol) { + return Socket(domain, type, protocol, false); + } + + s32 SocketExempt(Family domain, Type type, Protocol protocol) { + return Socket(domain, type, protocol, true); + } + + s32 Accept(s32 desc, SockAddr *out_address, SockLenT *out_addr_len) { + /* Check pre-conditions. */ + AMS_ABORT_UNLESS(IsInitialized()); + + /* Convert socket. */ + SOCKET handle = g_posix_winsock_converter.PosixToWinsockSocket(desc); + + /* Check input. */ + if (handle == static_cast(socket::InvalidSocket)) { + socket::impl::SetLastError(Errno::EBadf); + return -1; + } + + /* Check shutdown. */ + bool is_shutdown = false; + if (const auto res = g_posix_winsock_converter.GetShutdown(is_shutdown, desc); res == SOCKET_ERROR || (res == 0 && is_shutdown)) { + socket::impl::SetLastError(Errno::EConnAborted); + return -1; + } + + /* Accept. */ + sockaddr sa = {}; + socklen_t sa_len = sizeof(sa); + const auto res = ::accept(handle, std::addressof(sa), std::addressof(sa_len)); + if (res == static_cast::type>(SOCKET_ERROR)) { + if (::WSAGetLastError() == WSAEOPNOTSUPP) { + ::WSASetLastError(WSAEINVAL); + } else { + TransmuteWsaError(); + } + } + + /* Set output. */ + if (out_address != nullptr && out_addr_len != nullptr) { + if (sa_len > static_cast(sizeof(*out_address))) { + sa_len = sizeof(*out_address); + } + + if (*out_addr_len != 0) { + if (static_cast(*out_addr_len) > sa_len) { + *out_addr_len = sa_len; + } + + SockAddr sa_pl = {}; + CopyFromPlatform(reinterpret_cast(std::addressof(sa_pl)), reinterpret_cast(std::addressof(sa))); + std::memcpy(out_address, std::addressof(sa_pl), *out_addr_len); + } + + *out_addr_len = sa_len; + } + + if (res == static_cast::type>(SOCKET_ERROR)) { + return res; + } + + s32 fd = -1; + bool is_exempt = false; + if (g_posix_winsock_converter.GetSocketExempt(is_exempt, desc) == 0) { + fd = g_posix_winsock_converter.AcquirePosixHandle(res, is_exempt); + } + + if (fd < 0) { + /* Preserve last error. */ + const auto last_err = socket::impl::GetLastError(); + ON_SCOPE_EXIT { socket::impl::SetLastError(last_err); }; + + ::closesocket(res); + + return SOCKET_ERROR; + } + + return fd; + } + + s32 Bind(s32 desc, const SockAddr *address, SockLenT len) { + /* Check pre-conditions. */ + AMS_ABORT_UNLESS(IsInitialized()); + + /* Convert socket. */ + SOCKET handle = g_posix_winsock_converter.PosixToWinsockSocket(desc); + + /* Check input. */ + if (handle == static_cast(socket::InvalidSocket)) { + socket::impl::SetLastError(Errno::EBadf); + return -1; + } else if (address == nullptr) { + socket::impl::SetLastError(Errno::EInval); + return -1; + } + + /* Convert the sockaddr. */ + sockaddr sa = {}; + socket::impl::CopyToPlatform(reinterpret_cast(std::addressof(sa)), reinterpret_cast(address)); + + return AMS_SOCKET_IMPL_DO_WITH_TRANSMUTE(::bind(handle, std::addressof(sa), static_cast(len))); + } + + s32 Connect(s32 desc, const SockAddr *address, SockLenT len) { + /* Check pre-conditions. */ + AMS_ABORT_UNLESS(IsInitialized()); + + /* Convert socket. */ + SOCKET handle = g_posix_winsock_converter.PosixToWinsockSocket(desc); + + /* Check input. */ + if (handle == static_cast(socket::InvalidSocket)) { + socket::impl::SetLastError(Errno::EBadf); + return -1; + } + + /* Convert the sockaddr. */ + sockaddr sa = {}; + if (address != nullptr) { + if (reinterpret_cast(address)->sin_port == 0) { + socket::impl::SetLastError(Errno::EAddrNotAvail); + return -1; + } + + socket::impl::CopyToPlatform(reinterpret_cast(std::addressof(sa)), reinterpret_cast(address)); + } + + const auto res = ::connect(handle, address != nullptr ? std::addressof(sa) : nullptr, len); + if (res == SOCKET_ERROR) { + const auto wsa_err = ::WSAGetLastError(); + if (wsa_err == WSAEWOULDBLOCK) { + ::WSASetLastError(WSAEINPROGRESS); + } else if (wsa_err != WSAETIMEDOUT) { + TransmuteWsaError(); + } + } + + return res; + } + + s32 GetSockName(s32 desc, SockAddr *out_address, SockLenT *out_addr_len) { + /* Check pre-conditions. */ + AMS_ABORT_UNLESS(IsInitialized()); + + /* Convert socket. */ + SOCKET handle = g_posix_winsock_converter.PosixToWinsockSocket(desc); + + /* Check input. */ + if (handle == static_cast(socket::InvalidSocket)) { + socket::impl::SetLastError(Errno::EBadf); + return -1; + } + + /* We may end up preserving the last wsa error. */ + const auto last_err = ::WSAGetLastError(); + + /* Do the call. */ + sockaddr sa = {}; + + auto res = ::getsockname(handle, out_address != nullptr ? std::addressof(sa) : nullptr, reinterpret_cast(out_addr_len)); + if (res == SOCKET_ERROR) { + if (::WSAGetLastError() == WSAEINVAL) { + ::WSASetLastError(last_err); + + sa = {}; + res = 0; + } else { + TransmuteWsaError(); + } + } + + /* Copy out. */ + if (out_address != nullptr) { + CopyFromPlatform(reinterpret_cast(out_address), reinterpret_cast(std::addressof(sa))); + } + + return res; + } + + s32 SetSockOpt(s32 desc, Level level, Option option_name, const void *option_value, SockLenT option_size) { + /* Check pre-conditions. */ + AMS_ABORT_UNLESS(IsInitialized()); + + /* Convert socket. */ + SOCKET handle = g_posix_winsock_converter.PosixToWinsockSocket(desc); + + /* Check input. */ + if (handle == static_cast(socket::InvalidSocket)) { + socket::impl::SetLastError(Errno::EBadf); + return -1; + } + + union SocketOptionValue { + linger option_linger; + DWORD option_timeout_ms; + DWORD option_exempt; + }; + + SocketOptionValue sockopt_value = {}; + socklen_t option_value_length = option_size; + const char *p_option_value = nullptr; + + switch (option_name) { + case Option::So_Linger: + case Option::So_Nn_Linger: + { + if (option_value_length < static_cast(sizeof(sockopt_value.option_linger))) { + socket::impl::SetLastError(Errno::EInval); + return -1; + } + option_value_length = sizeof(sockopt_value.option_linger); + CopyToPlatform(std::addressof(sockopt_value.option_linger), reinterpret_cast(option_value)); + p_option_value = reinterpret_cast(std::addressof(sockopt_value.option_linger)); + } + break; + case Option::So_SndTimeo: + case Option::So_RcvTimeo: + { + if (option_value_length < static_cast(sizeof(sockopt_value.option_timeout_ms))) { + socket::impl::SetLastError(Errno::EInval); + return -1; + } + option_value_length = sizeof(sockopt_value.option_timeout_ms); + sockopt_value.option_timeout_ms = (reinterpret_cast(option_value)->tv_sec * 1000) + (reinterpret_cast(option_value)->tv_usec / 1000); + p_option_value = reinterpret_cast(std::addressof(sockopt_value.option_timeout_ms)); + } + break; + case Option::So_Nn_Shutdown_Exempt: + { + if (option_value_length < static_cast(sizeof(sockopt_value.option_exempt))) { + socket::impl::SetLastError(Errno::EInval); + return -1; + } + + return g_posix_winsock_converter.SetSocketExempt(desc, *reinterpret_cast(option_value) != 0); + } + break; + default: + p_option_value = reinterpret_cast(option_value); + break; + } + + return AMS_SOCKET_IMPL_DO_WITH_TRANSMUTE(::setsockopt(handle, MapLevelValue(level), MapOptionValue(level, option_name), p_option_value, option_value_length)); + } + + s32 Listen(s32 desc, s32 backlog) { + /* Check pre-conditions. */ + AMS_ABORT_UNLESS(IsInitialized()); + + /* Convert socket. */ + SOCKET handle = g_posix_winsock_converter.PosixToWinsockSocket(desc); + + /* Check input. */ + if (handle == static_cast(socket::InvalidSocket)) { + socket::impl::SetLastError(Errno::EBadf); + return -1; + } + + /* Check shutdown. */ + bool is_shutdown = false; + if (const auto res = g_posix_winsock_converter.GetShutdown(is_shutdown, desc); res == SOCKET_ERROR || (res == 0 && is_shutdown)) { + socket::impl::SetLastError(Errno::EInval); + return -1; + } + + return AMS_SOCKET_IMPL_DO_WITH_TRANSMUTE(::listen(handle, backlog)); + } + + s32 Close(s32 desc) { + /* Check pre-conditions. */ + AMS_ABORT_UNLESS(IsInitialized()); + + /* Check that we can close. */ + static constinit os::SdkMutex s_close_lock; + SOCKET handle = static_cast(socket::InvalidSocket); + { + std::scoped_lock lk(s_close_lock); + + handle = g_posix_winsock_converter.PosixToWinsockSocket(desc); + if (handle == static_cast(socket::InvalidSocket)) { + return SOCKET_ERROR; + } + + g_posix_winsock_converter.ReleasePosixHandle(desc); + } + + /* Do the close. */ + const auto res = ::closesocket(handle); + g_fcntl_state.ClearFlags(desc); + + return res; + } + +} diff --git a/libraries/libstratosphere/source/socket/impl/socket_platform_types_translation.os.windows.cpp b/libraries/libstratosphere/source/socket/impl/socket_platform_types_translation.os.windows.cpp new file mode 100644 index 000000000..b03f8a58b --- /dev/null +++ b/libraries/libstratosphere/source/socket/impl/socket_platform_types_translation.os.windows.cpp @@ -0,0 +1,776 @@ +/* + * Copyright (c) Atmosphère-NX + * + * This program is free software; you can redistribute it and/or modify it + * under the terms and conditions of the GNU General Public License, + * version 2, as published by the Free Software Foundation. + * + * This program is distributed in the hope it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +#include +#include "socket_api.hpp" + +#include + +#include + +namespace ams::socket::impl { + + ///* TODO: Custom sys/* headers, probably. */ + #define AF_LINK 18 + + #define MSG_TRUNC 0x10 + #define MSG_CTRUNC 0x20 + #define MSG_DONTWAIT 0x80 + + #define SHUT_RD 0 + #define SHUT_WR 1 + #define SHUT_RDWR 2 + + #define O_NONBLOCK 4 + + #define TCP_MAXSEG 4 + + PosixWinSockConverter g_posix_winsock_converter; + + s32 PosixWinSockConverter::AcquirePosixHandle(SOCKET winsock, bool exempt) { + /* Acquire exclusive access. */ + std::scoped_lock lk(m_mutex); + + /* Get initial index. */ + const auto initial_index = GetInitialIndex(winsock); + + /* Try to find an open index. */ + for (auto posix = initial_index; posix < static_cast(MaxSocketsPerClient); ++posix) { + if (m_data[posix].winsock == static_cast(INVALID_SOCKET)) { + m_data[posix].winsock = winsock; + m_data[posix].exempt = exempt; + return posix; + } + } + + for (auto posix = 0; posix < initial_index; ++posix) { + if (m_data[posix].winsock == static_cast(INVALID_SOCKET)) { + m_data[posix].winsock = winsock; + m_data[posix].exempt = exempt; + return posix; + } + } + + /* We're out of open handles. */ + socket::impl::SetLastError(Errno::EMFile); + return SOCKET_ERROR; + } + + s32 PosixWinSockConverter::GetShutdown(bool &shutdown, s32 posix) { + /* Acquire exclusive access. */ + std::scoped_lock lk(m_mutex); + + /* Check input. */ + if (static_cast(posix) >= MaxSocketsPerClient || m_data[posix].winsock == static_cast(INVALID_SOCKET)) { + socket::impl::SetLastError(Errno::EBadf); + return SOCKET_ERROR; + } + + /* Set the output. */ + shutdown = m_data[posix].shutdown; + return 0; + } + + s32 PosixWinSockConverter::GetSocketExempt(bool &exempt, s32 posix) { + /* Acquire exclusive access. */ + std::scoped_lock lk(m_mutex); + + /* Check input. */ + if (static_cast(posix) >= MaxSocketsPerClient || m_data[posix].winsock == static_cast(INVALID_SOCKET)) { + socket::impl::SetLastError(Errno::EBadf); + return SOCKET_ERROR; + } + + /* Set the output. */ + exempt = m_data[posix].exempt; + return 0; + } + + SOCKET PosixWinSockConverter::PosixToWinsockSocket(s32 posix) { + /* Acquire exclusive access. */ + std::scoped_lock lk(m_mutex); + + /* Check input. */ + if (static_cast(posix) >= MaxSocketsPerClient || m_data[posix].winsock == static_cast(INVALID_SOCKET)) { + socket::impl::SetLastError(Errno::EBadf); + return SOCKET_ERROR; + } + + return m_data[posix].winsock; + } + + void PosixWinSockConverter::ReleaseAllPosixHandles() { + /* Acquire exclusive access. */ + std::scoped_lock lk(m_mutex); + + for (size_t i = 0; i < MaxSocketsPerClient; ++i) { + m_data[i] = SocketData{}; + } + } + + void PosixWinSockConverter::ReleasePosixHandle(s32 posix) { + /* Acquire exclusive access. */ + std::scoped_lock lk(m_mutex); + + AMS_ASSERT(static_cast(posix) < MaxSocketsPerClient); + + m_data[posix] = SocketData{}; + } + + s32 PosixWinSockConverter::SetShutdown(s32 posix, bool shutdown) { + /* Acquire exclusive access. */ + std::scoped_lock lk(m_mutex); + + /* Check input. */ + if (static_cast(posix) >= MaxSocketsPerClient || m_data[posix].winsock == static_cast(INVALID_SOCKET)) { + socket::impl::SetLastError(Errno::EBadf); + return SOCKET_ERROR; + } + + /* Set the shutdown. */ + m_data[posix].shutdown = shutdown; + return 0; + } + + s32 PosixWinSockConverter::SetSocketExempt(s32 posix, bool exempt) { + /* Acquire exclusive access. */ + std::scoped_lock lk(m_mutex); + + /* Check input. */ + if (static_cast(posix) >= MaxSocketsPerClient || m_data[posix].winsock == static_cast(INVALID_SOCKET)) { + socket::impl::SetLastError(Errno::EBadf); + return SOCKET_ERROR; + } + + /* Set the exempt. */ + m_data[posix].exempt = exempt; + return 0; + } + + s32 PosixWinSockConverter::WinsockToPosixSocket(SOCKET winsock) { + /* Acquire exclusive access. */ + std::scoped_lock lk(m_mutex); + + /* Get initial index. */ + const auto initial_index = GetInitialIndex(winsock); + + /* Try to find an open index. */ + for (auto posix = initial_index; posix < static_cast(MaxSocketsPerClient); ++posix) { + if (m_data[posix].winsock == winsock) { + return posix; + } + } + + for (auto posix = 0; posix < initial_index; ++posix) { + if (m_data[posix].winsock == winsock) { + return posix; + } + } + + /* We failed to find the posix handle. */ + return -1; + } + + s32 MapProtocolValue(Protocol protocol) { + s32 mapped = -1; + + switch (protocol) { + case Protocol::IpProto_Ip: mapped = IPPROTO_IP; break; + case Protocol::IpProto_Icmp: mapped = IPPROTO_ICMP; break; + case Protocol::IpProto_Tcp: mapped = IPPROTO_TCP; break; + case Protocol::IpProto_Udp: mapped = IPPROTO_UDP; break; + case Protocol::IpProto_None: mapped = IPPROTO_NONE; break; + case Protocol::IpProto_UdpLite: mapped = IPPROTO_UDP; break; + case Protocol::IpProto_Raw: mapped = IPPROTO_RAW; break; + case Protocol::IpProto_Max: mapped = IPPROTO_MAX; break; + default: + AMS_SDK_LOG("WARNING: Invalid ams::Socket::Protocol %d\n", static_cast(protocol)); + break; + } + + if (mapped == -1) { + AMS_SDK_LOG("WARNING: ams::Socket::Protocol %d is not supported by Win32/Win64.\n", static_cast(protocol)); + } + + return mapped; + } + + Protocol MapProtocolValue(s32 protocol) { + Protocol mapped = Protocol::IpProto_None; + + switch (protocol) { + case IPPROTO_IP: mapped = Protocol::IpProto_Ip; break; + case IPPROTO_ICMP: mapped = Protocol::IpProto_Icmp; break; + case IPPROTO_TCP: mapped = Protocol::IpProto_Tcp; break; + case IPPROTO_UDP: mapped = Protocol::IpProto_Udp; break; + case IPPROTO_NONE: mapped = Protocol::IpProto_None; break; + case IPPROTO_RAW: mapped = Protocol::IpProto_Raw; break; + case IPPROTO_MAX: mapped = Protocol::IpProto_Max; break; + default: + AMS_SDK_LOG("WARNING: Invalid socket protocol %d\n", static_cast(protocol)); + break; + } + + return mapped; + } + + s32 MapTypeValue(Type type) { + s32 mapped = -1; + + switch (type) { + case Type::Sock_Default: mapped = 0; break; + case Type::Sock_Stream: mapped = SOCK_STREAM; break; + case Type::Sock_Dgram: mapped = SOCK_DGRAM; break; + case Type::Sock_Raw: mapped = SOCK_RAW; break; + case Type::Sock_SeqPacket: mapped = SOCK_SEQPACKET; break; + case Type::Sock_NonBlock: mapped = -1; break; + default: + AMS_SDK_LOG("WARNING: Invalid ams::Socket::Type %d\n", static_cast(type)); + break; + } + + if (mapped == -1) { + AMS_SDK_LOG("WARNING: ams::Socket::Type %d is not supported by Win32/Win64.\n", static_cast(type)); + } + + return mapped; + } + + Type MapTypeValue(s32 type) { + Type mapped = Type::Sock_Default; + + switch (type) { + case 0: mapped = Type::Sock_Default; break; + case SOCK_STREAM: mapped = Type::Sock_Stream; break; + case SOCK_DGRAM: mapped = Type::Sock_Dgram; break; + case SOCK_RAW: mapped = Type::Sock_Raw; break; + case SOCK_SEQPACKET: mapped = Type::Sock_SeqPacket; break; + default: + AMS_SDK_LOG("WARNING: Invalid socket type %d\n", static_cast(type)); + break; + } + + return mapped; + } + + s8 MapFamilyValue(Family family) { + s8 mapped = -1; + + switch (family) { + case Family::Af_Unspec: mapped = AF_UNSPEC; break; + case Family::Af_Inet: mapped = AF_INET; break; + case Family::Af_Route: mapped = -1; break; + case Family::Af_Link: mapped = AF_LINK; break; + case Family::Af_Inet6: mapped = AF_INET6; break; + case Family::Af_Max: mapped = AF_MAX; break; + default: + AMS_SDK_LOG("WARNING: Invalid ams::Socket::Family %d\n", static_cast(family)); + break; + } + + if (mapped == -1) { + AMS_SDK_LOG("WARNING: ams::Socket::Family %d is not supported by Win32/Win64.\n", static_cast(family)); + } + + return mapped; + } + + Family MapFamilyValue(s8 family) { + Family mapped = Family::Af_Unspec; + + switch (family) { + case AF_UNSPEC:mapped = Family::Af_Unspec; break; + case AF_INET: mapped = Family::Af_Inet; break; + case AF_LINK: mapped = Family::Af_Link; break; + case AF_INET6: mapped = Family::Af_Inet6; break; + case AF_MAX: mapped = Family::Af_Max; break; + default: + AMS_SDK_LOG("WARNING: Invalid socket family %d\n", static_cast(family)); + break; + } + + return mapped; + } + + s32 MapMsgFlagValue(MsgFlag flag) { + s32 mapped = 0; + + if ((flag & MsgFlag::Msg_Oob) != MsgFlag::Msg_None) { mapped |= MSG_OOB; } + if ((flag & MsgFlag::Msg_Peek) != MsgFlag::Msg_None) { mapped |= MSG_PEEK; } + if ((flag & MsgFlag::Msg_DontRoute) != MsgFlag::Msg_None) { mapped |= MSG_DONTROUTE; } + if ((flag & MsgFlag::Msg_Trunc) != MsgFlag::Msg_None) { mapped |= MSG_TRUNC; } + if ((flag & MsgFlag::Msg_CTrunc) != MsgFlag::Msg_None) { mapped |= MSG_CTRUNC; } + if ((flag & MsgFlag::Msg_WaitAll) != MsgFlag::Msg_None) { mapped |= MSG_WAITALL; } + if ((flag & MsgFlag::Msg_DontWait) != MsgFlag::Msg_None) { mapped |= MSG_DONTWAIT; } + + return mapped; + } + + MsgFlag MapMsgFlagValue(s32 flag) { + MsgFlag mapped = MsgFlag::Msg_None; + + if (flag & MSG_OOB) { mapped |= MsgFlag::Msg_Oob; } + if (flag & MSG_PEEK) { mapped |= MsgFlag::Msg_Peek; } + if (flag & MSG_DONTROUTE) { mapped |= MsgFlag::Msg_DontRoute; } + if (flag & MSG_TRUNC) { mapped |= MsgFlag::Msg_Trunc; } + if (flag & MSG_CTRUNC) { mapped |= MsgFlag::Msg_CTrunc; } + if (flag & MSG_WAITALL) { mapped |= MsgFlag::Msg_WaitAll; } + if (flag & MSG_DONTWAIT) { mapped |= MsgFlag::Msg_DontWait; } + + return mapped; + } + + u32 MapAddrInfoFlagValue(AddrInfoFlag flag) { + u32 mapped = 0; + + if ((flag & AddrInfoFlag::Ai_Passive) != AddrInfoFlag::Ai_None) { mapped |= AI_PASSIVE; } + if ((flag & AddrInfoFlag::Ai_CanonName) != AddrInfoFlag::Ai_None) { mapped |= AI_CANONNAME; } + if ((flag & AddrInfoFlag::Ai_NumericHost) != AddrInfoFlag::Ai_None) { mapped |= AI_NUMERICHOST; } + if ((flag & AddrInfoFlag::Ai_NumericServ) != AddrInfoFlag::Ai_None) { mapped |= AI_NUMERICSERV; } + if ((flag & AddrInfoFlag::Ai_AddrConfig) != AddrInfoFlag::Ai_None) { mapped |= AI_ADDRCONFIG; } + + return mapped; + } + + AddrInfoFlag MapAddrInfoFlagValue(u32 flag) { + AddrInfoFlag mapped = AddrInfoFlag::Ai_None; + + if (flag & AI_PASSIVE) { mapped |= AddrInfoFlag::Ai_Passive; } + if (flag & AI_CANONNAME) { mapped |= AddrInfoFlag::Ai_CanonName; } + if (flag & AI_NUMERICHOST) { mapped |= AddrInfoFlag::Ai_NumericHost; } + if (flag & AI_NUMERICSERV) { mapped |= AddrInfoFlag::Ai_NumericServ; } + if (flag & AI_ADDRCONFIG) { mapped |= AddrInfoFlag::Ai_AddrConfig; } + + return mapped; + } + + u32 MapShutdownMethodValue(ShutdownMethod how) { + u32 mapped = -1; + + switch (how) { + case ShutdownMethod::Shut_Rd: mapped = SHUT_RD; break; + case ShutdownMethod::Shut_Wr: mapped = SHUT_WR; break; + case ShutdownMethod::Shut_RdWr: mapped = SHUT_RDWR; break; + default: + AMS_SDK_LOG("WARNING: Invalid ams::Socket::ShutdownMethod %d\n", static_cast(how)); + break; + } + + return mapped; + } + + ShutdownMethod MapShutdownMethodValue(u32 how) { + ShutdownMethod mapped = static_cast(-1); + + switch (how) { + case SHUT_RD: mapped = ShutdownMethod::Shut_Rd; break; + case SHUT_WR: mapped = ShutdownMethod::Shut_Wr; break; + case SHUT_RDWR: mapped = ShutdownMethod::Shut_RdWr; break; + default: + AMS_SDK_LOG("WARNING: Invalid socket shutdown %d\n", static_cast(how)); + break; + } + + return mapped; + } + + u32 MapFcntlFlagValue(FcntlFlag flag) { + u32 mapped = 0; + + switch (flag) { + case FcntlFlag::O_NonBlock: mapped = O_NONBLOCK; break; + default: + AMS_SDK_LOG("WARNING: Invalid ams::Socket::FcntlFlag %d\n", static_cast(flag)); + break; + } + + return mapped; + } + + FcntlFlag MapFcntlFlagValue(u32 flag) { + FcntlFlag mapped = FcntlFlag::None; + + switch (flag) { + case O_NONBLOCK: mapped = FcntlFlag::O_NonBlock; break; + default: + AMS_SDK_LOG("WARNING: Invalid socket fcntl flag %d\n", static_cast(flag)); + break; + } + + return mapped; + } + + s32 MapLevelValue(Level level) { + s32 mapped = -1; + + switch (level) { + case Level::Sol_Socket: mapped = SOL_SOCKET; break; + case Level::Sol_Ip: mapped = IPPROTO_IP; break; + case Level::Sol_Icmp: mapped = IPPROTO_ICMP; break; + case Level::Sol_Tcp: mapped = IPPROTO_TCP; break; + case Level::Sol_Udp: mapped = IPPROTO_UDP; break; + case Level::Sol_UdpLite: mapped = IPPROTO_UDP; break; + default: + AMS_SDK_LOG("WARNING: Invalid ams::Socket::Level %d\n", static_cast(level)); + break; + } + + return mapped; + } + + Level MapLevelValue(s32 level) { + Level mapped = static_cast(0); + + switch (level) { + case SOL_SOCKET: mapped = Level::Sol_Socket; break; + case IPPROTO_IP: mapped = Level::Sol_Ip; break; + case IPPROTO_ICMP: mapped = Level::Sol_Icmp; break; + case IPPROTO_TCP: mapped = Level::Sol_Tcp; break; + case IPPROTO_UDP: mapped = Level::Sol_Udp; break; + default: + AMS_SDK_LOG("WARNING: Invalid socket level %d\n", static_cast(level)); + break; + } + + return mapped; + } + + s32 MapOptionValue(Level level, Option option) { + s32 mapped = -1; + + switch (level) { + case Level::Sol_Socket: + switch (option) { + case Option::So_Debug: mapped = SO_DEBUG; break; + case Option::So_AcceptConn: mapped = SO_ACCEPTCONN; break; + case Option::So_ReuseAddr: mapped = SO_REUSEADDR; break; + case Option::So_KeepAlive: mapped = SO_KEEPALIVE; break; + case Option::So_DontRoute: mapped = SO_DONTROUTE; break; + case Option::So_Broadcast: mapped = SO_BROADCAST; break; + case Option::So_UseLoopback: mapped = SO_USELOOPBACK; break; + case Option::So_Linger: mapped = SO_LINGER; break; + case Option::So_OobInline: mapped = SO_OOBINLINE; break; + case Option::So_ReusePort: mapped = -1; break; + case Option::So_SndBuf: mapped = SO_SNDBUF; break; + case Option::So_RcvBuf: mapped = SO_RCVBUF; break; + case Option::So_SndLoWat: mapped = SO_SNDLOWAT; break; + case Option::So_RcvLoWat: mapped = SO_RCVLOWAT; break; + case Option::So_SndTimeo: mapped = SO_SNDTIMEO; break; + case Option::So_RcvTimeo: mapped = SO_RCVTIMEO; break; + case Option::So_Error: mapped = SO_ERROR; break; + case Option::So_Type: mapped = SO_TYPE; break; + case Option::So_Label: mapped = -1; break; + case Option::So_PeerLabel: mapped = -1; break; + case Option::So_ListenQLimit: mapped = -1; break; + case Option::So_ListenQLen: mapped = -1; break; + case Option::So_ListenIncQLen: mapped = -1; break; + case Option::So_SetFib: mapped = -1; break; + case Option::So_User_Cookie: mapped = -1; break; + case Option::So_Protocol: mapped = -1; break; + case Option::So_Vendor: mapped = -1; break; + case Option::So_Nn_Linger: mapped = -1; break; + case Option::So_Nn_Shutdown_Exempt: mapped = -1; break; + default: + AMS_SDK_LOG("WARNING: Invalid ams::Socket::Option %d for Level::Sol_Socket\n", static_cast(option)); + break; + } + break; + case Level::Sol_Ip: + switch (option) { + case Option::Ip_Options: mapped = IP_OPTIONS; break; + case Option::Ip_HdrIncl: mapped = IP_HDRINCL; break; + case Option::Ip_Tos: mapped = IP_TOS; break; + case Option::Ip_Ttl: mapped = IP_TTL; break; + case Option::Ip_RecvOpts: mapped = -1; break; + case Option::Ip_Multicast_If: mapped = IP_MULTICAST_IF; break; + case Option::Ip_Multicast_Ttl: mapped = IP_MULTICAST_TTL; break; + case Option::Ip_Multicast_Loop: mapped = IP_MULTICAST_LOOP; break; + case Option::Ip_Add_Membership: mapped = IP_ADD_MEMBERSHIP; break; + case Option::Ip_Drop_Membership: mapped = IP_DROP_MEMBERSHIP; break; + case Option::Ip_Multicast_Vif: mapped = -1; break; + case Option::Ip_Rsvp_On: mapped = -1; break; + case Option::Ip_Rsvp_Off: mapped = -1; break; + case Option::Ip_Rsvp_Vif_On: mapped = -1; break; + case Option::Ip_Rsvp_Vif_Off: mapped = -1; break; + case Option::Ip_PortRange: mapped = -1; break; + case Option::Ip_Faith: mapped = -1; break; + case Option::Ip_OnesBcast: mapped = -1; break; + case Option::Ip_BindAny: mapped = -1; break; + case Option::Ip_RecvTtl: mapped = -1; break; + case Option::Ip_MinTtl: mapped = -1; break; + case Option::Ip_DontFrag: mapped = -1; break; + case Option::Ip_RecvTos: mapped = -1; break; + case Option::Ip_Add_Source_Membership: mapped = IP_ADD_SOURCE_MEMBERSHIP; break; + case Option::Ip_Drop_Source_Membership: mapped = IP_DROP_SOURCE_MEMBERSHIP; break; + case Option::Ip_Block_Source: mapped = IP_BLOCK_SOURCE; break; + case Option::Ip_Unblock_Source: mapped = IP_UNBLOCK_SOURCE; break; + default: + AMS_SDK_LOG("WARNING: Invalid ams::Socket::Option %d for Level::Sol_Ip\n", static_cast(option)); + break; + } + break; + case Level::Sol_Tcp: + switch (option) { + case Option::Tcp_NoDelay: mapped = TCP_NODELAY; break; + case Option::Tcp_MaxSeg: mapped = TCP_MAXSEG; break; + case Option::Tcp_NoPush: mapped = -1; break; + case Option::Tcp_NoOpt: mapped = -1; break; + case Option::Tcp_Md5Sig: mapped = -1; break; + case Option::Tcp_Info: mapped = -1; break; + case Option::Tcp_Congestion: mapped = -1; break; + case Option::Tcp_KeepInit: mapped = -1; break; + case Option::Tcp_KeepIdle: mapped = -1; break; + case Option::Tcp_KeepIntvl: mapped = -1; break; + case Option::Tcp_KeepCnt: mapped = -1; break; + case Option::Tcp_Vendor: mapped = -1; break; + default: + AMS_SDK_LOG("WARNING: Invalid ams::Socket::Option %d for Level::Sol_Tcp\n", static_cast(option)); + break; + } + break; + default: + AMS_SDK_LOG("WARNING: Invalid option level %d\n", static_cast(level)); + break; + } + + if (mapped == -1) { + AMS_SDK_LOG("WARNING: ams::Socket::Option %d is not supported by Win32/Win64.\n", static_cast(option)); + } + + return mapped; + } + + Option MapOptionValue(s32 level, s32 option) { + Option mapped = static_cast