From 965df808908b67ce2dea918295e38c9247698d3d Mon Sep 17 00:00:00 2001 From: "NGnius (Graham)" Date: Sat, 1 Jul 2023 17:09:30 -0400 Subject: [PATCH] Add gRPC streaming support --- Cargo.lock | 109 ++++++++++- nrpc-build/Cargo.toml | 4 +- nrpc-build/src/service_gen.rs | 227 ++++++++++++++++++++--- nrpc-codegen-test/proto/helloworld.proto | 9 + nrpc-codegen-test/src/main.rs | 180 +++++++++++++++++- nrpc/Cargo.toml | 3 +- nrpc/src/lib.rs | 6 +- nrpc/src/service.rs | 30 +-- nrpc/src/stream_utils.rs | 80 ++++++++ 9 files changed, 593 insertions(+), 55 deletions(-) create mode 100644 nrpc/src/stream_utils.rs diff --git a/Cargo.lock b/Cargo.lock index 4657a5e..b5d0f5d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -133,6 +133,95 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "futures" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23342abe12aba583913b2e62f22225ff9c950774065e4bfb61a19cd9770fec40" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "955518d47e09b25bbebc7a18df10b81f0c766eaf4c4f1cccef2fca5f2a4fb5f2" +dependencies = [ + "futures-core", + "futures-sink", +] + +[[package]] +name = "futures-core" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4bca583b7e26f571124fe5b7561d49cb2868d79116cfa0eefce955557c6fee8c" + +[[package]] +name = "futures-executor" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccecee823288125bd88b4d7f565c9e58e41858e47ab72e8ea2d64e93624386e0" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fff74096e71ed47f8e023204cfd0aa1289cd54ae5430a9523be060cdb849964" + +[[package]] +name = "futures-macro" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.22", +] + +[[package]] +name = "futures-sink" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f43be4fe21a13b9781a69afa4985b0f6ee0e1afab2c6f454a8cf30e2b2237b6e" + +[[package]] +name = "futures-task" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76d3d132be6c0e6aa1534069c705a74a5997a356c0dc2f86a47765e5617c5b65" + +[[package]] +name = "futures-util" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26b01e40b772d54cf6c6d721c1d1abd0647a0106a12ecaa1c186273392a69533" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "pin-utils", + "slab", +] + [[package]] name = "gimli" version = "0.27.3" @@ -328,16 +417,17 @@ checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a" [[package]] name = "nrpc" -version = "0.7.0" +version = "0.8.0" dependencies = [ "async-trait", "bytes", + "futures", "prost", ] [[package]] name = "nrpc-build" -version = "0.7.0" +version = "0.8.0" dependencies = [ "nrpc", "prettyplease 0.2.9", @@ -425,6 +515,12 @@ version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e0a7ae3ac2f1173085d398531c705756c94a4c56843785df85a60c1a0afac116" +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + [[package]] name = "prettyplease" version = "0.1.25" @@ -622,6 +718,15 @@ dependencies = [ "libc", ] +[[package]] +name = "slab" +version = "0.4.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6528351c9bc8ab22353f9d776db39a20288e8d6c37ef8cfe3317cf875eecfc2d" +dependencies = [ + "autocfg", +] + [[package]] name = "smallvec" version = "1.10.0" diff --git a/nrpc-build/Cargo.toml b/nrpc-build/Cargo.toml index 61a8a10..d104aca 100644 --- a/nrpc-build/Cargo.toml +++ b/nrpc-build/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "nrpc-build" -version = "0.7.0" +version = "0.8.0" edition = "2021" license = "Apache-2.0" repository = "https://github.com/NGnius/nRPC" @@ -21,4 +21,4 @@ quote = "1.0" syn = "2.0" proc-macro2 = "1.0" -nrpc = { version = "0.7", path = "../nrpc" } +nrpc = { version = "0.8", path = "../nrpc" } diff --git a/nrpc-build/src/service_gen.rs b/nrpc-build/src/service_gen.rs index 92af979..153aedb 100644 --- a/nrpc-build/src/service_gen.rs +++ b/nrpc-build/src/service_gen.rs @@ -47,40 +47,131 @@ impl ProtobufServiceGenerator { } } +fn stream_type(item_type: &syn::Ident) -> proc_macro2::TokenStream { + quote::quote!{ + ::nrpc::ServiceStream<'a, #item_type> + } +} + +/*fn stream_type_static_lifetime(item_type: &syn::Ident) -> proc_macro2::TokenStream { + quote::quote!{ + ::nrpc::ServiceStream<'static, #item_type> + } +}*/ + fn trait_methods_server(descriptors: &Vec) -> proc_macro2::TokenStream { let mut gen_methods = Vec::with_capacity(descriptors.len()); let mut gen_method_match_arms = Vec::with_capacity(descriptors.len()); for descriptor in descriptors { + let input_ty = quote::format_ident!("{}", descriptor.input_type); + let output_ty = quote::format_ident!("{}", descriptor.output_type); + let fn_name = quote::format_ident!("{}", descriptor.name); + let method_name = &descriptor.name; match (descriptor.client_streaming, descriptor.server_streaming) { (false, false) => { // no streaming; 1->1 - let input_ty = quote::format_ident!("{}", descriptor.input_type); - let output_ty = quote::format_ident!("{}", descriptor.output_type); - let fn_name = quote::format_ident!("{}", descriptor.name); - let method_name = &descriptor.name; gen_methods.push( quote! { - async fn #fn_name(&mut self, input: #input_ty) -> Result<#output_ty, Box>; + async fn #fn_name(&mut self, input: #input_ty) -> Result<#output_ty, Box>; } ); gen_method_match_arms.push(quote! { #method_name => { - Ok(self.#fn_name(#input_ty::decode(payload)?).await?.encode(buffer)?) + if let Some(item1_payload) = stream_in.next().await { + let item = #input_ty::decode(item1_payload?)?; + // TODO does it need to be enforced that there are no more items in the stream? + let mut buffer = ::nrpc::_helpers::bytes::BytesMut::new(); + self.#fn_name(item).await?.encode(&mut buffer)?; + Ok(Box::new(::nrpc::OnceStream::once(Ok(buffer.freeze())))) + } else { + Err(::nrpc::ServiceError::StreamLength { want: 1, got: 0 }) + } + } + }); + } + (false, true) => { + // client streaming; 1 -> many + //let stream_out_ty = stream_type_static_lifetime(&output_ty); + let stream_out_ty = stream_type(&output_ty); + gen_methods.push( + quote! { + async fn #fn_name<'a>(&mut self, input: #input_ty) -> Result<#stream_out_ty, Box>; + } + ); + + gen_method_match_arms.push(quote! { + #method_name => { + if let Some(item1_payload) = stream_in.next().await { + let item = #input_ty::decode(item1_payload?)?; + // TODO does it need to be enforced that there are no more items in the stream? + let result = self.#fn_name(item).await?; + Ok(Box::new( + result.map( + |item_result| item_result.and_then(|item| { + let mut buffer = ::nrpc::_helpers::bytes::BytesMut::new(); + item.encode(&mut buffer) + .map(|_| buffer.freeze()) + .map_err(|e| ::nrpc::ServiceError::from(e)) + }) + ) + )) + } else { + Err(::nrpc::ServiceError::StreamLength { want: 1, got: 0 }) + } } }); } (true, false) => { - // client streaming; 1 -> many - todo!("streaming not supported") - } - (false, true) => { // server streaming; many -> 1 - todo!("streaming not supported") + let stream_in_ty = stream_type(&input_ty); + gen_methods.push( + quote! { + async fn #fn_name<'a>(&mut self, input: #stream_in_ty) -> Result<#output_ty, Box>; + } + ); + + gen_method_match_arms.push(quote! { + #method_name => { + let item_stream = stream_in.map(|item_result| item_result.and_then(|item1_payload| { + #input_ty::decode(item1_payload) + .map_err(|e| ::nrpc::ServiceError::from(e)) + })); + let mut buffer = ::nrpc::_helpers::bytes::BytesMut::new(); + self.#fn_name(Box::new(item_stream)).await?.encode(&mut buffer)?; + Ok(Box::new(::nrpc::OnceStream::once(Ok(buffer.freeze())))) + } + }); } (true, true) => { // all streaming; many -> many - todo!("streaming not supported") + let stream_in_ty = stream_type(&input_ty); + let stream_out_ty = stream_type(&output_ty); + gen_methods.push( + quote! { + async fn #fn_name<'a>(&mut self, input: #stream_in_ty) -> Result<#stream_out_ty, Box>; + } + ); + + gen_method_match_arms.push(quote! { + #method_name => { + let item_stream = stream_in.map(|item_result| item_result.and_then(|item1_payload| { + #input_ty::decode(item1_payload) + .map_err(|e| ::nrpc::ServiceError::from(e)) + })); + let result = self.#fn_name(Box::new(item_stream)).await?; + Ok(Box::new( + result.map( + |item_result| item_result.and_then(|item| { + let mut buffer = ::nrpc::_helpers::bytes::BytesMut::new(); + item.encode(&mut buffer) + .map(|_| buffer.freeze()) + .map_err(|e| ::nrpc::ServiceError::from(e)) + }) + ) + )) + } + }); } } } @@ -88,7 +179,18 @@ fn trait_methods_server(descriptors: &Vec) -> proc_macro2:: quote! { #(#gen_methods)* - async fn call(&mut self, method: &str, payload: ::nrpc::_helpers::bytes::Bytes, buffer: &mut ::nrpc::_helpers::bytes::BytesMut) -> Result<(), ::nrpc::ServiceError> { + /*async fn call(&mut self, method: &str, payload: ::nrpc::_helpers::bytes::Bytes, buffer: &mut ::nrpc::_helpers::bytes::BytesMut) -> Result<(), ::nrpc::ServiceError> { + match method { + #(#gen_method_match_arms)* + _ => Err(::nrpc::ServiceError::MethodNotFound) + } + }*/ + + async fn call<'a>( + &mut self, + method: &str, + mut stream_in: ::nrpc::ServiceStream<'a, ::nrpc::_helpers::bytes::Bytes>, + ) -> Result<::nrpc::ServiceStream<'a, ::nrpc::_helpers::bytes::Bytes>, ::nrpc::ServiceError> { match method { #(#gen_method_match_arms)* _ => Err(::nrpc::ServiceError::MethodNotFound) @@ -104,36 +206,99 @@ fn struct_methods_client( ) -> proc_macro2::TokenStream { let mut gen_methods = Vec::with_capacity(descriptors.len()); for descriptor in descriptors { + let input_ty = quote::format_ident!("{}", descriptor.input_type); + let output_ty = quote::format_ident!("{}", descriptor.output_type); + let fn_name = quote::format_ident!("{}", descriptor.name); + let method_name = &descriptor.name; match (descriptor.client_streaming, descriptor.server_streaming) { (false, false) => { // no streaming; 1->1 - let input_ty = quote::format_ident!("{}", descriptor.input_type); - let output_ty = quote::format_ident!("{}", descriptor.output_type); - let fn_name = quote::format_ident!("{}", descriptor.name); - let method_name = &descriptor.name; gen_methods.push( quote! { pub async fn #fn_name(&self, input: #input_ty) -> Result<#output_ty, ::nrpc::ServiceError> { let mut in_buf = ::nrpc::_helpers::bytes::BytesMut::new(); input.encode(&mut in_buf)?; - let mut out_buf = ::nrpc::_helpers::bytes::BytesMut::new(); - self.inner.call(#package_name, #service_name, #method_name, in_buf.into(), &mut out_buf).await?; - Ok(#output_ty::decode(out_buf)?) + let in_stream = ::nrpc::OnceStream::once(Ok(in_buf.freeze())); + let mut result_stream = self.inner.call(#package_name, #service_name, #method_name, Box::new( in_stream)).await?; + if let Some(out_result) = result_stream.next().await { + Ok(#output_ty::decode(out_result?)?) + } else { + Err(::nrpc::ServiceError::StreamLength { want: 1, got: 0 }) + } + + } + } + ); + } + (false, true) => { + // client streaming; 1 -> many + let stream_out_ty = stream_type(&output_ty); + gen_methods.push( + quote! { + pub async fn #fn_name<'a>(&self, input: #input_ty) -> Result<#stream_out_ty, ::nrpc::ServiceError> { + let mut in_buf = ::nrpc::_helpers::bytes::BytesMut::new(); + input.encode(&mut in_buf)?; + let in_stream = ::nrpc::OnceStream::once(Ok(in_buf.freeze())); + let result_stream = self.inner.call(#package_name, #service_name, #method_name, Box::new(in_stream)).await?; + let item_stream = result_stream.map(|out_result| + out_result.and_then(|out_buf| #output_ty::decode(out_buf) + .map_err(|e| ::nrpc::ServiceError::from(e)) + ) + ); + Ok(Box::new(item_stream)) } } ); } (true, false) => { - // client streaming; 1 -> many - todo!("streaming not supported") - } - (false, true) => { // server streaming; many -> 1 - todo!("streaming not supported") + let stream_in_ty = stream_type(&input_ty); + gen_methods.push( + quote! { + pub async fn #fn_name<'a>(&self, input: #stream_in_ty) -> Result<#output_ty, ::nrpc::ServiceError> { + let in_stream = input.map(|item_result| { + let mut in_buf = ::nrpc::_helpers::bytes::BytesMut::new(); + item_result.and_then(|item| item.encode(&mut in_buf) + .map(|_| in_buf.freeze()) + .map_err(|e| ::nrpc::ServiceError::from(e)) + ) + }); + let mut result_stream = self.inner.call(#package_name, #service_name, #method_name, Box::new(in_stream)).await?; + if let Some(out_result) = result_stream.next().await { + Ok(#output_ty::decode(out_result?)?) + } else { + Err(::nrpc::ServiceError::StreamLength { want: 1, got: 0 }) + } + + } + } + ); } (true, true) => { // all streaming; many -> many - todo!("streaming not supported") + let stream_in_ty = stream_type(&input_ty); + let stream_out_ty = stream_type(&output_ty); + gen_methods.push( + quote! { + pub async fn #fn_name<'a>(&self, input: #stream_in_ty) -> Result<#stream_out_ty, ::nrpc::ServiceError> { + let in_stream = input.map(|item_result| { + let mut in_buf = ::nrpc::_helpers::bytes::BytesMut::new(); + item_result.and_then(|item| item.encode(&mut in_buf) + .map(|_| in_buf.freeze()) + .map_err(|e| ::nrpc::ServiceError::from(e)) + ) + }); + let result_stream = self.inner.call(#package_name, #service_name, #method_name, Box::new(in_stream)).await?; + let item_stream = result_stream.map(|out_result| + out_result.and_then(|out_buf| #output_ty::decode(out_buf) + .map_err(|e| ::nrpc::ServiceError::from(e)) + ) + ); + Ok(Box::new(item_stream)) + + } + } + ); } } } @@ -174,6 +339,7 @@ impl ServiceGenerator for ProtobufServiceGenerator { use super::*; use ::nrpc::_helpers::async_trait::async_trait; use ::nrpc::_helpers::prost::Message; + use ::nrpc::_helpers::futures::StreamExt; #[async_trait] pub trait #service_trait_name: Send { @@ -198,8 +364,12 @@ impl ServiceGenerator for ProtobufServiceGenerator { #descriptor_str } - async fn call(&mut self, method: &str, payload: ::nrpc::_helpers::bytes::Bytes, buffer: &mut ::nrpc::_helpers::bytes::BytesMut) -> Result<(), ::nrpc::ServiceError> { - self.inner.call(method, payload, buffer).await + async fn call<'a>( + &mut self, + method: &str, + input: ::nrpc::ServiceStream<'a, ::nrpc::_helpers::bytes::Bytes>, + ) -> Result<::nrpc::ServiceStream<'a, ::nrpc::_helpers::bytes::Bytes>, ::nrpc::ServiceError> { + self.inner.call(method, input).await } } } @@ -227,6 +397,7 @@ impl ServiceGenerator for ProtobufServiceGenerator { mod #service_mod_name { use super::*; use ::nrpc::_helpers::prost::Message; + use ::nrpc::_helpers::futures::StreamExt; //#[derive(core::any::Any)] pub struct #service_struct_name { diff --git a/nrpc-codegen-test/proto/helloworld.proto b/nrpc-codegen-test/proto/helloworld.proto index d79a6a0..45e8786 100644 --- a/nrpc-codegen-test/proto/helloworld.proto +++ b/nrpc-codegen-test/proto/helloworld.proto @@ -24,6 +24,15 @@ package helloworld; service Greeter { // Sends a greeting rpc SayHello (HelloRequest) returns (HelloReply) {} + + // Sends many -> 1 greeting + rpc SayHelloManyToOne(stream HelloRequest) returns (HelloReply) {} + + // Sends 1 -> many greeting + rpc SayHelloOneToMany(HelloRequest) returns (stream HelloReply) {} + + // // Sends many -> many greeting + rpc SayHelloManyToMany(stream HelloRequest) returns (stream HelloReply) {} } // The request message containing the user's name. diff --git a/nrpc-codegen-test/src/main.rs b/nrpc-codegen-test/src/main.rs index f02528e..3141a94 100644 --- a/nrpc-codegen-test/src/main.rs +++ b/nrpc-codegen-test/src/main.rs @@ -1,6 +1,8 @@ use std::error::Error; +use std::fmt::Write; -use nrpc::ServerService; +use nrpc::_helpers::futures::StreamExt; +use nrpc::{ServerService, ServiceError}; use prost::Message; pub mod generated { @@ -11,28 +13,119 @@ pub use generated::*; #[tokio::main] async fn main() { + // NOTE: This doesn't test network functionality + // it just checks generated code for correctness (compile-time) + // and tests mock client & server traits implementations let req = helloworld::HelloRequest { name: "World".into(), }; let resp = helloworld::HelloReply { message: "Hello World".into(), }; + let original_resp = resp.clone(); // server let mut service_impl = helloworld::GreeterServer::new(GreeterService); + + // server one to one let mut input_buf = bytes::BytesMut::new(); - let mut output_buf = bytes::BytesMut::new(); - req.encode(&mut input_buf).unwrap(); - service_impl - .call("say_hello", input_buf.into(), &mut output_buf) + //let mut output_buf = bytes::BytesMut::new(); + req.clone().encode(&mut input_buf).unwrap(); + let stream_in = nrpc::OnceStream::once(Ok(input_buf.into())); + let mut output_stream = service_impl + .call("say_hello", Box::new(stream_in)) .await .unwrap(); + let output_buf = output_stream.next().await.unwrap().unwrap(); let actual_resp = helloworld::HelloReply::decode(output_buf).unwrap(); assert_eq!(resp, actual_resp); - // client - let mut client_impl = helloworld::GreeterClient::new(ClientHandler); - let resp = client_impl.say_hello(req).await.unwrap(); + // client one to one + let client_impl = helloworld::GreeterClient::new(ClientHandler); + let resp = client_impl.say_hello(req.clone()).await.unwrap(); assert_eq!(resp, actual_resp); + + // server many to one + let resp = helloworld::HelloReply { + message: "Hello World0, World1, World2".into(), + }; + let stream_in = nrpc::VecStream::from_iter([(); 3].iter().enumerate().map(|(i, _)| { + let mut input_buf = bytes::BytesMut::new(); + helloworld::HelloRequest { name: format!("World{}", i) }.encode(&mut input_buf).expect("Protobuf encoding error"); + Ok(input_buf.freeze()) + })); + let mut output_stream = service_impl + .call("say_hello_many_to_one", Box::new(stream_in)) + .await + .unwrap(); + let output_buf = output_stream.next().await.unwrap().unwrap(); + let actual_resp = helloworld::HelloReply::decode(output_buf).unwrap(); + assert_eq!(resp, actual_resp); + + // client many to one + let client_impl = helloworld::GreeterClient::new(ClientHandler); + let stream_in = nrpc::VecStream::from_iter([(); 3].iter().enumerate().map(|(i, _)| + Ok(helloworld::HelloRequest { name: format!("World{}", i) }))); + let resp = client_impl.say_hello_many_to_one(Box::new(stream_in)).await.unwrap(); + assert_eq!(resp, original_resp); + + // server one to many + let resp = vec![ + helloworld::HelloReply { + message: "Hello World".into(), + }, + helloworld::HelloReply { + message: "Hello World".into(), + }, + helloworld::HelloReply { + message: "Hello World".into(), + }, + ]; + let mut input_buf = bytes::BytesMut::new(); + //let mut output_buf = bytes::BytesMut::new(); + req.clone().encode(&mut input_buf).unwrap(); + let stream_in = nrpc::OnceStream::once(Ok(input_buf.into())); + let output_stream = service_impl + .call("say_hello_one_to_many", Box::new(stream_in)) + .await + .unwrap(); + let actual_resp: Vec<_> = output_stream.map(|buf_result| helloworld::HelloReply::decode(buf_result.unwrap()).unwrap()).collect().await; + assert_eq!(resp, actual_resp); + + // client one to many + let client_impl = helloworld::GreeterClient::new(ClientHandler); + let resp: Vec<_> = client_impl.say_hello_one_to_many(req.clone()).await.unwrap().map(|item_result| item_result.unwrap()).collect().await; + assert_eq!(resp, vec![original_resp.clone()]); + + // server many to many + let resp = vec![ + helloworld::HelloReply { + message: "Hello World0".into(), + }, + helloworld::HelloReply { + message: "Hello World1".into(), + }, + helloworld::HelloReply { + message: "Hello World2".into(), + }, + ]; + let stream_in = nrpc::VecStream::from_iter([(); 3].iter().enumerate().map(|(i, _)| { + let mut input_buf = bytes::BytesMut::new(); + helloworld::HelloRequest { name: format!("World{}", i) }.encode(&mut input_buf).expect("Protobuf encoding error"); + Ok(input_buf.freeze()) + })); + let output_stream = service_impl + .call("say_hello_many_to_many", Box::new(stream_in)) + .await + .unwrap(); + let actual_resp: Vec<_> = output_stream.map(|buf_result| helloworld::HelloReply::decode(buf_result.unwrap()).unwrap()).collect().await; + assert_eq!(resp, actual_resp); + + // client many to many + let client_impl = helloworld::GreeterClient::new(ClientHandler); + let stream_in = nrpc::VecStream::from_iter([(); 3].iter().enumerate().map(|(i, _)| + Ok(helloworld::HelloRequest { name: format!("World{}", i) }))); + let resp: Vec<_> = client_impl.say_hello_many_to_many(Box::new(stream_in)).await.unwrap().map(|item_result| item_result.unwrap()).collect().await; + assert_eq!(resp, vec![original_resp.clone(); 3]); } struct GreeterService; @@ -42,20 +135,64 @@ impl helloworld::IGreeter for GreeterService { async fn say_hello( &mut self, input: helloworld::HelloRequest, - ) -> Result> { + ) -> Result> { let result = helloworld::HelloReply { message: format!("Hello {}", input.name), }; println!("{}", result.message); Ok(result) } + + async fn say_hello_one_to_many<'a>( + &mut self, + input: helloworld::HelloRequest, + ) -> Result< + ::nrpc::ServiceStream<'a, helloworld::HelloReply>, + Box, + > { + let result = helloworld::HelloReply { + message: format!("Hello {}", input.name), + }; + println!("{}", result.message); + Ok(Box::new(::nrpc::VecStream::from_iter([(); 3].iter().map(move |_| Ok(result.clone()))))) + } + + async fn say_hello_many_to_one<'a>( + &mut self, + mut input: ::nrpc::ServiceStream<'a, helloworld::HelloRequest>, + ) -> Result>{ + let mut message = "Hello ".to_string(); + while let Some(item_result) = input.next().await { + write!(message, "{}, ", item_result.map_err(|e| Box::new(e) as Box)?.name) + .map_err(|e| Box::new(e) as Box)?; + } + let result = helloworld::HelloReply { message: message.trim_end_matches(", ").to_string(), }; + println!("{}", result.message); + Ok(result) + } + + async fn say_hello_many_to_many<'a>( + &mut self, + input: ::nrpc::ServiceStream<'a, helloworld::HelloRequest>, + ) -> Result< + ::nrpc::ServiceStream<'a, helloworld::HelloReply>, + Box, + >{ + Ok(Box::new(input.map(|item_result| item_result.map(|input| { + let result = helloworld::HelloReply { + message: format!("Hello {}", input.name), + }; + println!("(many to many) {}", result.message); + result + })))) + } } struct ClientHandler; #[async_trait::async_trait] impl nrpc::ClientHandler for ClientHandler { - async fn call( + /*async fn call( &mut self, package: &str, service: &str, @@ -72,5 +209,28 @@ impl nrpc::ClientHandler for ClientHandler { message: "Hello World".into(), } .encode(output)?) + }*/ + + async fn call<'a>( + &self, + package: &str, + service: &str, + method: &str, + input: ::nrpc::ServiceStream<'a, ::nrpc::_helpers::bytes::Bytes>, + ) -> Result<::nrpc::ServiceStream<'a, ::nrpc::_helpers::bytes::Bytes>, ServiceError> { + println!( + "call {}.{}/{} with data stream", + package, service, method + ); + // This is ok to hardcode ONLY because it's for testing + Ok( + Box::new(input.map(|item_result| { + let mut output = bytes::BytesMut::new(); + item_result.and_then(|_item| helloworld::HelloReply { + message: format!("Hello World"), + }.encode(&mut output).map(|_| output.freeze()).map_err(|e| ServiceError::from(e))) + } + )) + ) } } diff --git a/nrpc/Cargo.toml b/nrpc/Cargo.toml index 7f3c85a..235fe72 100644 --- a/nrpc/Cargo.toml +++ b/nrpc/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "nrpc" -version = "0.7.0" +version = "0.8.0" edition = "2021" license = "Apache-2.0" repository = "https://github.com/NGnius/nRPC" @@ -11,3 +11,4 @@ description = "Yet another remote procedure call library" prost = "0.11" bytes = "1" async-trait = "0.1" +futures = "0.3" diff --git a/nrpc/src/lib.rs b/nrpc/src/lib.rs index df7a826..746eda1 100644 --- a/nrpc/src/lib.rs +++ b/nrpc/src/lib.rs @@ -1,9 +1,13 @@ mod service; +mod stream_utils; -pub use service::{ClientHandler, ClientService, ServerService, ServiceError}; +pub use service::{ClientHandler, ClientService, ServerService, ServiceError, ServiceStream}; + +pub use stream_utils::{EmptyStream, OnceStream, VecStream}; pub mod _helpers { pub use async_trait; pub use bytes; pub use prost; + pub use futures; } diff --git a/nrpc/src/service.rs b/nrpc/src/service.rs index 624997d..59fbc48 100644 --- a/nrpc/src/service.rs +++ b/nrpc/src/service.rs @@ -1,25 +1,28 @@ +use futures::Stream; +use core::marker::Unpin; + +pub type ServiceStream<'a, T> = Box> + Unpin + Send + 'a>; + #[async_trait::async_trait] pub trait ServerService { fn descriptor(&self) -> &'static str; - async fn call( + async fn call<'a>( &mut self, method: &str, - input: bytes::Bytes, - output: &mut bytes::BytesMut, - ) -> Result<(), ServiceError>; + input: ServiceStream<'a, bytes::Bytes>, + ) -> Result, ServiceError>; } #[async_trait::async_trait] pub trait ClientHandler { - async fn call( + async fn call<'a>( &self, package: &str, service: &str, method: &str, - input: bytes::Bytes, - output: &mut bytes::BytesMut, - ) -> Result<(), ServiceError>; + input: ServiceStream<'a, bytes::Bytes>, + ) -> Result, ServiceError>; } pub trait ClientService { @@ -32,7 +35,11 @@ pub enum ServiceError { Decode(prost::DecodeError), MethodNotFound, ServiceNotFound, - Method(Box), + Method(Box), + StreamLength { + want: u64, + got: u64, + } } impl std::fmt::Display for ServiceError { @@ -43,6 +50,7 @@ impl std::fmt::Display for ServiceError { Self::MethodNotFound => write!(f, "Method not found error"), Self::ServiceNotFound => write!(f, "Service not found error"), Self::Method(e) => write!(f, "Method error: {}", e), + Self::StreamLength{ want, got } => write!(f, "Stream length error: wanted {}, got {}", want, got), } } } @@ -59,8 +67,8 @@ impl std::convert::From for ServiceError { } } -impl std::convert::From> for ServiceError { - fn from(value: Box) -> Self { +impl std::convert::From> for ServiceError { + fn from(value: Box) -> Self { Self::Method(value) } } diff --git a/nrpc/src/stream_utils.rs b/nrpc/src/stream_utils.rs new file mode 100644 index 0000000..7e05fe5 --- /dev/null +++ b/nrpc/src/stream_utils.rs @@ -0,0 +1,80 @@ +use futures::Stream; + +use core::{pin::Pin, task::{Context, Poll}}; +use core::marker::{PhantomData, Unpin}; + +#[derive(Default, Clone, Copy)] +pub struct EmptyStream { + _idc: PhantomData, +} + +impl Stream for EmptyStream { + type Item = T; + + fn poll_next( + self: Pin<&mut Self>, + _cx: &mut Context<'_> + ) -> Poll> { + Poll::Ready(None) + } + + fn size_hint(&self) -> (usize, Option) { + (0, Some(0)) + } +} + +#[derive(Clone)] +pub struct OnceStream { + item: Option, +} + +impl OnceStream { + pub fn once(item: T) -> Self { + Self { item: Some(item) } + } +} + +impl Stream for OnceStream { + type Item = T; + + fn poll_next( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_> + ) -> Poll> { + Poll::Ready(self.item.take()) + } + + fn size_hint(&self) -> (usize, Option) { + if self.item.is_some() { + (1, Some(1)) + } else { + (0, Some(0)) + } + } +} + +#[derive(Clone)] +pub struct VecStream { + items: std::collections::VecDeque, +} + +impl VecStream { + pub fn from_iter(iter: impl Iterator) -> Self { + Self { items: iter.collect() } + } +} + +impl Stream for VecStream { + type Item = T; + + fn poll_next( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_> + ) -> Poll> { + Poll::Ready(self.items.pop_front()) + } + + fn size_hint(&self) -> (usize, Option) { + (self.items.len(), Some(self.items.len())) + } +}