diff --git a/stratosphere/tma/Makefile b/stratosphere/tma/Makefile index 6f361d647..7dc12f262 100644 --- a/stratosphere/tma/Makefile +++ b/stratosphere/tma/Makefile @@ -29,7 +29,7 @@ DEFINES := -DDISABLE_IPC #--------------------------------------------------------------------------------- # options for code generation #--------------------------------------------------------------------------------- -ARCH := -march=armv8-a -mtune=cortex-a57 -mtp=soft -fPIE +ARCH := -march=armv8-a+crc+crypto -mtune=cortex-a57 -mtp=soft -fPIE CFLAGS := -g -Wall -O2 -ffunction-sections \ $(ARCH) $(DEFINES) diff --git a/stratosphere/tma/source/crc.h b/stratosphere/tma/source/crc.h new file mode 100644 index 000000000..1fc37cffb --- /dev/null +++ b/stratosphere/tma/source/crc.h @@ -0,0 +1,93 @@ +/* + * Copyright (c) 2018 Atmosphère-NX + * + * This program is free software; you can redistribute it and/or modify it + * under the terms and conditions of the GNU General Public License, + * version 2, as published by the Free Software Foundation. + * + * This program is distributed in the hope it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif + +/* Code taken from Yazen Ghannam , licensed GPLv2. */ + +#define CRC32X(crc, value) __asm__("crc32x %w[c], %w[c], %x[v]":[c]"+r"(crc):[v]"r"(value)) +#define CRC32W(crc, value) __asm__("crc32w %w[c], %w[c], %w[v]":[c]"+r"(crc):[v]"r"(value)) +#define CRC32H(crc, value) __asm__("crc32h %w[c], %w[c], %w[v]":[c]"+r"(crc):[v]"r"(value)) +#define CRC32B(crc, value) __asm__("crc32b %w[c], %w[c], %w[v]":[c]"+r"(crc):[v]"r"(value)) +#define CRC32CX(crc, value) __asm__("crc32cx %w[c], %w[c], %x[v]":[c]"+r"(crc):[v]"r"(value)) +#define CRC32CW(crc, value) __asm__("crc32cw %w[c], %w[c], %w[v]":[c]"+r"(crc):[v]"r"(value)) +#define CRC32CH(crc, value) __asm__("crc32ch %w[c], %w[c], %w[v]":[c]"+r"(crc):[v]"r"(value)) +#define CRC32CB(crc, value) __asm__("crc32cb %w[c], %w[c], %w[v]":[c]"+r"(crc):[v]"r"(value)) + +static inline uint16_t __get_unaligned_le16(const uint8_t *p) +{ + return p[0] | p[1] << 8; +} + +static inline uint32_t __get_unaligned_le32(const uint8_t *p) +{ + return p[0] | p[1] << 8 | p[2] << 16 | p[3] << 24; +} + +static inline uint64_t __get_unaligned_le64(const uint8_t *p) +{ + return (uint64_t)__get_unaligned_le32(p + 4) << 32 | + __get_unaligned_le32(p); +} + +static inline uint16_t get_unaligned_le16(const void *p) +{ + return __get_unaligned_le16((const uint8_t *)p); +} + +static inline uint32_t get_unaligned_le32(const void *p) +{ + return __get_unaligned_le32((const uint8_t *)p); +} + +static inline uint64_t get_unaligned_le64(const void *p) +{ + return __get_unaligned_le64((const uint8_t *)p); +} + + +static u32 crc32_arm64_le_hw(const u8 *p, unsigned int len) { + u32 crc = 0xFFFFFFFF; + + s64 length = len; + + while ((length -= sizeof(u64)) >= 0) { + CRC32X(crc, get_unaligned_le64(p)); + p += sizeof(u64); + } + + /* The following is more efficient than the straight loop */ + if (length & sizeof(u32)) { + CRC32W(crc, get_unaligned_le32(p)); + p += sizeof(u32); + } + if (length & sizeof(u16)) { + CRC32H(crc, get_unaligned_le16(p)); + p += sizeof(u16); + } + if (length & sizeof(u8)) + CRC32B(crc, *p); + + return crc ^ 0xFFFFFFFF; +} + +#ifdef __cplusplus +} +#endif \ No newline at end of file diff --git a/stratosphere/tma/source/tma_conn_packet.hpp b/stratosphere/tma/source/tma_conn_packet.hpp new file mode 100644 index 000000000..03aaffc5a --- /dev/null +++ b/stratosphere/tma/source/tma_conn_packet.hpp @@ -0,0 +1,262 @@ +/* + * Copyright (c) 2018 Atmosphère-NX + * + * This program is free software; you can redistribute it and/or modify it + * under the terms and conditions of the GNU General Public License, + * version 2, as published by the Free Software Foundation. + * + * This program is distributed in the hope it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#pragma once + +#include +#include "tma_conn_result.hpp" +#include "tma_conn_service_ids.hpp" +#include "crc.h" + +class TmaPacket { + public: + struct Header { + u32 service_id; + u32 task_id; + u16 command; + u8 is_continuation; + u8 version; + u32 body_len; + u32 reserved[4]; /* This is where N's header ends. */ + u32 body_checksum; + u32 header_checksum; + }; + + static_assert(sizeof(Header) == 0x28, "Packet::Header definition!"); + + static constexpr u32 MaxBodySize = 0xE000; + static constexpr u32 MaxPacketSize = MaxBodySize + sizeof(Header); + + private: + std::unique_ptr buffer = std::make_unique(MaxPacketSize); + u32 offset = 0; + + Header *GetHeader() const { + return reinterpret_cast
(buffer.get()); + } + + u8 *GetBody(u32 ofs) const { + return reinterpret_cast(buffer.get() + sizeof(Header) + ofs); + } + public: + TmaPacket() { + memset(buffer.get(), 0, MaxPacketSize); + } + + /* Implicit ~TmaPacket() */ + + /* These allow reading a packet in. */ + void CopyHeaderFrom(Header *hdr) { + *GetHeader() = *hdr; + } + + TmaConnResult CopyBodyFrom(void *body, size_t size) { + if (size >= MaxBodySize) { + return TmaConnResult::PacketOverflow; + } + + memcpy(GetBody(0), body, size); + + return TmaConnResult::Success; + } + + void CopyHeaderTo(void *out) { + memcpy(out, buffer.get(), sizeof(Header)); + } + + void CopyBodyTo(void *out) const { + memcpy(out, buffer.get() + sizeof(Header), GetBodyLength()); + } + + bool IsHeaderValid() { + Header *hdr = GetHeader(); + return crc32_arm64_le_hw(reinterpret_cast(hdr), sizeof(*hdr) - sizeof(hdr->header_checksum)) == hdr->header_checksum; + } + + bool IsBodyValid() const { + const u32 body_len = GetHeader()->body_len; + if (body_len == 0) { + return GetHeader()->body_checksum == 0; + } else { + return crc32_arm64_le_hw(GetBody(0), body_len) == GetHeader()->body_checksum; + } + } + + void SetChecksums() { + Header *hdr = GetHeader(); + if (hdr->body_len) { + hdr->body_checksum = crc32_arm64_le_hw(GetBody(0), hdr->body_len); + } else { + hdr->body_checksum = 0; + } + hdr->header_checksum = crc32_arm64_le_hw(reinterpret_cast(hdr), sizeof(*hdr) - sizeof(hdr->header_checksum)); + } + + u32 GetBodyLength() const { + return GetHeader()->body_len; + } + + u32 GetLength() const { + return GetBodyLength() + sizeof(Header); + } + + u32 GetBodyAvailableLength() const { + return MaxPacketSize - this->offset; + } + + void SetServiceId(TmaService srv) { + GetHeader()->service_id = static_cast(srv); + } + + TmaService GetServiceId() const { + return static_cast(GetHeader()->service_id); + } + + void SetTaskId(u32 id) { + GetHeader()->task_id = id; + } + + u32 GetTaskId() const { + return GetHeader()->task_id; + } + + void SetCommand(u16 cmd) { + GetHeader()->command = cmd; + } + + u16 GetCommand() const { + return GetHeader()->command; + } + + void SetContinuation(bool c) { + GetHeader()->is_continuation = c ? 1 : 0; + } + + bool GetContinuation() const { + return GetHeader()->is_continuation == 1; + } + + void SetVersion(u8 v) { + GetHeader()->version = v; + } + + u8 GetVersion() const { + return GetHeader()->version; + } + + void ClearOffset() { + this->offset = 0; + } + + TmaConnResult Write(const void *data, size_t size) { + if (size > GetBodyAvailableLength()) { + return TmaConnResult::PacketOverflow; + } + + memcpy(GetBody(this->offset), data, size); + this->offset += size; + GetHeader()->body_len = this->offset; + + return TmaConnResult::Success; + } + + TmaConnResult Read(void *data, size_t size) { + if (size > GetBodyAvailableLength()) { + return TmaConnResult::PacketOverflow; + } + + memcpy(data, GetBody(this->offset), size); + this->offset += size; + + return TmaConnResult::Success; + } + + template + TmaConnResult Write(const T &t) { + return Write(&t, sizeof(T)); + } + + template + TmaConnResult Read(const T &t) { + return Read(&t, sizeof(T)); + } + + TmaConnResult WriteString(const char *s) { + return Write(s, strlen(s) + 1); + } + + size_t WriteFormat(const char *format, ...) { + va_list va_arg; + va_start(va_arg, format); + const size_t available = GetBodyAvailableLength(); + const int written = vsnprintf(reinterpret_cast(GetBody(this->offset)), available, format, va_arg); + + size_t total_written; + if (static_cast(written) < available) { + this->offset += written; + *GetBody(this->offset++) = 0; + total_written = written + 1; + } else { + this->offset += available; + total_written = available; + } + + GetHeader()->body_len = this->offset; + return total_written; + } + + TmaConnResult ReadString(char *buf, size_t buf_size, size_t *out_size) { + TmaConnResult res = TmaConnResult::Success; + + size_t available = GetBodyAvailableLength(); + size_t ofs = 0; + while (ofs < buf_size) { + if (ofs >= available) { + res = TmaConnResult::PacketOverflow; + break; + } + if (ofs == buf_size) { + res = TmaConnResult::BufferOverflow; + break; + } + + buf[ofs] = static_cast(*GetBody(this->offset++)); + + if (buf[ofs++] == '\x00') { + break; + } + } + + /* Finish reading the string if the user buffer is too small. */ + if (res == TmaConnResult::BufferOverflow) { + u8 cur = *GetBody(this->offset); + while (cur != 0) { + if (ofs >= available) { + res = TmaConnResult::PacketOverflow; + break; + } + cur = *GetBody(this->offset++); + ofs++; + } + } + + if (out_size != nullptr) { + *out_size = ofs; + } + + return res; + } +}; diff --git a/stratosphere/tma/source/tma_conn_result.hpp b/stratosphere/tma/source/tma_conn_result.hpp new file mode 100644 index 000000000..7d91923af --- /dev/null +++ b/stratosphere/tma/source/tma_conn_result.hpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2018 Atmosphère-NX + * + * This program is free software; you can redistribute it and/or modify it + * under the terms and conditions of the GNU General Public License, + * version 2, as published by the Free Software Foundation. + * + * This program is distributed in the hope it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +enum class TmaConnResult : u32 { + Success = 0, + NotImplemented, + GeneralFailure, + ConnectionFailure, + AlreadyConnected, + WrongConnectionVersion, + PacketOverflow, + BufferOverflow, + Disconnected, + ServiceAlreadyRegistered, + ServiceUnknown, + Timeout, + NotInitialized, +}; diff --git a/stratosphere/tma/source/tma_conn_service_ids.hpp b/stratosphere/tma/source/tma_conn_service_ids.hpp new file mode 100644 index 000000000..7dc8ee331 --- /dev/null +++ b/stratosphere/tma/source/tma_conn_service_ids.hpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2018 Atmosphère-NX + * + * This program is free software; you can redistribute it and/or modify it + * under the terms and conditions of the GNU General Public License, + * version 2, as published by the Free Software Foundation. + * + * This program is distributed in the hope it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#pragma once + +#include "tma_conn_result.hpp" + +/* This is just python's hash function, but official TMA code uses it. */ +static constexpr u32 HashServiceName(const char *name) { + u32 h = *name; + u32 len = 0; + + while (*name) { + h = (1000003 * h) ^ *name; + name++; + len++; + } + + return h ^ len; +} + +enum class TmaService : u32 { + Invalid = 0, + TestService = HashServiceName("AtmosphereTestService"), /* Temporary service, will be used to debug communications. */ +}; diff --git a/stratosphere/tma/source/tma_main.cpp b/stratosphere/tma/source/tma_main.cpp index a1b291e1d..e4c23777b 100644 --- a/stratosphere/tma/source/tma_main.cpp +++ b/stratosphere/tma/source/tma_main.cpp @@ -22,12 +22,14 @@ #include #include +#include "tma_usb_comms.hpp" + extern "C" { extern u32 __start__; u32 __nx_applet_type = AppletType_None; - #define INNER_HEAP_SIZE 0x20000 + #define INNER_HEAP_SIZE 0x100000 size_t nx_inner_heap_size = INNER_HEAP_SIZE; char nx_inner_heap[INNER_HEAP_SIZE]; @@ -71,10 +73,7 @@ void __appExit(void) { smExit(); } -int main(int argc, char **argv) -{ - consoleDebugInit(debugDevice_SVC); - +void PmThread(void *arg) { /* Setup psc module. */ Result rc; PscPmModule tma_module = {0}; @@ -99,6 +98,32 @@ int main(int argc, char **argv) fatalSimple(rc); } } +} + +int main(int argc, char **argv) +{ + consoleDebugInit(debugDevice_SVC); + Thread pm_thread = {0}; + if (R_FAILED(threadCreate(&pm_thread, &PmThread, NULL, 0x4000, 0x15, 0))) { + /* TODO: Panic. */ + } + if (R_FAILED(threadStart(&pm_thread))) { + /* TODO: Panic. */ + } + + TmaUsbComms::Initialize(); + TmaPacket *packet = new TmaPacket(); + usbDsWaitReady(U64_MAX); + packet->Write(0xCAFEBABEDEADCAFEUL); + packet->Write(0xCCCCCCCCCCCCCCCCUL); + TmaUsbComms::SendPacket(packet); + packet->ClearOffset(); + while (true) { + if (TmaUsbComms::ReceivePacket(packet) == TmaConnResult::Success) { + TmaUsbComms::SendPacket(packet); + } + } + return 0; } diff --git a/stratosphere/tma/source/tma_usb_comms.cpp b/stratosphere/tma/source/tma_usb_comms.cpp new file mode 100644 index 000000000..440218bef --- /dev/null +++ b/stratosphere/tma/source/tma_usb_comms.cpp @@ -0,0 +1,460 @@ +/* + * Copyright (c) 2018 Atmosphère-NX + * + * This program is free software; you can redistribute it and/or modify it + * under the terms and conditions of the GNU General Public License, + * version 2, as published by the Free Software Foundation. + * + * This program is distributed in the hope it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#include "tma_usb_comms.hpp" + +/* TODO: Is this actually allowed? */ +#define ATMOSPHERE_INTERFACE_PROTOCOL 0xFC + +static std::atomic g_initialized = false; +static UsbDsInterface *g_interface; +static UsbDsEndpoint *g_endpoint_in, *g_endpoint_out; + +/* USB State Change Tracking. */ +static HosThread g_state_change_thread; +static WaitableManagerBase *g_state_change_manager = nullptr; +static void (*g_state_change_callback)(void *arg, u32 state); +static void *g_state_change_arg; + +/* USB Send/Receive mutexes. */ +static HosMutex g_send_mutex; +static HosMutex g_recv_mutex; + +/* Static arrays to do USB DMA into. */ +static constexpr size_t DmaBufferAlign = 0x1000; +static constexpr size_t HeaderBufferSize = DmaBufferAlign; +static constexpr size_t DataBufferSize = 0x18000; +static __attribute__((aligned(DmaBufferAlign))) u8 g_header_buffer[HeaderBufferSize]; +static __attribute__((aligned(DmaBufferAlign))) u8 g_recv_data_buf[DataBufferSize]; +static __attribute__((aligned(DmaBufferAlign))) u8 g_send_data_buf[DataBufferSize]; + +/* Taken from libnx usb comms. */ +static Result _usbCommsInterfaceInit1x() +{ + Result rc = 0; + + struct usb_interface_descriptor interface_descriptor = { + .bLength = USB_DT_INTERFACE_SIZE, + .bDescriptorType = USB_DT_INTERFACE, + .bInterfaceNumber = 4, + .bInterfaceClass = USB_CLASS_VENDOR_SPEC, + .bInterfaceSubClass = USB_CLASS_VENDOR_SPEC, + .bInterfaceProtocol = ATMOSPHERE_INTERFACE_PROTOCOL, + }; + + struct usb_endpoint_descriptor endpoint_descriptor_in = { + .bLength = USB_DT_ENDPOINT_SIZE, + .bDescriptorType = USB_DT_ENDPOINT, + .bEndpointAddress = USB_ENDPOINT_IN, + .bmAttributes = USB_TRANSFER_TYPE_BULK, + .wMaxPacketSize = 0x200, + }; + + struct usb_endpoint_descriptor endpoint_descriptor_out = { + .bLength = USB_DT_ENDPOINT_SIZE, + .bDescriptorType = USB_DT_ENDPOINT, + .bEndpointAddress = USB_ENDPOINT_OUT, + .bmAttributes = USB_TRANSFER_TYPE_BULK, + .wMaxPacketSize = 0x200, + }; + + if (R_FAILED(rc)) return rc; + + //Setup interface. + rc = usbDsGetDsInterface(&g_interface, &interface_descriptor, "usb"); + if (R_FAILED(rc)) return rc; + + //Setup endpoints. + rc = usbDsInterface_GetDsEndpoint(g_interface, &g_endpoint_in, &endpoint_descriptor_in);//device->host + if (R_FAILED(rc)) return rc; + + rc = usbDsInterface_GetDsEndpoint(g_interface, &g_endpoint_out, &endpoint_descriptor_out);//host->device + if (R_FAILED(rc)) return rc; + + return rc; +} + +static Result _usbCommsInterfaceInit5x() { + Result rc = 0; + + u8 iManufacturer, iProduct, iSerialNumber; + static const u16 supported_langs[1] = {0x0409}; + // Send language descriptor + rc = usbDsAddUsbLanguageStringDescriptor(NULL, supported_langs, sizeof(supported_langs)/sizeof(u16)); + // Send manufacturer + if (R_SUCCEEDED(rc)) rc = usbDsAddUsbStringDescriptor(&iManufacturer, "Nintendo"); + // Send product + if (R_SUCCEEDED(rc)) rc = usbDsAddUsbStringDescriptor(&iProduct, "Nintendo Switch"); + // Send serial number + if (R_SUCCEEDED(rc)) rc = usbDsAddUsbStringDescriptor(&iSerialNumber, "SerialNumber"); + + // Send device descriptors + struct usb_device_descriptor device_descriptor = { + .bLength = USB_DT_DEVICE_SIZE, + .bDescriptorType = USB_DT_DEVICE, + .bcdUSB = 0x0110, + .bDeviceClass = 0x00, + .bDeviceSubClass = 0x00, + .bDeviceProtocol = 0x00, + .bMaxPacketSize0 = 0x40, + .idVendor = 0x057e, + .idProduct = 0x3000, + .bcdDevice = 0x0100, + .iManufacturer = iManufacturer, + .iProduct = iProduct, + .iSerialNumber = iSerialNumber, + .bNumConfigurations = 0x01 + }; + // Full Speed is USB 1.1 + if (R_SUCCEEDED(rc)) rc = usbDsSetUsbDeviceDescriptor(UsbDeviceSpeed_Full, &device_descriptor); + + // High Speed is USB 2.0 + device_descriptor.bcdUSB = 0x0200; + if (R_SUCCEEDED(rc)) rc = usbDsSetUsbDeviceDescriptor(UsbDeviceSpeed_High, &device_descriptor); + + // Super Speed is USB 3.0 + device_descriptor.bcdUSB = 0x0300; + // Upgrade packet size to 512 + device_descriptor.bMaxPacketSize0 = 0x09; + if (R_SUCCEEDED(rc)) rc = usbDsSetUsbDeviceDescriptor(UsbDeviceSpeed_Super, &device_descriptor); + + // Define Binary Object Store + u8 bos[0x16] = { + 0x05, // .bLength + USB_DT_BOS, // .bDescriptorType + 0x16, 0x00, // .wTotalLength + 0x02, // .bNumDeviceCaps + + // USB 2.0 + 0x07, // .bLength + USB_DT_DEVICE_CAPABILITY, // .bDescriptorType + 0x02, // .bDevCapabilityType + 0x02, 0x00, 0x00, 0x00, // dev_capability_data + + // USB 3.0 + 0x0A, // .bLength + USB_DT_DEVICE_CAPABILITY, // .bDescriptorType + 0x03, // .bDevCapabilityType + 0x00, 0x0E, 0x00, 0x03, 0x00, 0x00, 0x00 + }; + if (R_SUCCEEDED(rc)) rc = usbDsSetBinaryObjectStore(bos, sizeof(bos)); + + if (R_FAILED(rc)) return rc; + + struct usb_interface_descriptor interface_descriptor = { + .bLength = USB_DT_INTERFACE_SIZE, + .bDescriptorType = USB_DT_INTERFACE, + .bInterfaceNumber = 4, + .bNumEndpoints = 2, + .bInterfaceClass = USB_CLASS_VENDOR_SPEC, + .bInterfaceSubClass = USB_CLASS_VENDOR_SPEC, + .bInterfaceProtocol = ATMOSPHERE_INTERFACE_PROTOCOL, + }; + + struct usb_endpoint_descriptor endpoint_descriptor_in = { + .bLength = USB_DT_ENDPOINT_SIZE, + .bDescriptorType = USB_DT_ENDPOINT, + .bEndpointAddress = USB_ENDPOINT_IN, + .bmAttributes = USB_TRANSFER_TYPE_BULK, + .wMaxPacketSize = 0x40, + }; + + struct usb_endpoint_descriptor endpoint_descriptor_out = { + .bLength = USB_DT_ENDPOINT_SIZE, + .bDescriptorType = USB_DT_ENDPOINT, + .bEndpointAddress = USB_ENDPOINT_OUT, + .bmAttributes = USB_TRANSFER_TYPE_BULK, + .wMaxPacketSize = 0x40, + }; + + struct usb_ss_endpoint_companion_descriptor endpoint_companion = { + .bLength = sizeof(struct usb_ss_endpoint_companion_descriptor), + .bDescriptorType = USB_DT_SS_ENDPOINT_COMPANION, + .bMaxBurst = 0x0F, + .bmAttributes = 0x00, + .wBytesPerInterval = 0x00, + }; + + rc = usbDsRegisterInterface(&g_interface); + if (R_FAILED(rc)) return rc; + + interface_descriptor.bInterfaceNumber = g_interface->interface_index; + endpoint_descriptor_in.bEndpointAddress += interface_descriptor.bInterfaceNumber + 1; + endpoint_descriptor_out.bEndpointAddress += interface_descriptor.bInterfaceNumber + 1; + + // Full Speed Config + rc = usbDsInterface_AppendConfigurationData(g_interface, UsbDeviceSpeed_Full, &interface_descriptor, USB_DT_INTERFACE_SIZE); + if (R_FAILED(rc)) return rc; + rc = usbDsInterface_AppendConfigurationData(g_interface, UsbDeviceSpeed_Full, &endpoint_descriptor_in, USB_DT_ENDPOINT_SIZE); + if (R_FAILED(rc)) return rc; + rc = usbDsInterface_AppendConfigurationData(g_interface, UsbDeviceSpeed_Full, &endpoint_descriptor_out, USB_DT_ENDPOINT_SIZE); + if (R_FAILED(rc)) return rc; + + // High Speed Config + endpoint_descriptor_in.wMaxPacketSize = 0x200; + endpoint_descriptor_out.wMaxPacketSize = 0x200; + rc = usbDsInterface_AppendConfigurationData(g_interface, UsbDeviceSpeed_High, &interface_descriptor, USB_DT_INTERFACE_SIZE); + if (R_FAILED(rc)) return rc; + rc = usbDsInterface_AppendConfigurationData(g_interface, UsbDeviceSpeed_High, &endpoint_descriptor_in, USB_DT_ENDPOINT_SIZE); + if (R_FAILED(rc)) return rc; + rc = usbDsInterface_AppendConfigurationData(g_interface, UsbDeviceSpeed_High, &endpoint_descriptor_out, USB_DT_ENDPOINT_SIZE); + if (R_FAILED(rc)) return rc; + + // Super Speed Config + endpoint_descriptor_in.wMaxPacketSize = 0x400; + endpoint_descriptor_out.wMaxPacketSize = 0x400; + rc = usbDsInterface_AppendConfigurationData(g_interface, UsbDeviceSpeed_Super, &interface_descriptor, USB_DT_INTERFACE_SIZE); + if (R_FAILED(rc)) return rc; + rc = usbDsInterface_AppendConfigurationData(g_interface, UsbDeviceSpeed_Super, &endpoint_descriptor_in, USB_DT_ENDPOINT_SIZE); + if (R_FAILED(rc)) return rc; + rc = usbDsInterface_AppendConfigurationData(g_interface, UsbDeviceSpeed_Super, &endpoint_companion, USB_DT_SS_ENDPOINT_COMPANION_SIZE); + if (R_FAILED(rc)) return rc; + rc = usbDsInterface_AppendConfigurationData(g_interface, UsbDeviceSpeed_Super, &endpoint_descriptor_out, USB_DT_ENDPOINT_SIZE); + if (R_FAILED(rc)) return rc; + rc = usbDsInterface_AppendConfigurationData(g_interface, UsbDeviceSpeed_Super, &endpoint_companion, USB_DT_SS_ENDPOINT_COMPANION_SIZE); + if (R_FAILED(rc)) return rc; + + //Setup endpoints. + rc = usbDsInterface_RegisterEndpoint(g_interface, &g_endpoint_in, endpoint_descriptor_in.bEndpointAddress); + if (R_FAILED(rc)) return rc; + + rc = usbDsInterface_RegisterEndpoint(g_interface, &g_endpoint_out, endpoint_descriptor_out.bEndpointAddress); + if (R_FAILED(rc)) return rc; + + return rc; +} + + +/* Actual function implementations. */ +TmaConnResult TmaUsbComms::Initialize() { + TmaConnResult res = TmaConnResult::Success; + + if (g_initialized) { + std::abort(); + } + + Result rc = usbDsInitialize(); + + /* Perform interface setup. */ + if (R_SUCCEEDED(rc)) { + if (GetRuntimeFirmwareVersion() >= FirmwareVersion_500) { + rc = _usbCommsInterfaceInit5x(); + } else { + rc = _usbCommsInterfaceInit1x(); + } + } + + /* Start the state change thread. */ + /*if (R_SUCCEEDED(rc)) { + rc = g_state_change_thread.Initialize(&TmaUsbComms::UsbStateChangeThreadFunc, nullptr, 0x4000, 38); + if (R_SUCCEEDED(rc)) { + rc = g_state_change_thread.Start(); + } + }*/ + + /* Enable USB communication. */ + if (R_SUCCEEDED(rc)) { + rc = usbDsInterface_EnableInterface(g_interface); + } + if (R_SUCCEEDED(rc) && GetRuntimeFirmwareVersion() >= FirmwareVersion_500) { + rc = usbDsEnable(); + } + + + if (R_FAILED(rc)) { + /* TODO: Should I not abort here? */ + std::abort(); + + // /* Cleanup, just in case. */ + // TmaUsbComms::Finalize(); + // res = TmaConnResult::Failure; + } + + g_initialized = true; + + return res; +} + +TmaConnResult TmaUsbComms::Finalize() { + Result rc = 0; + /* We must have initialized before calling finalize. */ + if (!g_initialized) { + std::abort(); + } + + /* Kill the state change thread. */ + g_state_change_manager->RequestStop(); + if (R_FAILED(g_state_change_thread.Join())) { + std::abort(); + } + + CancelComms(); + if (R_SUCCEEDED(rc)) { + usbDsExit(); + } + + g_initialized = false; + + return R_SUCCEEDED(rc) ? TmaConnResult::Success : TmaConnResult::ConnectionFailure; +} + +void TmaUsbComms::CancelComms() { + if (!g_initialized) { + return; + } + + usbDsEndpoint_Cancel(g_endpoint_in); + usbDsEndpoint_Cancel(g_endpoint_out); +} + +void TmaUsbComms::SetStateChangeCallback(void (*callback)(void *, u32), void *arg) { + g_state_change_callback = callback; + g_state_change_arg = arg; +} + +Result TmaUsbComms::UsbXfer(UsbDsEndpoint *ep, size_t *out_xferd, void *buf, size_t size) { + Result rc = 0; + u32 urbId = 0; + u32 total_xferd = 0; + UsbDsReportData reportdata; + + if (size) { + /* Start transfer. */ + rc = usbDsEndpoint_PostBufferAsync(ep, buf, size, &urbId); + if (R_FAILED(rc)) return rc; + + /* Wait for transfer to complete. */ + eventWait(&ep->CompletionEvent, U64_MAX); + eventClear(&ep->CompletionEvent); + + rc = usbDsEndpoint_GetReportData(ep, &reportdata); + if (R_FAILED(rc)) return rc; + + rc = usbDsParseReportData(&reportdata, urbId, NULL, &total_xferd); + if (R_FAILED(rc)) return rc; + } + + if (out_xferd) *out_xferd = total_xferd; + + return rc; +} + +TmaConnResult TmaUsbComms::ReceivePacket(TmaPacket *packet) { + std::scoped_lock lk{g_recv_mutex}; + TmaConnResult res = TmaConnResult::Success; + + if (!g_initialized || packet == nullptr) { + return TmaConnResult::GeneralFailure; + } + + /* Read the header. */ + size_t read = 0; + if (R_SUCCEEDED(UsbXfer(g_endpoint_out, &read, g_header_buffer, sizeof(TmaPacket::Header)))) { + packet->CopyHeaderFrom(reinterpret_cast(g_header_buffer)); + } else { + res = TmaConnResult::GeneralFailure; + } + + /* Validate the read header data. */ + if (res == TmaConnResult::Success) { + if (read != sizeof(TmaPacket::Header) || !packet->IsHeaderValid()) { + res = TmaConnResult::GeneralFailure; + } + } + + /* Read the body! */ + if (res == TmaConnResult::Success) { + const u32 body_len = packet->GetBodyLength(); + if (0 < body_len) { + if (body_len <= sizeof(g_recv_data_buf)) { + if (R_SUCCEEDED(UsbXfer(g_endpoint_out, &read, g_recv_data_buf, body_len))) { + if (read == body_len) { + res = packet->CopyBodyFrom(g_recv_data_buf, body_len); + } else { + res = TmaConnResult::GeneralFailure; + } + } + } else { + res = TmaConnResult::GeneralFailure; + } + } + } + + /* Validate the body. */ + if (res == TmaConnResult::Success) { + if (!packet->IsBodyValid()) { + res = TmaConnResult::GeneralFailure; + } + } + + if (res == TmaConnResult::Success) { + packet->ClearOffset(); + } + + return res; +} + +TmaConnResult TmaUsbComms::SendPacket(TmaPacket *packet) { + std::scoped_lock lk{g_send_mutex}; + TmaConnResult res = TmaConnResult::Success; + + if (!g_initialized || packet == nullptr) { + return TmaConnResult::GeneralFailure; + } + + /* Ensure our packets have the correct checksums. */ + packet->SetChecksums(); + + /* Send the packet. */ + size_t written = 0; + const u32 body_len = packet->GetBodyLength(); + if (body_len <= sizeof(g_send_data_buf)) { + /* Copy header to send buffer. */ + packet->CopyHeaderTo(g_send_data_buf); + + /* Send the packet header. */ + if (R_SUCCEEDED(UsbXfer(g_endpoint_in, &written, g_send_data_buf, sizeof(TmaPacket::Header)))) { + if (written == sizeof(TmaPacket::Header)) { + res = TmaConnResult::Success; + } else { + res = TmaConnResult::GeneralFailure; + } + } else { + res = TmaConnResult::GeneralFailure; + } + + if (res == TmaConnResult::Success) { + /* Copy body to send buffer. */ + packet->CopyBodyTo(g_send_data_buf); + + + /* Send the packet body. */ + if (R_SUCCEEDED(UsbXfer(g_endpoint_in, &written, g_send_data_buf, body_len))) { + if (written == body_len) { + res = TmaConnResult::Success; + } else { + res = TmaConnResult::GeneralFailure; + } + } else { + res = TmaConnResult::GeneralFailure; + } + } + } else { + res = TmaConnResult::GeneralFailure; + } + + return res; +} diff --git a/stratosphere/tma/source/tma_usb_comms.hpp b/stratosphere/tma/source/tma_usb_comms.hpp new file mode 100644 index 000000000..9815e7634 --- /dev/null +++ b/stratosphere/tma/source/tma_usb_comms.hpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2018 Atmosphère-NX + * + * This program is free software; you can redistribute it and/or modify it + * under the terms and conditions of the GNU General Public License, + * version 2, as published by the Free Software Foundation. + * + * This program is distributed in the hope it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for + * more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#include +#include + +#include "tma_conn_result.hpp" +#include "tma_conn_packet.hpp" + +class TmaUsbComms { + private: + static void UsbStateChangeThreadFunc(void *arg); + static Result UsbXfer(UsbDsEndpoint *ep, size_t *out_xferd, void *buf, size_t size); + public: + static TmaConnResult Initialize(); + static TmaConnResult Finalize(); + static void CancelComms(); + static TmaConnResult ReceivePacket(TmaPacket *packet); + static TmaConnResult SendPacket(TmaPacket *packet); + + static void SetStateChangeCallback(void (*callback)(void *, u32), void *arg); +}; \ No newline at end of file