Add dry_run mode for debugging purposes

This commit is contained in:
Peter Cai 2025-03-09 12:48:56 -04:00
parent 49ed447639
commit 8ea8261f38
4 changed files with 44 additions and 5 deletions

View file

@ -59,6 +59,11 @@ allowed_globals = [
"zwlr_data_control_manager_v1"
]
# When set to true, do not actually filter anything -- only emit a
# warning when a filter would have been triggered.
# Defaults to false
# dry_run = false
# A list of requests we'd like to filter
[[filter.requests]]
# The interface name in question

View file

@ -63,6 +63,8 @@ pub struct WlFilter {
pub allowed_globals: HashSet<String>,
#[serde(deserialize_with = "deserialize_filter_requests")]
pub requests: HashMap<String, Vec<WlFilterRequest>>,
#[serde(default)]
pub dry_run: bool,
}
#[derive(Deserialize)]

View file

@ -14,7 +14,7 @@ use io_util::{WlMsgReader, WlMsgWriter};
use proto::{WL_DISPLAY_OBJECT_ID, WlConstructableMessage, WlDisplayErrorEvent};
use state::{WlMitmOutcome, WlMitmState, WlMitmVerdict};
use tokio::net::{UnixListener, UnixStream};
use tracing::{Instrument, Level, debug, error, info, span};
use tracing::{Instrument, Level, debug, error, info, span, warn};
#[tokio::main]
async fn main() {
@ -79,6 +79,7 @@ macro_rules! control_flow {
}
struct ConnDuplex<'a> {
config: Arc<Config>,
upstream_read: WlMsgReader<'a>,
upstream_write: WlMsgWriter<'a>,
downstream_read: WlMsgReader<'a>,
@ -88,6 +89,7 @@ struct ConnDuplex<'a> {
impl<'a> ConnDuplex<'a> {
pub fn new(
config: Arc<Config>,
state: WlMitmState,
upstream_conn: &'a mut UnixStream,
downstream_conn: &'a mut UnixStream,
@ -102,6 +104,7 @@ impl<'a> ConnDuplex<'a> {
let downstream_write = WlMsgWriter::new(downstream_write);
Self {
config,
upstream_read,
upstream_write,
downstream_read,
@ -123,11 +126,19 @@ impl<'a> ConnDuplex<'a> {
"s2c event"
);
let WlMitmOutcome(num_consumed_fds, verdict) =
let WlMitmOutcome(num_consumed_fds, mut verdict) =
self.state.on_s2c_event(&wl_raw_msg).await;
self.upstream_read
.return_unused_fds(&mut wl_raw_msg, num_consumed_fds);
if !verdict.is_allowed() && self.config.filter.dry_run {
warn!(
verdict = ?verdict,
"Last event would have been filtered! (see prior logs for reason)"
);
verdict = WlMitmVerdict::Allowed;
}
match verdict {
WlMitmVerdict::Allowed => {
self.downstream_write.queue_write(wl_raw_msg);
@ -161,11 +172,19 @@ impl<'a> ConnDuplex<'a> {
"c2s request"
);
let WlMitmOutcome(num_consumed_fds, verdict) =
let WlMitmOutcome(num_consumed_fds, mut verdict) =
self.state.on_c2s_request(&wl_raw_msg).await;
self.downstream_read
.return_unused_fds(&mut wl_raw_msg, num_consumed_fds);
if !verdict.is_allowed() && self.config.filter.dry_run {
warn!(
verdict = ?verdict,
"Last request would have been filtered! (see prior logs for reason)"
);
verdict = WlMitmVerdict::Allowed;
}
match verdict {
WlMitmVerdict::Allowed => {
self.upstream_write.queue_write(wl_raw_msg);
@ -222,9 +241,9 @@ pub async fn handle_conn(
mut downstream_conn: UnixStream,
) -> io::Result<()> {
let mut upstream_conn = UnixStream::connect(src_path).await?;
let state = WlMitmState::new(config);
let state = WlMitmState::new(config.clone());
let duplex = ConnDuplex::new(state, &mut upstream_conn, &mut downstream_conn);
let duplex = ConnDuplex::new(config, state, &mut upstream_conn, &mut downstream_conn);
duplex.run_to_completion().await
}

View file

@ -16,6 +16,7 @@ use crate::{
};
/// What to do for a message?
#[derive(Debug)]
pub enum WlMitmVerdict {
/// This message is allowed. Pass it through to the opposite end.
Allowed,
@ -27,6 +28,12 @@ pub enum WlMitmVerdict {
Terminate,
}
impl WlMitmVerdict {
pub fn is_allowed(&self) -> bool {
matches!(self, WlMitmVerdict::Allowed)
}
}
impl Default for WlMitmVerdict {
fn default() -> Self {
WlMitmVerdict::Terminate
@ -259,6 +266,12 @@ impl WlMitmState {
// to bind to it; if it does, it's likely a malicious client!
// So, we simply remove these messages from the stream, which will cause the Wayland server to error out.
let Some(obj_type) = self.objects.lookup_global(msg.name) else {
warn!(
interface = msg.name,
version = msg.id_interface_version,
obj_id = msg.id,
"Client binding non-existent or filtered interface"
);
return outcome.terminate();
};