From 8ea8261f382deacddc9c322509e7f778a698db2d Mon Sep 17 00:00:00 2001 From: Peter Cai Date: Sun, 9 Mar 2025 12:48:56 -0400 Subject: [PATCH] Add dry_run mode for debugging purposes --- config.toml | 5 +++++ src/config.rs | 2 ++ src/main.rs | 29 ++++++++++++++++++++++++----- src/state.rs | 13 +++++++++++++ 4 files changed, 44 insertions(+), 5 deletions(-) diff --git a/config.toml b/config.toml index 2a80560..c0ba75a 100644 --- a/config.toml +++ b/config.toml @@ -59,6 +59,11 @@ allowed_globals = [ "zwlr_data_control_manager_v1" ] +# When set to true, do not actually filter anything -- only emit a +# warning when a filter would have been triggered. +# Defaults to false +# dry_run = false + # A list of requests we'd like to filter [[filter.requests]] # The interface name in question diff --git a/src/config.rs b/src/config.rs index 22b6bdc..eb684c9 100644 --- a/src/config.rs +++ b/src/config.rs @@ -63,6 +63,8 @@ pub struct WlFilter { pub allowed_globals: HashSet, #[serde(deserialize_with = "deserialize_filter_requests")] pub requests: HashMap>, + #[serde(default)] + pub dry_run: bool, } #[derive(Deserialize)] diff --git a/src/main.rs b/src/main.rs index 9b1942f..300e6fe 100644 --- a/src/main.rs +++ b/src/main.rs @@ -14,7 +14,7 @@ use io_util::{WlMsgReader, WlMsgWriter}; use proto::{WL_DISPLAY_OBJECT_ID, WlConstructableMessage, WlDisplayErrorEvent}; use state::{WlMitmOutcome, WlMitmState, WlMitmVerdict}; use tokio::net::{UnixListener, UnixStream}; -use tracing::{Instrument, Level, debug, error, info, span}; +use tracing::{Instrument, Level, debug, error, info, span, warn}; #[tokio::main] async fn main() { @@ -79,6 +79,7 @@ macro_rules! control_flow { } struct ConnDuplex<'a> { + config: Arc, upstream_read: WlMsgReader<'a>, upstream_write: WlMsgWriter<'a>, downstream_read: WlMsgReader<'a>, @@ -88,6 +89,7 @@ struct ConnDuplex<'a> { impl<'a> ConnDuplex<'a> { pub fn new( + config: Arc, state: WlMitmState, upstream_conn: &'a mut UnixStream, downstream_conn: &'a mut UnixStream, @@ -102,6 +104,7 @@ impl<'a> ConnDuplex<'a> { let downstream_write = WlMsgWriter::new(downstream_write); Self { + config, upstream_read, upstream_write, downstream_read, @@ -123,11 +126,19 @@ impl<'a> ConnDuplex<'a> { "s2c event" ); - let WlMitmOutcome(num_consumed_fds, verdict) = + let WlMitmOutcome(num_consumed_fds, mut verdict) = self.state.on_s2c_event(&wl_raw_msg).await; self.upstream_read .return_unused_fds(&mut wl_raw_msg, num_consumed_fds); + if !verdict.is_allowed() && self.config.filter.dry_run { + warn!( + verdict = ?verdict, + "Last event would have been filtered! (see prior logs for reason)" + ); + verdict = WlMitmVerdict::Allowed; + } + match verdict { WlMitmVerdict::Allowed => { self.downstream_write.queue_write(wl_raw_msg); @@ -161,11 +172,19 @@ impl<'a> ConnDuplex<'a> { "c2s request" ); - let WlMitmOutcome(num_consumed_fds, verdict) = + let WlMitmOutcome(num_consumed_fds, mut verdict) = self.state.on_c2s_request(&wl_raw_msg).await; self.downstream_read .return_unused_fds(&mut wl_raw_msg, num_consumed_fds); + if !verdict.is_allowed() && self.config.filter.dry_run { + warn!( + verdict = ?verdict, + "Last request would have been filtered! (see prior logs for reason)" + ); + verdict = WlMitmVerdict::Allowed; + } + match verdict { WlMitmVerdict::Allowed => { self.upstream_write.queue_write(wl_raw_msg); @@ -222,9 +241,9 @@ pub async fn handle_conn( mut downstream_conn: UnixStream, ) -> io::Result<()> { let mut upstream_conn = UnixStream::connect(src_path).await?; - let state = WlMitmState::new(config); + let state = WlMitmState::new(config.clone()); - let duplex = ConnDuplex::new(state, &mut upstream_conn, &mut downstream_conn); + let duplex = ConnDuplex::new(config, state, &mut upstream_conn, &mut downstream_conn); duplex.run_to_completion().await } diff --git a/src/state.rs b/src/state.rs index 394f1c3..f00a7ce 100644 --- a/src/state.rs +++ b/src/state.rs @@ -16,6 +16,7 @@ use crate::{ }; /// What to do for a message? +#[derive(Debug)] pub enum WlMitmVerdict { /// This message is allowed. Pass it through to the opposite end. Allowed, @@ -27,6 +28,12 @@ pub enum WlMitmVerdict { Terminate, } +impl WlMitmVerdict { + pub fn is_allowed(&self) -> bool { + matches!(self, WlMitmVerdict::Allowed) + } +} + impl Default for WlMitmVerdict { fn default() -> Self { WlMitmVerdict::Terminate @@ -259,6 +266,12 @@ impl WlMitmState { // to bind to it; if it does, it's likely a malicious client! // So, we simply remove these messages from the stream, which will cause the Wayland server to error out. let Some(obj_type) = self.objects.lookup_global(msg.name) else { + warn!( + interface = msg.name, + version = msg.id_interface_version, + obj_id = msg.id, + "Client binding non-existent or filtered interface" + ); return outcome.terminate(); };