Compare commits

...

4 commits

Author SHA1 Message Date
f78628a3fd Filter out unexpected object creation / destruction 2025-03-03 21:15:42 -05:00
aea28e4457 Allow terminating connections when we receive unexpected messages 2025-03-03 21:01:35 -05:00
49947ba052 wl_output is required by xwayland-satellite
well, actually most programs... so let's add it to the default / example
config.

Also change the Wayland socket name to wayland-10 to avoid conflicts.
2025-03-03 20:18:55 -05:00
6d772971ff Log object type when possible for errors 2025-03-03 19:59:58 -05:00
3 changed files with 123 additions and 26 deletions

View file

@ -1,7 +1,7 @@
[socket]
# Which socket to listen on? If relative,
# defaults to being relative to $XDG_RUNTIME_DIR
listen = "wayland-2"
listen = "wayland-10"
# Which Wayland socket to use as upstream?
# If missing, defaults to $WAYLAND_DISPLAY
# upstream = "wayland-1"
@ -16,6 +16,7 @@ allowed_globals = [
"wl_compositor",
"wl_shm",
"wl_data_device_manager",
"wl_output", # each output is also a global
"wl_seat",
# Window management
"xdg_wm_base",

View file

@ -10,7 +10,7 @@ use std::{io, path::Path, sync::Arc};
use config::Config;
use io_util::{WlMsgReader, WlMsgWriter};
use state::WlMitmState;
use state::{WlMitmOutcome, WlMitmState, WlMitmVerdict};
use tokio::net::{UnixListener, UnixStream};
use tracing::{Instrument, Level, debug, error, info, span};
@ -93,11 +93,15 @@ pub async fn handle_conn(
codec::DecoderOutcome::Decoded(mut wl_raw_msg) => {
debug!(obj_id = wl_raw_msg.obj_id, opcode = wl_raw_msg.opcode, num_fds = wl_raw_msg.fds.len(), "s2c event");
let (num_consumed_fds, verdict) = state.on_s2c_event(&wl_raw_msg).await;
let WlMitmOutcome(num_consumed_fds, verdict) = state.on_s2c_event(&wl_raw_msg).await;
upstream_read.return_unused_fds(&mut wl_raw_msg, num_consumed_fds);
if verdict {
downstream_write.queue_write(wl_raw_msg);
match verdict {
WlMitmVerdict::Allowed => {
downstream_write.queue_write(wl_raw_msg);
},
WlMitmVerdict::Terminate => break Err(io::Error::new(io::ErrorKind::ConnectionAborted, "aborting connection")),
_ => {}
}
},
codec::DecoderOutcome::Incomplete => continue,
@ -109,11 +113,15 @@ pub async fn handle_conn(
codec::DecoderOutcome::Decoded(mut wl_raw_msg) => {
debug!(obj_id = wl_raw_msg.obj_id, opcode = wl_raw_msg.opcode, num_fds = wl_raw_msg.fds.len(), "c2s request");
let (num_consumed_fds, verdict) = state.on_c2s_request(&wl_raw_msg).await;
let WlMitmOutcome(num_consumed_fds, verdict) = state.on_c2s_request(&wl_raw_msg).await;
downstream_read.return_unused_fds(&mut wl_raw_msg, num_consumed_fds);
if verdict {
upstream_write.queue_write(wl_raw_msg);
match verdict {
WlMitmVerdict::Allowed => {
upstream_write.queue_write(wl_raw_msg);
},
WlMitmVerdict::Terminate => break Err(io::Error::new(io::ErrorKind::ConnectionAborted, "aborting connection")),
_ => {}
}
},
codec::DecoderOutcome::Incomplete => continue,

View file

@ -12,6 +12,52 @@ use crate::{
},
};
/// What to do for a message?
pub enum WlMitmVerdict {
/// This message is allowed. Pass it through to the opposite end.
Allowed,
/// This message is filtered.
/// TODO: We should probably construct a proper error response
Filtered,
/// Terminate this entire session. Something is off.
Terminate,
}
impl Default for WlMitmVerdict {
fn default() -> Self {
WlMitmVerdict::Terminate
}
}
/// Result returned by [WlMitmState] when handling messages.
/// It's a pair of (num_consumed_fds, verdict).
///
/// We need to return back unused fds from the [WlRawMsg], which
/// is why this has to be returned from here.
#[derive(Default)]
pub struct WlMitmOutcome(pub usize, pub WlMitmVerdict);
impl WlMitmOutcome {
fn set_consumed_fds(&mut self, consumed_fds: usize) {
self.0 = consumed_fds;
}
fn allowed(mut self) -> Self {
self.1 = WlMitmVerdict::Allowed;
self
}
fn filtered(mut self) -> Self {
self.1 = WlMitmVerdict::Filtered;
self
}
fn terminate(mut self) -> Self {
self.1 = WlMitmVerdict::Terminate;
self
}
}
pub struct WlMitmState {
config: Arc<Config>,
objects: WlObjects,
@ -27,12 +73,28 @@ impl WlMitmState {
/// Handle messages which register new objects with known interfaces or deletes them.
///
/// If there is an error, this function will return false and the connection shall be terminated.
///
/// Note that most _globals_ are instantiated using [WlRegistryBindRequest]. That request
/// is not handled here.
fn handle_created_or_destroyed_objects(&mut self, msg: &dyn AnyWlParsedMessage<'_>) {
fn handle_created_or_destroyed_objects(&mut self, msg: &dyn AnyWlParsedMessage<'_>) -> bool {
if let Some(created_objects) = msg.known_objects_created() {
if let Some(parent_obj) = self.objects.lookup_object(msg.obj_id()) {
for (id, tt) in created_objects.into_iter() {
if let Some(existing_obj_type) = self.objects.lookup_object(id) {
debug!(
parent_obj_id = msg.obj_id(),
obj_type = tt.interface(),
obj_id = id,
existing_obj_type = existing_obj_type.interface(),
"Trying to create object via message {}::{} but the object ID is already used!",
parent_obj.interface(),
msg.self_msg_name()
);
return false;
}
debug!(
parent_obj_id = msg.obj_id(),
obj_type = tt.interface(),
@ -44,7 +106,8 @@ impl WlMitmState {
self.objects.record_object(tt, id);
}
} else {
error!("Parent object ID {} not found, ignoring", msg.obj_id());
error!("Parent object ID {} not found!", msg.obj_id());
return false;
}
} else if msg.is_destructor() {
if let Some(obj_type) = self.objects.lookup_object(msg.obj_id()) {
@ -58,25 +121,38 @@ impl WlMitmState {
self.objects.remove_object(msg.obj_id());
}
true
}
/// Returns the number of fds consumed while parsing the message as a concrete Wayland type, and a verdict
#[tracing::instrument(skip_all)]
pub async fn on_c2s_request(&mut self, raw_msg: &WlRawMsg) -> (usize, bool) {
pub async fn on_c2s_request(&mut self, raw_msg: &WlRawMsg) -> WlMitmOutcome {
let mut outcome: WlMitmOutcome = Default::default();
let msg = match crate::proto::decode_request(&self.objects, raw_msg) {
WaylandProtocolParsingOutcome::Ok(msg) => msg,
_ => {
let obj_type = self
.objects
.lookup_object(raw_msg.obj_id)
.map(|t| t.interface());
error!(
obj_id = raw_msg.obj_id,
obj_type = ?obj_type,
opcode = raw_msg.opcode,
num_fds = raw_msg.fds.len(),
"Malformed or unknown request"
);
return (0, false);
return outcome.terminate();
}
};
self.handle_created_or_destroyed_objects(&*msg);
outcome.set_consumed_fds(msg.num_consumed_fds());
if !self.handle_created_or_destroyed_objects(&*msg) {
return outcome.terminate();
}
// The bind request doesn't create interface with a fixed type; handle it separately.
if let Some(msg) = msg.downcast_ref::<WlRegistryBindRequest>() {
@ -86,7 +162,7 @@ impl WlMitmState {
// to bind to it; if it does, it's likely a malicious client!
// So, we simply remove these messages from the stream, which will cause the Wayland server to error out.
let Some(obj_type) = self.objects.lookup_global(msg.name) else {
return (0, false);
return outcome.terminate();
};
if obj_type.interface() != msg.id_interface_name {
@ -96,7 +172,7 @@ impl WlMitmState {
msg.name,
obj_type.interface()
);
return (0, false);
return outcome.terminate();
}
info!(
@ -144,9 +220,10 @@ impl WlMitmState {
msg.self_msg_name(),
status
);
return outcome.filtered();
} else {
return outcome.allowed();
}
return (msg.num_consumed_fds(), status.success());
}
}
@ -155,7 +232,7 @@ impl WlMitmState {
msg.self_object_type().interface(),
msg.self_msg_name()
);
return (msg.num_consumed_fds(), false);
return outcome.filtered();
}
WlFilterRequestAction::Block => {
warn!(
@ -164,31 +241,42 @@ impl WlMitmState {
msg.self_msg_name()
);
// TODO: don't just return false, build an error event
return (msg.num_consumed_fds(), false);
return outcome.filtered();
}
}
}
}
(msg.num_consumed_fds(), true)
outcome.allowed()
}
#[tracing::instrument(skip_all)]
pub async fn on_s2c_event(&mut self, raw_msg: &WlRawMsg) -> (usize, bool) {
pub async fn on_s2c_event(&mut self, raw_msg: &WlRawMsg) -> WlMitmOutcome {
let mut outcome: WlMitmOutcome = Default::default();
let msg = match crate::proto::decode_event(&self.objects, raw_msg) {
WaylandProtocolParsingOutcome::Ok(msg) => msg,
_ => {
let obj_type = self
.objects
.lookup_object(raw_msg.obj_id)
.map(|t| t.interface());
error!(
obj_id = raw_msg.obj_id,
obj_type = ?obj_type,
opcode = raw_msg.opcode,
num_fds = raw_msg.fds.len(),
"Malformed or unknown event"
);
return (0, false);
return outcome.terminate();
}
};
self.handle_created_or_destroyed_objects(&*msg);
outcome.set_consumed_fds(msg.num_consumed_fds());
if !self.handle_created_or_destroyed_objects(&*msg) {
return outcome.terminate();
}
if let Some(msg) = msg.downcast_ref::<WlRegistryGlobalEvent>() {
// This event is how Wayland servers announce globals -- and they are the entrypoint to
@ -207,7 +295,7 @@ impl WlMitmState {
"Unknown interface removed! If required, please include its XML when building wl-mitm!"
);
return (0, false);
return outcome.filtered();
};
// To block entire extensions, we just need to filter out their announced global objects.
@ -216,7 +304,7 @@ impl WlMitmState {
interface = msg.interface,
"Removing interface from published globals"
);
return (0, false);
return outcome.filtered();
}
// Else, record the global object. These are the only ones we're ever going to allow through.
@ -230,6 +318,6 @@ impl WlMitmState {
self.objects.remove_object(msg.id);
}
(msg.num_consumed_fds(), true)
outcome.allowed()
}
}