diff --git a/Cargo.lock b/Cargo.lock index 00f5318..fa6fbce 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -794,6 +794,15 @@ version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a" +[[package]] +name = "nrpc" +version = "0.6.0" +dependencies = [ + "async-trait", + "bytes", + "prost", +] + [[package]] name = "nrpc" version = "0.6.0" @@ -807,11 +816,9 @@ dependencies = [ [[package]] name = "nrpc-build" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b598ecce0e6d4b2cb367143696174ae24bff5eb4aeb1d8eccffbfeef75fc68e" +version = "0.7.0" dependencies = [ - "nrpc", + "nrpc 0.6.0", "prettyplease 0.2.4", "proc-macro2", "prost-build", @@ -1495,8 +1502,9 @@ dependencies = [ "gettext-ng", "hex", "log", - "nrpc", + "nrpc 0.6.0 (registry+https://github.com/rust-lang/crates.io-index)", "obfstr", + "prost", "ratchet_rs", "tokio", "usdpl-build", @@ -1535,7 +1543,7 @@ dependencies = [ "gloo-net", "hex", "js-sys", - "nrpc", + "nrpc 0.6.0 (registry+https://github.com/rust-lang/crates.io-index)", "obfstr", "prost", "usdpl-build", diff --git a/usdpl-back/Cargo.toml b/usdpl-back/Cargo.toml index 26ba550..1037037 100644 --- a/usdpl-back/Cargo.toml +++ b/usdpl-back/Cargo.toml @@ -21,6 +21,7 @@ log = "0.4" # gRPC/protobuf nrpc = "0.6" async-lock = "2.7" +prost = "0.11" # websocket framework ratchet_rs = { version = "0.3", features = [ "deflate" ] } diff --git a/usdpl-back/src/lib.rs b/usdpl-back/src/lib.rs index fc4c052..cb2e8ce 100644 --- a/usdpl-back/src/lib.rs +++ b/usdpl-back/src/lib.rs @@ -43,3 +43,15 @@ pub mod api { pub mod core { pub use usdpl_core::*; } + +/// nrpc re-export +pub mod nrpc { + pub use nrpc::*; +} + +/// nRPC-generated exports +#[allow(missing_docs)] +#[allow(dead_code)] +pub mod services { + include!(concat!(env!("OUT_DIR"), "/mod.rs")); +} diff --git a/usdpl-back/src/websockets.rs b/usdpl-back/src/websockets.rs index 9ad4fc0..9ad22f9 100644 --- a/usdpl-back/src/websockets.rs +++ b/usdpl-back/src/websockets.rs @@ -30,6 +30,12 @@ impl WebsocketServer { &mut self.services } + /// Register a nRPC service for this server to handle + pub fn register(mut self, service: S) -> Self { + self.services.register(service); + self + } + /// Run the web server forever, asynchronously pub async fn run(&self) -> std::io::Result<()> { #[cfg(debug_assertions)] @@ -46,6 +52,15 @@ impl WebsocketServer { Ok(()) } + #[cfg(feature = "blocking")] + /// Run the server forever, blocking the current thread + pub fn run_blocking(self) -> std::io::Result<()> { + let runner = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build()?; + runner.block_on(self.run()) + } + async fn connection_handler(services: ServiceRegistry<'static>, stream: TcpStream) -> Result<(), RatchetError> { let upgraded = ratchet_rs::accept_with( stream, diff --git a/usdpl-build/Cargo.toml b/usdpl-build/Cargo.toml index ff252ba..9b1f1ad 100644 --- a/usdpl-build/Cargo.toml +++ b/usdpl-build/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -nrpc-build = "0.6" +nrpc-build = { version = "0.7", path = "../../nRPC/nrpc-build" } prost-build = "0.11" prost-types = "0.11" diff --git a/usdpl-build/src/back/mod.rs b/usdpl-build/src/back/mod.rs index 65d0120..21affc5 100644 --- a/usdpl-build/src/back/mod.rs +++ b/usdpl-build/src/back/mod.rs @@ -1,7 +1,7 @@ pub fn build() { crate::dump_protos_out().unwrap(); nrpc_build::compile_servers( - crate::all_proto_filenames().map(|n| crate::proto_out_path().clone().join(n)), - [crate::proto_out_path()] + crate::all_proto_filenames(crate::proto_builtins_out_path()), + crate::proto_out_paths() ) } diff --git a/usdpl-build/src/front/mod.rs b/usdpl-build/src/front/mod.rs index 698aa9d..37efa4b 100644 --- a/usdpl-build/src/front/mod.rs +++ b/usdpl-build/src/front/mod.rs @@ -11,8 +11,8 @@ pub fn build() { let shared_state = SharedState::new(); crate::dump_protos_out().unwrap(); nrpc_build::Transpiler::new( - crate::all_proto_filenames().map(|n| crate::proto_out_path().clone().join(n)), - [crate::proto_out_path()] + crate::all_proto_filenames(crate::proto_builtins_out_path()), + crate::proto_out_paths() ).unwrap() .generate_client() .with_preprocessor(nrpc_build::AbstractImpl::outer(WasmProtoPreprocessor::with_state(&shared_state))) diff --git a/usdpl-build/src/front/service_generator.rs b/usdpl-build/src/front/service_generator.rs index 173a757..9c53698 100644 --- a/usdpl-build/src/front/service_generator.rs +++ b/usdpl-build/src/front/service_generator.rs @@ -31,7 +31,7 @@ fn generate_service_methods(service: &Service, fds: &FileDescriptorSet) -> proc_ let mut params_to_fields = Vec::with_capacity(input_type.field.len()); for field in &input_type.field { //let param_name = quote::format_ident!("val{}", i.to_string()); - let type_name = translate_type(field, &service.name); + let type_name = ProtobufType::from_field(field, &service.name).to_tokens(); let field_name = quote::format_ident!("{}", field.name.as_ref().expect("Protobuf message field needs a name")); input_params.push(quote::quote!{ #field_name: #type_name, @@ -125,15 +125,6 @@ fn find_field<'a>(want_field: &str, descriptor: &'a DescriptorProto) -> Option<& None } -fn translate_type(field: &FieldDescriptorProto, service: &str) -> proc_macro2::TokenStream { - if let Some(type_name) = &field.type_name { - translate_type_name(type_name, service) - } else { - let number = field.r#type.unwrap(); - translate_type_known(number) - } -} - fn generate_wasm_struct_interop(descriptor: &DescriptorProto, handled_enums: &mut HashSet, handled_types: &mut HashSet, is_response_msg: bool, service: &str) -> proc_macro2::TokenStream { let msg_name = quote::format_ident!("{}{}", service, descriptor.name.as_ref().expect("Protobuf message needs a name")); let super_msg_name = quote::format_ident!("{}", descriptor.name.as_ref().expect("Protobuf message needs a name")); @@ -153,35 +144,107 @@ fn generate_wasm_struct_interop(descriptor: &DescriptorProto, handled_enums: &mu let special_fn_from = quote::format_ident!("{}_convert_from", name.split('.').last().unwrap().to_lowercase()); let special_fn_into = quote::format_ident!("{}_convert_into", name.split('.').last().unwrap().to_lowercase()); let key_field = find_field("key", descriptor).expect("Protobuf map entry has no key field"); - let key_type = translate_type(&key_field, service); + let key_type = ProtobufType::from_field(&key_field, service); let value_field = find_field("value", descriptor).expect("Protobuf map entry has no value field"); - let value_type = translate_type(&value_field, service); + let value_type = ProtobufType::from_field(&value_field, service); + + let key_type_tokens = key_type.to_tokens(); + let value_type_tokens = value_type.to_tokens(); + + let (fn_from, fn_into) = match (key_type, value_type) { + (ProtobufType::String, ProtobufType::String) => ( + quote::quote!{ + #[inline] + #[allow(dead_code)] + fn #special_fn_from(other: ::std::collections::HashMap<#key_type_tokens, #value_type_tokens>) -> #msg_name { + let map = #msg_name::new(); + for (key, val) in other.iter() { + map.set(&key.into(), &val.into()); + } + map + } + }, + quote::quote!{ + #[inline] + #[allow(dead_code)] + fn #special_fn_into(this: #msg_name) -> ::std::collections::HashMap<#key_type_tokens, #value_type_tokens> { + let mut output = ::std::collections::HashMap::<#key_type_tokens, #value_type_tokens>::new(); + this.for_each(&mut |key: ::wasm_bindgen::JsValue, val: ::wasm_bindgen::JsValue| { + if let Some(key) = key.as_string() { + if let Some(val) = val.as_string() { + output.insert(key, val); + } + } + }); + output + } + } + ), + (ProtobufType::String, ProtobufType::Double | ProtobufType::Float | ProtobufType::Int32| ProtobufType::Int64| ProtobufType::Uint32| ProtobufType::Uint64| ProtobufType::Sint32| ProtobufType::Sint64| ProtobufType::Fixed32| ProtobufType::Fixed64| ProtobufType::Sfixed32| ProtobufType::Sfixed64) => ( + quote::quote!{ + #[inline] + #[allow(dead_code)] + fn #special_fn_from(other: ::std::collections::HashMap<#key_type_tokens, #value_type_tokens>) -> #msg_name { + let map = #msg_name::new(); + for (key, val) in other.iter() { + map.set(&key.into(), &(val as f64).into()); + } + map + } + }, + quote::quote!{ + #[inline] + #[allow(dead_code)] + fn #special_fn_into(this: #msg_name) -> ::std::collections::HashMap<#key_type_tokens, #value_type_tokens> { + let mut output = ::std::collections::HashMap::<#key_type_tokens, #value_type_tokens>::new(); + this.for_each(&mut |key: ::wasm_bindgen::JsValue, val: ::wasm_bindgen::JsValue| { + if let Some(key) = key.as_string() { + if let Some(val) = val.as_f64() { + output.insert(key, val as _); + } + } + }); + output + } + } + ), + (ProtobufType::String, ProtobufType::Bool) => ( + quote::quote!{ + #[inline] + #[allow(dead_code)] + fn #special_fn_from(other: ::std::collections::HashMap<#key_type_tokens, #value_type_tokens>) -> #msg_name { + let map = #msg_name::new(); + for (key, val) in other.iter() { + map.set(&key.into(), &(val as f64).into()); + } + map + } + }, + quote::quote!{ + #[inline] + #[allow(dead_code)] + fn #special_fn_into(this: #msg_name) -> ::std::collections::HashMap<#key_type_tokens, #value_type_tokens> { + let mut output = ::std::collections::HashMap::<#key_type_tokens, #value_type_tokens>::new(); + this.for_each(&mut |key: ::wasm_bindgen::JsValue, val: ::wasm_bindgen::JsValue| { + if let Some(key) = key.as_string() { + if let Some(val) = val.as_bool() { + output.insert(key, val); + } + } + }); + output + } + } + ), + (key_type, value_type) => panic!("Unsupported map type map<{:?}, {:?}>", key_type, value_type), + }; + return quote::quote!{ pub type #msg_name = ::js_sys::Map; - #[inline] - #[allow(dead_code)] - fn #special_fn_from(other: ::std::collections::HashMap<#key_type, #value_type>) -> #msg_name { - let map = #msg_name::new(); - for (key, val) in other.iter() { - map.set(&key.into(), &val.into()); - } - map - } + #fn_from - #[inline] - #[allow(dead_code)] - fn #special_fn_into(this: #msg_name) -> ::std::collections::HashMap<#key_type, #value_type> { - let mut output = ::std::collections::HashMap::<#key_type, #value_type>::new(); - this.for_each(&mut |key: ::wasm_bindgen::JsValue, val: ::wasm_bindgen::JsValue| { - if let Some(key) = key.as_string() { - if let Some(val) = val.as_string() { - output.insert(key, val); - } - } - }); - output - } + #fn_into } } } else { @@ -207,7 +270,7 @@ fn generate_wasm_struct_interop(descriptor: &DescriptorProto, handled_enums: &mu if descriptor.field.len() == 1 { let field = &descriptor.field[0]; let field_name = quote::format_ident!("{}", field.name.as_ref().expect("Protobuf message field needs a name")); - let type_name = translate_type(field, service); + let type_name = ProtobufType::from_field(field, service).to_tokens(); gen_fields.push(quote::quote!{ pub #field_name: #type_name, }); @@ -267,7 +330,7 @@ fn generate_wasm_struct_interop(descriptor: &DescriptorProto, handled_enums: &mu } else { for field in &descriptor.field { let field_name = quote::format_ident!("{}", field.name.as_ref().expect("Protobuf message field needs a name")); - let type_name = translate_type(field, service); + let type_name = ProtobufType::from_field(field, service).to_tokens(); gen_fields.push(quote::quote!{ pub #field_name: #type_name, }); @@ -357,51 +420,100 @@ fn generate_wasm_struct_interop(descriptor: &DescriptorProto, handled_enums: &mu } -fn translate_type_name(name: &str, service: &str) -> proc_macro2::TokenStream { - match name { - "double" => quote::quote!{f64}, - "float" => quote::quote!{f32}, - "int32" => quote::quote!{i32}, - "int64" => quote::quote!{i64}, - "uint32" => quote::quote!{u32}, - "uint64" => quote::quote!{u64}, - "sint32" => quote::quote!{i32}, - "sint64" => quote::quote!{i64}, - "fixed32" => quote::quote!{u32}, - "fixed64" => quote::quote!{u64}, - "sfixed32" => quote::quote!{i32}, - "sfixed64" => quote::quote!{i64}, - "bool" => quote::quote!{bool}, - "string" => quote::quote!{String}, - "bytes" => quote::quote!{Vec}, - t => { - let ident = quote::format_ident!("{}{}", service, t.split('.').last().unwrap()); - quote::quote!{#ident} - }, - } +#[derive(Debug)] +enum ProtobufType { + Double, + Float, + Int32, + Int64, + Uint32, + Uint64, + Sint32, + Sint64, + Fixed32, + Fixed64, + Sfixed32, + Sfixed64, + Bool, + String, + Bytes, + Custom(String), } -fn translate_type_known(id: i32) -> proc_macro2::TokenStream { - match id { - //"double" => quote::quote!{f64}, - //"float" => quote::quote!{f32}, - //"int32" => quote::quote!{i32}, - //"int64" => quote::quote!{i64}, - //"uint32" => quote::quote!{u32}, - //"uint64" => quote::quote!{u64}, - //"sint32" => quote::quote!{i32}, - //"sint64" => quote::quote!{i64}, - //"fixed32" => quote::quote!{u32}, - //"fixed64" => quote::quote!{u64}, - //"sfixed32" => quote::quote!{i32}, - //"sfixed64" => quote::quote!{i64}, - //"bool" => quote::quote!{bool}, - 9 => quote::quote!{String}, - //"bytes" => quote::quote!{Vec}, - t => { - let ident = quote::format_ident!("UnknownType{}", t.to_string()); - quote::quote!{#ident} - }, +impl ProtobufType { + fn from_str(type_name: &str, service: &str) -> Self { + match type_name { + "double" => Self::Double, + "float" => Self::Float, + "int32" => Self::Int32, + "int64" => Self::Int64, + "uint32" => Self::Uint32, + "uint64" => Self::Uint64, + "sint32" => Self::Sint32, + "sint64" => Self::Sint64, + "fixed32" => Self::Fixed32, + "fixed64" => Self::Fixed64, + "sfixed32" => Self::Sfixed32, + "sfixed64" => Self::Sfixed64, + "bool" => Self::Bool, + "string" => Self::String, + "bytes" => Self::Bytes, + t => Self::Custom(format!("{}{}", service, t.split('.').last().unwrap())), + } + } + + fn from_id(id: i32) -> Self { + match id { + //"double" => quote::quote!{f64}, + //"float" => quote::quote!{f32}, + //"int32" => quote::quote!{i32}, + //"int64" => quote::quote!{i64}, + //"uint32" => quote::quote!{u32}, + //"uint64" => quote::quote!{u64}, + //"sint32" => quote::quote!{i32}, + //"sint64" => quote::quote!{i64}, + //"fixed32" => quote::quote!{u32}, + //"fixed64" => quote::quote!{u64}, + //"sfixed32" => quote::quote!{i32}, + //"sfixed64" => quote::quote!{i64}, + //"bool" => quote::quote!{bool}, + 9 => Self::String, + //"bytes" => quote::quote!{Vec}, + t => Self::Custom(format!("UnknownType{}", t)), + } + } + + fn from_field(field: &FieldDescriptorProto, service: &str) -> Self { + if let Some(type_name) = &field.type_name { + Self::from_str(type_name, service) + } else { + let number = field.r#type.unwrap(); + Self::from_id(number) + } + } + + fn to_tokens(&self) -> proc_macro2::TokenStream { + match self { + Self::Double => quote::quote!{f64}, + Self::Float => quote::quote!{f32}, + Self::Int32 => quote::quote!{i32}, + Self::Int64 => quote::quote!{i64}, + Self::Uint32 => quote::quote!{u32}, + Self::Uint64 => quote::quote!{u64}, + Self::Sint32 => quote::quote!{i32}, + Self::Sint64 => quote::quote!{i64}, + Self::Fixed32 => quote::quote!{u32}, + Self::Fixed64 => quote::quote!{u64}, + Self::Sfixed32 => quote::quote!{i32}, + Self::Sfixed64 => quote::quote!{i64}, + Self::Bool => quote::quote!{bool}, + Self::String => quote::quote!{String}, + Self::Bytes => quote::quote!{Vec}, + Self::Custom(t) => { + let ident = quote::format_ident!("{}", t); + quote::quote!{#ident} + }, + } } } diff --git a/usdpl-build/src/lib.rs b/usdpl-build/src/lib.rs index 31f742f..d543e54 100644 --- a/usdpl-build/src/lib.rs +++ b/usdpl-build/src/lib.rs @@ -2,4 +2,4 @@ pub mod back; pub mod front; mod proto_files; -pub use proto_files::{dump_protos, dump_protos_out, proto_out_path, all_proto_filenames}; +pub use proto_files::{dump_protos, dump_protos_out, proto_out_paths, all_proto_filenames, proto_builtins_out_path}; diff --git a/usdpl-build/src/proto_files.rs b/usdpl-build/src/proto_files.rs index 818c728..58e4bed 100644 --- a/usdpl-build/src/proto_files.rs +++ b/usdpl-build/src/proto_files.rs @@ -5,6 +5,8 @@ struct IncludedFileStr<'a> { contents: &'a str, } +const ADDITIONAL_PROTOBUFS_ENV_VAR: &'static str = "USDPL_PROTOS_PATH"; + const DEBUG_PROTO: IncludedFileStr<'static> = IncludedFileStr { filename: "debug.proto", contents: include_str!("../protos/debug.proto"), @@ -20,12 +22,40 @@ const ALL_PROTOS: [IncludedFileStr<'static>; 2] = [ TRANSLATIONS_PROTO, ]; -pub fn proto_out_path() -> PathBuf { +pub fn proto_builtins_out_path() -> PathBuf { PathBuf::from(std::env::var("OUT_DIR").expect("Not in a build.rs context (missing $OUT_DIR)")).join("protos") } -pub fn all_proto_filenames() -> impl Iterator { - ALL_PROTOS.iter().map(|x| x.filename) +pub fn proto_out_paths() -> impl Iterator { + std::iter::once(proto_builtins_out_path()) + .map(|x| x.to_str().unwrap().to_owned()) + .chain(custom_protos_dirs().into_iter()) +} + +fn custom_protos_dirs() -> Vec { + let dirs = std::env::var(ADDITIONAL_PROTOBUFS_ENV_VAR).unwrap_or_else(|_| "".to_owned()); + dirs.split(':') + .filter(|x| std::fs::read_dir(x).is_ok()) + .map(|x| x.to_owned()) + .collect() +} + +fn custom_protos_filenames() -> Vec { + let dirs = std::env::var(ADDITIONAL_PROTOBUFS_ENV_VAR).unwrap_or_else(|_| "".to_owned()); + dirs.split(':') + .map(std::fs::read_dir) + .filter(|x| x.is_ok()) + .flat_map(|x| x.unwrap()) + .filter(|x| x.is_ok()) + .map(|x| x.unwrap().path()) + .filter(|x| if let Some(ext) = x.extension() { ext.to_ascii_lowercase() == "proto" && x.is_file() } else { false }) + .filter_map(|x| x.to_str().map(|x| x.to_owned())) + .collect() +} + +pub fn all_proto_filenames(p: impl AsRef + 'static) -> impl Iterator { + //let p = p.as_ref(); + ALL_PROTOS.iter().map(move |x| p.as_ref().join(x.filename).to_str().unwrap().to_owned()).chain(custom_protos_filenames()) } pub fn dump_protos(p: impl AsRef) -> std::io::Result<()> { @@ -38,7 +68,7 @@ pub fn dump_protos(p: impl AsRef) -> std::io::Result<()> { } pub fn dump_protos_out() -> std::io::Result<()> { - let path = proto_out_path(); + let path = proto_builtins_out_path(); std::fs::create_dir_all(&path)?; dump_protos(&path) }