Compare commits

..

No commits in common. "f78628a3fd7eb8c37c82b227fa162c948297faee" and "65c5567c8875dd0ce0529ded1a7e302211460d98" have entirely different histories.

3 changed files with 26 additions and 123 deletions

View file

@ -1,7 +1,7 @@
[socket] [socket]
# Which socket to listen on? If relative, # Which socket to listen on? If relative,
# defaults to being relative to $XDG_RUNTIME_DIR # defaults to being relative to $XDG_RUNTIME_DIR
listen = "wayland-10" listen = "wayland-2"
# Which Wayland socket to use as upstream? # Which Wayland socket to use as upstream?
# If missing, defaults to $WAYLAND_DISPLAY # If missing, defaults to $WAYLAND_DISPLAY
# upstream = "wayland-1" # upstream = "wayland-1"
@ -16,7 +16,6 @@ allowed_globals = [
"wl_compositor", "wl_compositor",
"wl_shm", "wl_shm",
"wl_data_device_manager", "wl_data_device_manager",
"wl_output", # each output is also a global
"wl_seat", "wl_seat",
# Window management # Window management
"xdg_wm_base", "xdg_wm_base",

View file

@ -10,7 +10,7 @@ use std::{io, path::Path, sync::Arc};
use config::Config; use config::Config;
use io_util::{WlMsgReader, WlMsgWriter}; use io_util::{WlMsgReader, WlMsgWriter};
use state::{WlMitmOutcome, WlMitmState, WlMitmVerdict}; use state::WlMitmState;
use tokio::net::{UnixListener, UnixStream}; use tokio::net::{UnixListener, UnixStream};
use tracing::{Instrument, Level, debug, error, info, span}; use tracing::{Instrument, Level, debug, error, info, span};
@ -93,15 +93,11 @@ pub async fn handle_conn(
codec::DecoderOutcome::Decoded(mut wl_raw_msg) => { codec::DecoderOutcome::Decoded(mut wl_raw_msg) => {
debug!(obj_id = wl_raw_msg.obj_id, opcode = wl_raw_msg.opcode, num_fds = wl_raw_msg.fds.len(), "s2c event"); debug!(obj_id = wl_raw_msg.obj_id, opcode = wl_raw_msg.opcode, num_fds = wl_raw_msg.fds.len(), "s2c event");
let WlMitmOutcome(num_consumed_fds, verdict) = state.on_s2c_event(&wl_raw_msg).await; let (num_consumed_fds, verdict) = state.on_s2c_event(&wl_raw_msg).await;
upstream_read.return_unused_fds(&mut wl_raw_msg, num_consumed_fds); upstream_read.return_unused_fds(&mut wl_raw_msg, num_consumed_fds);
match verdict { if verdict {
WlMitmVerdict::Allowed => {
downstream_write.queue_write(wl_raw_msg); downstream_write.queue_write(wl_raw_msg);
},
WlMitmVerdict::Terminate => break Err(io::Error::new(io::ErrorKind::ConnectionAborted, "aborting connection")),
_ => {}
} }
}, },
codec::DecoderOutcome::Incomplete => continue, codec::DecoderOutcome::Incomplete => continue,
@ -113,15 +109,11 @@ pub async fn handle_conn(
codec::DecoderOutcome::Decoded(mut wl_raw_msg) => { codec::DecoderOutcome::Decoded(mut wl_raw_msg) => {
debug!(obj_id = wl_raw_msg.obj_id, opcode = wl_raw_msg.opcode, num_fds = wl_raw_msg.fds.len(), "c2s request"); debug!(obj_id = wl_raw_msg.obj_id, opcode = wl_raw_msg.opcode, num_fds = wl_raw_msg.fds.len(), "c2s request");
let WlMitmOutcome(num_consumed_fds, verdict) = state.on_c2s_request(&wl_raw_msg).await; let (num_consumed_fds, verdict) = state.on_c2s_request(&wl_raw_msg).await;
downstream_read.return_unused_fds(&mut wl_raw_msg, num_consumed_fds); downstream_read.return_unused_fds(&mut wl_raw_msg, num_consumed_fds);
match verdict { if verdict {
WlMitmVerdict::Allowed => {
upstream_write.queue_write(wl_raw_msg); upstream_write.queue_write(wl_raw_msg);
},
WlMitmVerdict::Terminate => break Err(io::Error::new(io::ErrorKind::ConnectionAborted, "aborting connection")),
_ => {}
} }
}, },
codec::DecoderOutcome::Incomplete => continue, codec::DecoderOutcome::Incomplete => continue,

View file

@ -12,52 +12,6 @@ use crate::{
}, },
}; };
/// What to do for a message?
pub enum WlMitmVerdict {
/// This message is allowed. Pass it through to the opposite end.
Allowed,
/// This message is filtered.
/// TODO: We should probably construct a proper error response
Filtered,
/// Terminate this entire session. Something is off.
Terminate,
}
impl Default for WlMitmVerdict {
fn default() -> Self {
WlMitmVerdict::Terminate
}
}
/// Result returned by [WlMitmState] when handling messages.
/// It's a pair of (num_consumed_fds, verdict).
///
/// We need to return back unused fds from the [WlRawMsg], which
/// is why this has to be returned from here.
#[derive(Default)]
pub struct WlMitmOutcome(pub usize, pub WlMitmVerdict);
impl WlMitmOutcome {
fn set_consumed_fds(&mut self, consumed_fds: usize) {
self.0 = consumed_fds;
}
fn allowed(mut self) -> Self {
self.1 = WlMitmVerdict::Allowed;
self
}
fn filtered(mut self) -> Self {
self.1 = WlMitmVerdict::Filtered;
self
}
fn terminate(mut self) -> Self {
self.1 = WlMitmVerdict::Terminate;
self
}
}
pub struct WlMitmState { pub struct WlMitmState {
config: Arc<Config>, config: Arc<Config>,
objects: WlObjects, objects: WlObjects,
@ -73,28 +27,12 @@ impl WlMitmState {
/// Handle messages which register new objects with known interfaces or deletes them. /// Handle messages which register new objects with known interfaces or deletes them.
/// ///
/// If there is an error, this function will return false and the connection shall be terminated.
///
/// Note that most _globals_ are instantiated using [WlRegistryBindRequest]. That request /// Note that most _globals_ are instantiated using [WlRegistryBindRequest]. That request
/// is not handled here. /// is not handled here.
fn handle_created_or_destroyed_objects(&mut self, msg: &dyn AnyWlParsedMessage<'_>) -> bool { fn handle_created_or_destroyed_objects(&mut self, msg: &dyn AnyWlParsedMessage<'_>) {
if let Some(created_objects) = msg.known_objects_created() { if let Some(created_objects) = msg.known_objects_created() {
if let Some(parent_obj) = self.objects.lookup_object(msg.obj_id()) { if let Some(parent_obj) = self.objects.lookup_object(msg.obj_id()) {
for (id, tt) in created_objects.into_iter() { for (id, tt) in created_objects.into_iter() {
if let Some(existing_obj_type) = self.objects.lookup_object(id) {
debug!(
parent_obj_id = msg.obj_id(),
obj_type = tt.interface(),
obj_id = id,
existing_obj_type = existing_obj_type.interface(),
"Trying to create object via message {}::{} but the object ID is already used!",
parent_obj.interface(),
msg.self_msg_name()
);
return false;
}
debug!( debug!(
parent_obj_id = msg.obj_id(), parent_obj_id = msg.obj_id(),
obj_type = tt.interface(), obj_type = tt.interface(),
@ -106,8 +44,7 @@ impl WlMitmState {
self.objects.record_object(tt, id); self.objects.record_object(tt, id);
} }
} else { } else {
error!("Parent object ID {} not found!", msg.obj_id()); error!("Parent object ID {} not found, ignoring", msg.obj_id());
return false;
} }
} else if msg.is_destructor() { } else if msg.is_destructor() {
if let Some(obj_type) = self.objects.lookup_object(msg.obj_id()) { if let Some(obj_type) = self.objects.lookup_object(msg.obj_id()) {
@ -121,38 +58,25 @@ impl WlMitmState {
self.objects.remove_object(msg.obj_id()); self.objects.remove_object(msg.obj_id());
} }
true
} }
/// Returns the number of fds consumed while parsing the message as a concrete Wayland type, and a verdict /// Returns the number of fds consumed while parsing the message as a concrete Wayland type, and a verdict
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
pub async fn on_c2s_request(&mut self, raw_msg: &WlRawMsg) -> WlMitmOutcome { pub async fn on_c2s_request(&mut self, raw_msg: &WlRawMsg) -> (usize, bool) {
let mut outcome: WlMitmOutcome = Default::default();
let msg = match crate::proto::decode_request(&self.objects, raw_msg) { let msg = match crate::proto::decode_request(&self.objects, raw_msg) {
WaylandProtocolParsingOutcome::Ok(msg) => msg, WaylandProtocolParsingOutcome::Ok(msg) => msg,
_ => { _ => {
let obj_type = self
.objects
.lookup_object(raw_msg.obj_id)
.map(|t| t.interface());
error!( error!(
obj_id = raw_msg.obj_id, obj_id = raw_msg.obj_id,
obj_type = ?obj_type,
opcode = raw_msg.opcode, opcode = raw_msg.opcode,
num_fds = raw_msg.fds.len(), num_fds = raw_msg.fds.len(),
"Malformed or unknown request" "Malformed or unknown request"
); );
return outcome.terminate(); return (0, false);
} }
}; };
outcome.set_consumed_fds(msg.num_consumed_fds()); self.handle_created_or_destroyed_objects(&*msg);
if !self.handle_created_or_destroyed_objects(&*msg) {
return outcome.terminate();
}
// The bind request doesn't create interface with a fixed type; handle it separately. // The bind request doesn't create interface with a fixed type; handle it separately.
if let Some(msg) = msg.downcast_ref::<WlRegistryBindRequest>() { if let Some(msg) = msg.downcast_ref::<WlRegistryBindRequest>() {
@ -162,7 +86,7 @@ impl WlMitmState {
// to bind to it; if it does, it's likely a malicious client! // to bind to it; if it does, it's likely a malicious client!
// So, we simply remove these messages from the stream, which will cause the Wayland server to error out. // So, we simply remove these messages from the stream, which will cause the Wayland server to error out.
let Some(obj_type) = self.objects.lookup_global(msg.name) else { let Some(obj_type) = self.objects.lookup_global(msg.name) else {
return outcome.terminate(); return (0, false);
}; };
if obj_type.interface() != msg.id_interface_name { if obj_type.interface() != msg.id_interface_name {
@ -172,7 +96,7 @@ impl WlMitmState {
msg.name, msg.name,
obj_type.interface() obj_type.interface()
); );
return outcome.terminate(); return (0, false);
} }
info!( info!(
@ -220,10 +144,9 @@ impl WlMitmState {
msg.self_msg_name(), msg.self_msg_name(),
status status
); );
return outcome.filtered();
} else {
return outcome.allowed();
} }
return (msg.num_consumed_fds(), status.success());
} }
} }
@ -232,7 +155,7 @@ impl WlMitmState {
msg.self_object_type().interface(), msg.self_object_type().interface(),
msg.self_msg_name() msg.self_msg_name()
); );
return outcome.filtered(); return (msg.num_consumed_fds(), false);
} }
WlFilterRequestAction::Block => { WlFilterRequestAction::Block => {
warn!( warn!(
@ -241,42 +164,31 @@ impl WlMitmState {
msg.self_msg_name() msg.self_msg_name()
); );
// TODO: don't just return false, build an error event // TODO: don't just return false, build an error event
return outcome.filtered(); return (msg.num_consumed_fds(), false);
} }
} }
} }
} }
outcome.allowed() (msg.num_consumed_fds(), true)
} }
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
pub async fn on_s2c_event(&mut self, raw_msg: &WlRawMsg) -> WlMitmOutcome { pub async fn on_s2c_event(&mut self, raw_msg: &WlRawMsg) -> (usize, bool) {
let mut outcome: WlMitmOutcome = Default::default();
let msg = match crate::proto::decode_event(&self.objects, raw_msg) { let msg = match crate::proto::decode_event(&self.objects, raw_msg) {
WaylandProtocolParsingOutcome::Ok(msg) => msg, WaylandProtocolParsingOutcome::Ok(msg) => msg,
_ => { _ => {
let obj_type = self
.objects
.lookup_object(raw_msg.obj_id)
.map(|t| t.interface());
error!( error!(
obj_id = raw_msg.obj_id, obj_id = raw_msg.obj_id,
obj_type = ?obj_type,
opcode = raw_msg.opcode, opcode = raw_msg.opcode,
num_fds = raw_msg.fds.len(), num_fds = raw_msg.fds.len(),
"Malformed or unknown event" "Malformed or unknown event"
); );
return outcome.terminate(); return (0, false);
} }
}; };
outcome.set_consumed_fds(msg.num_consumed_fds()); self.handle_created_or_destroyed_objects(&*msg);
if !self.handle_created_or_destroyed_objects(&*msg) {
return outcome.terminate();
}
if let Some(msg) = msg.downcast_ref::<WlRegistryGlobalEvent>() { if let Some(msg) = msg.downcast_ref::<WlRegistryGlobalEvent>() {
// This event is how Wayland servers announce globals -- and they are the entrypoint to // This event is how Wayland servers announce globals -- and they are the entrypoint to
@ -295,7 +207,7 @@ impl WlMitmState {
"Unknown interface removed! If required, please include its XML when building wl-mitm!" "Unknown interface removed! If required, please include its XML when building wl-mitm!"
); );
return outcome.filtered(); return (0, false);
}; };
// To block entire extensions, we just need to filter out their announced global objects. // To block entire extensions, we just need to filter out their announced global objects.
@ -304,7 +216,7 @@ impl WlMitmState {
interface = msg.interface, interface = msg.interface,
"Removing interface from published globals" "Removing interface from published globals"
); );
return outcome.filtered(); return (0, false);
} }
// Else, record the global object. These are the only ones we're ever going to allow through. // Else, record the global object. These are the only ones we're ever going to allow through.
@ -318,6 +230,6 @@ impl WlMitmState {
self.objects.remove_object(msg.id); self.objects.remove_object(msg.id);
} }
outcome.allowed() (msg.num_consumed_fds(), true)
} }
} }