Add ID checks to packet handling to reduce replay attack surface

This commit is contained in:
NGnius (Graham) 2022-12-05 17:46:12 -05:00
parent fbaef000b5
commit 8387c8024e

View file

@ -1,4 +1,5 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use warp::Filter; use warp::Filter;
@ -7,6 +8,9 @@ use usdpl_core::{socket, RemoteCallResponse};
use super::{Callable, MutCallable, AsyncCallable, WrappedCallable}; use super::{Callable, MutCallable, AsyncCallable, WrappedCallable};
static LAST_ID: AtomicU64 = AtomicU64::new(0);
const MAX_ID_DIFFERENCE: u64 = 5;
//type WrappedCallable = Arc<Mutex<Box<dyn Callable>>>; // thread-safe, cloneable Callable //type WrappedCallable = Arc<Mutex<Box<dyn Callable>>>; // thread-safe, cloneable Callable
#[cfg(feature = "encrypt")] #[cfg(feature = "encrypt")]
@ -88,7 +92,22 @@ impl Instance {
) -> socket::Packet { ) -> socket::Packet {
match packet { match packet {
socket::Packet::Call(call) => { socket::Packet::Call(call) => {
log::info!("Got USDPL call {} (`{}`, params: {})", call.id, call.function, call.parameters.len()); log::debug!("Got USDPL call {} (`{}`, params: {})", call.id, call.function, call.parameters.len());
let last_id = LAST_ID.load(Ordering::SeqCst);
if call.id == 0 {
log::info!("Call ID is 0, assuming new connection (resetting last id)");
LAST_ID.store(0, Ordering::SeqCst);
} else if call.id > last_id && call.id - last_id < MAX_ID_DIFFERENCE {
LAST_ID.store(call.id, Ordering::SeqCst);
} else {
#[cfg(not(debug_assertions))]
{
log::error!("Got USDPL call with strange ID! got:{} last id:{} (rejecting packet)", call.id, last_id);
return socket::Packet::Invalid
}
#[cfg(debug_assertions)]
log::warn!("Got USDPL call with strange ID! got:{} last id:{} (in release mode this packet will be rejected)", call.id, last_id);
}
//let handlers = CALLS.lock().expect("Failed to acquire CALLS lock"); //let handlers = CALLS.lock().expect("Failed to acquire CALLS lock");
if let Some(target) = handlers.get(&call.function) { if let Some(target) = handlers.get(&call.function) {
let result = target.call(call.parameters).await; let result = target.call(call.parameters).await;