From 72c7f111e8d6fe82fd238eb7dd2e9b9018876f6b Mon Sep 17 00:00:00 2001 From: "NGnius (Graham)" Date: Tue, 29 Aug 2023 19:49:17 -0400 Subject: [PATCH] Fix some generation bugs to get Fantastic to compile --- Cargo.lock | 16 +++--- usdpl-back/Cargo.toml | 2 +- usdpl-back/src/rpc/mod.rs | 5 +- usdpl-back/src/rpc/registry.rs | 19 +++--- usdpl-back/src/rpc/websocket_stream.rs | 34 +++++++++++ usdpl-back/src/websockets.rs | 49 ++++++++++++---- usdpl-build/src/front/service_generator.rs | 67 ++++++++++++++++------ usdpl-front/src/wasm/streaming.rs | 5 +- 8 files changed, 148 insertions(+), 49 deletions(-) create mode 100644 usdpl-back/src/rpc/websocket_stream.rs diff --git a/Cargo.lock b/Cargo.lock index b8c205e..57d6178 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1152,9 +1152,9 @@ dependencies = [ [[package]] name = "ratchet_core" -version = "0.3.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "854bf6632d9f5c7fa7f77cbc332f2b0a8dfb2acc36c3f351fc36bf40f2759728" +checksum = "faed301a9f297e8cd3617a2bc79ed17eefa88d5873ed08517c96628b48d1f386" dependencies = [ "base64", "bitflags", @@ -1176,9 +1176,9 @@ dependencies = [ [[package]] name = "ratchet_deflate" -version = "0.3.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0b144cb23a76d810b25737f4b87943fdfd7772b423bdc15c2b3820849207adc" +checksum = "77238362df52f64482e0bd1c413d2d3d0e20052056ba4d88918ef2e962c86f11" dependencies = [ "bytes", "flate2", @@ -1190,9 +1190,9 @@ dependencies = [ [[package]] name = "ratchet_ext" -version = "0.3.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67f97bb0776d195720319a1e9f08fa343fe3f9f0b7ebf9d97d5926ce50b8e1ad" +checksum = "35f5bf3bd015a94b77730229e895e03af945627984ee5c4f95d40fd9227ea36b" dependencies = [ "bytes", "http", @@ -1201,9 +1201,9 @@ dependencies = [ [[package]] name = "ratchet_rs" -version = "0.3.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7dba456fc23026b46ce0936d109ce3e73b4a592baf0dda0f83d49886c5e5f83" +checksum = "62d326d7cd4227a7f58b36c1efb16b348f7e2e43e1d1ef032e9b094ff6cec583" dependencies = [ "ratchet_core", "ratchet_deflate", diff --git a/usdpl-back/Cargo.toml b/usdpl-back/Cargo.toml index ff6e69d..872c61d 100644 --- a/usdpl-back/Cargo.toml +++ b/usdpl-back/Cargo.toml @@ -24,7 +24,7 @@ async-lock = "2.7" prost = "0.11" # websocket framework -ratchet_rs = { version = "0.3", features = [ "deflate" ] } +ratchet_rs = { version = "0.4", features = [ "deflate" ] } # HTTP web framework #warp = { version = "0.3" } diff --git a/usdpl-back/src/rpc/mod.rs b/usdpl-back/src/rpc/mod.rs index 6dddc35..506bfd6 100644 --- a/usdpl-back/src/rpc/mod.rs +++ b/usdpl-back/src/rpc/mod.rs @@ -1,2 +1,5 @@ mod registry; -pub use registry::ServiceRegistry; +pub use registry::{ServiceRegistry, StaticServiceRegistry}; + +mod websocket_stream; +pub use websocket_stream::ws_stream; diff --git a/usdpl-back/src/rpc/registry.rs b/usdpl-back/src/rpc/registry.rs index 0560696..26cd152 100644 --- a/usdpl-back/src/rpc/registry.rs +++ b/usdpl-back/src/rpc/registry.rs @@ -2,11 +2,13 @@ use async_lock::Mutex; use std::collections::HashMap; use std::sync::Arc; -use nrpc::{ServerService, ServiceError}; +use nrpc::{ServerService, ServiceError, ServiceServerStream}; + +pub type StaticServiceRegistry = ServiceRegistry<'static>; #[derive(Default, Clone)] pub struct ServiceRegistry<'a> { - entries: HashMap>>>, + entries: HashMap + Send + 'a>>>>, } impl<'a> ServiceRegistry<'a> { @@ -19,23 +21,22 @@ impl<'a> ServiceRegistry<'a> { format!("{}.{}", package, service) }*/ - pub async fn call_descriptor( - &self, + pub async fn call_descriptor<'b: 'a>( + &mut self, descriptor: &str, method: &str, - data: bytes::Bytes, - ) -> Result { + input: ServiceServerStream<'a, bytes::Bytes>, + ) -> Result, ServiceError> { if let Some(service) = self.entries.get(descriptor) { - let mut output = bytes::BytesMut::new(); let mut service_lock = service.lock_arc().await; - service_lock.call(method, data, &mut output).await?; + let output = service_lock.call(method, input).await?; Ok(output.into()) } else { Err(ServiceError::ServiceNotFound) } } - pub fn register(&mut self, service: S) -> &mut Self { + pub fn register + Send + 'a>(&mut self, service: S) -> &mut Self { let key = service.descriptor().to_owned(); self.entries .insert(key, Arc::new(Mutex::new(Box::new(service)))); diff --git a/usdpl-back/src/rpc/websocket_stream.rs b/usdpl-back/src/rpc/websocket_stream.rs new file mode 100644 index 0000000..ea44c00 --- /dev/null +++ b/usdpl-back/src/rpc/websocket_stream.rs @@ -0,0 +1,34 @@ +use core::marker::Unpin; +use std::sync::Arc; + +use tokio::{net::TcpStream, sync::Mutex}; +use ratchet_rs::{WebSocket, Message, Error as RatchetError, Extension}; + +use nrpc::ServiceError; +use nrpc::_helpers::futures::Stream; +use nrpc::_helpers::bytes::{BytesMut, Bytes}; + +struct WsStreamState{ + ws: Arc>>, + buf: BytesMut, +} + +pub fn ws_stream<'a, T: Extension + Unpin + 'a>(ws: Arc>>) -> impl Stream> + 'a { + nrpc::_helpers::futures::stream::unfold(WsStreamState { ws, buf: BytesMut::new() }, |mut state| async move { + let mut locked_ws = state.ws.lock().await; + if locked_ws.is_closed() || !locked_ws.is_active() { + None + } else { + let result = locked_ws.read(&mut state.buf).await; + drop(locked_ws); + match result { + Ok(Message::Binary) => Some((Ok(state.buf.clone().freeze()), state)), + Ok(_) => Some((Err(ServiceError::Method(Box::new(RatchetError::with_cause( + ratchet_rs::ErrorKind::Protocol, + "Websocket text messages are not accepted", + )))), state)), + Err(e) => Some((Err(ServiceError::Method(Box::new(e))), state)) + } + } + }) +} diff --git a/usdpl-back/src/websockets.rs b/usdpl-back/src/websockets.rs index 9b30695..9658e2c 100644 --- a/usdpl-back/src/websockets.rs +++ b/usdpl-back/src/websockets.rs @@ -1,9 +1,10 @@ -use bytes::BytesMut; use ratchet_rs::deflate::DeflateExtProvider; -use ratchet_rs::{Error as RatchetError, Message, ProtocolRegistry, WebSocketConfig}; +use ratchet_rs::{Error as RatchetError, ProtocolRegistry, WebSocketConfig}; use tokio::net::{TcpListener, TcpStream}; -use crate::rpc::ServiceRegistry; +use nrpc::_helpers::futures::StreamExt; + +use crate::rpc::StaticServiceRegistry; struct MethodDescriptor<'a> { service: &'a str, @@ -12,7 +13,7 @@ struct MethodDescriptor<'a> { /// Handler for communication to and from the front-end pub struct WebsocketServer { - services: ServiceRegistry<'static>, + services: StaticServiceRegistry, port: u16, } @@ -20,18 +21,18 @@ impl WebsocketServer { /// Initialise an instance of the back-end websocket server pub fn new(port_usdpl: u16) -> Self { Self { - services: ServiceRegistry::new(), + services: StaticServiceRegistry::new(), port: port_usdpl, } } /// Get the service registry that the server handles - pub fn registry(&mut self) -> &'_ mut ServiceRegistry<'static> { + pub fn registry(&mut self) -> &'_ mut StaticServiceRegistry { &mut self.services } /// Register a nRPC service for this server to handle - pub fn register(mut self, service: S) -> Self { + pub fn register + Send + 'static>(mut self, service: S) -> Self { self.services.register(service); self } @@ -62,7 +63,7 @@ impl WebsocketServer { } async fn connection_handler( - services: ServiceRegistry<'static>, + mut services: StaticServiceRegistry, stream: TcpStream, ) -> Result<(), RatchetError> { log::debug!("connection_handler invoked!"); @@ -80,12 +81,38 @@ impl WebsocketServer { log::debug!("accepted new connection on uri {}", request_path); - let mut websocket = upgraded.websocket; + let websocket = std::sync::Arc::new(tokio::sync::Mutex::new(upgraded.websocket)); let descriptor = Self::parse_uri_path(request_path) .map_err(|e| RatchetError::with_cause(ratchet_rs::ErrorKind::Protocol, e))?; - let mut buf = BytesMut::new(); + let input_stream = Box::new(nrpc::_helpers::futures::stream::StreamExt::boxed(crate::rpc::ws_stream(websocket.clone()))); + let output_stream = services + .call_descriptor( + descriptor.service, + descriptor.method, + input_stream, + ) + .await + .map_err(|e| { + RatchetError::with_cause(ratchet_rs::ErrorKind::Protocol, e.to_string()) + })?; + + output_stream.for_each_concurrent(None, |result| async { + match result { + Ok(msg) => { + let mut ws_lock = websocket.lock().await; + if let Err(e) = ws_lock.write_binary(msg).await { + log::error!("websocket error while writing response on uri {}: {}", request_path, e); + } + }, + Err(e) => { + log::error!("service error while writing response on uri {}: {}", request_path, e); + } + } + }).await; + + /*let mut buf = BytesMut::new(); loop { match websocket.read(&mut buf).await? { Message::Text => { @@ -113,7 +140,7 @@ impl WebsocketServer { Message::Pong(_) => {} Message::Close(_) => break, } - } + }*/ log::debug!("ws connection {} closed", request_path); Ok(()) } diff --git a/usdpl-build/src/front/service_generator.rs b/usdpl-build/src/front/service_generator.rs index 00189f4..5391926 100644 --- a/usdpl-build/src/front/service_generator.rs +++ b/usdpl-build/src/front/service_generator.rs @@ -436,8 +436,26 @@ fn generate_wasm_struct_interop( let type_name = type_enum.to_tokens(); let wasm_type_name = type_enum.to_wasm_tokens(); - let into_wasm_streamable = quote::quote!{self.into_wasm_streamable()}; - let from_wasm_streamable = quote::quote!{#type_name::from_wasm_streamable(js)}; + /*let wasm_streamable_impl = if type_enum.is_already_wasm_streamable() { + quote::quote!{} + } else { + let into_wasm_streamable = quote::quote!{self.into_wasm_streamable()}; + let from_wasm_streamable = quote::quote!{#type_name::from_wasm_streamable(js)}; + quote::quote!{ + impl ::usdpl_front::wasm::FromWasmStreamableType for #msg_name { + fn from_wasm_streamable(js: JsValue) -> Result { + #from_wasm_streamable + } + } + + impl ::usdpl_front::wasm::IntoWasmStreamableType for #msg_name { + fn into_wasm_streamable(self) -> JsValue { + #into_wasm_streamable + } + } + } + };*/ + quote::quote! { pub type #msg_name = #type_name; @@ -460,17 +478,7 @@ fn generate_wasm_struct_interop( } } - impl ::usdpl_front::wasm::FromWasmStreamableType for #msg_name { - fn from_wasm_streamable(js: JsValue) -> Result { - #from_wasm_streamable - } - } - - impl ::usdpl_front::wasm::IntoWasmStreamableType for #msg_name { - fn into_wasm_streamable(self) -> JsValue { - #into_wasm_streamable - } - } + // #wasm_streamable_impl #(#gen_nested_types)* @@ -733,15 +741,19 @@ impl ProtobufType { fn to_into_wasm_streamable(&self, field_name: &str, js_map_name: &syn::Ident) -> proc_macro2::TokenStream { //let type_tokens = self.to_tokens(); - //let field_ident = quote::format_ident!("{}", field_name); - quote::quote!{#js_map_name.set(#field_name.into(), self.field_ident);} + let field_ident = quote::format_ident!("{}", field_name); + quote::quote!{#js_map_name.set(&JsValue::from(#field_name), &self.#field_ident.into_wasm_streamable());} } fn to_from_wasm_streamable(&self, field_name: &str, js_map_name: &syn::Ident) -> proc_macro2::TokenStream { let type_tokens = self.to_tokens(); - //let field_ident = quote::format_ident!("{}", field_name); - quote::quote!{#field_name: #type_tokens::from_wasm_streamable(#js_map_name.get(#field_name.into()))?,} + let field_ident = quote::format_ident!("{}", field_name); + quote::quote!{#field_ident: #type_tokens::from_wasm_streamable(#js_map_name.get(&JsValue::from(#field_name)))?,} } + + /*fn is_already_wasm_streamable(&self) -> bool { + !matches!(self, Self::Custom(_)) + }*/ } fn generate_wasm_enum_interop( @@ -870,6 +882,25 @@ fn generate_wasm_enum_interop( self as i32 } } + + impl ::usdpl_front::wasm::FromWasmStreamableType for #enum_name { + fn from_wasm_streamable(js: JsValue) -> Result { + if let Some(float) = js.as_f64() { + Ok(Self::from_wasm(float as i32)) + } else { + Err(::usdpl_front::wasm::WasmStreamableConversionError::UnexpectedType { + expected: ::usdpl_front::wasm::JsType::Number, + got: ::usdpl_front::wasm::JsType::guess(&js), + }) + } + } + } + + impl ::usdpl_front::wasm::IntoWasmStreamableType for #enum_name { + fn into_wasm_streamable(self) -> JsValue { + JsValue::from(self.into_wasm()) + } + } } } @@ -1021,7 +1052,7 @@ impl IServiceGenerator for WasmServiceGenerator { #[wasm_bindgen] pub struct #service_js_name { //#[wasm_bindgen(skip)] - service: super::#service_struct_name, + service: super::#service_struct_name<'static, WebSocketHandler>, } #[wasm_bindgen] diff --git a/usdpl-front/src/wasm/streaming.rs b/usdpl-front/src/wasm/streaming.rs index 9eea695..6c18ff8 100644 --- a/usdpl-front/src/wasm/streaming.rs +++ b/usdpl-front/src/wasm/streaming.rs @@ -64,7 +64,10 @@ impl core::fmt::Display for JsType { } impl JsType { - fn guess(js: &JsValue) -> JsType { + /// Guess the JS type of the parameter. + /// This is not guaranteed to be correct, but is intended to give more information + /// in debug and error messages + pub fn guess(js: &JsValue) -> JsType { if js.as_f64().is_some() { Self::Number } else if js.as_string().is_some() {