Compare commits
4 commits
65c5567c88
...
f78628a3fd
Author | SHA1 | Date | |
---|---|---|---|
f78628a3fd | |||
aea28e4457 | |||
49947ba052 | |||
6d772971ff |
3 changed files with 123 additions and 26 deletions
|
@ -1,7 +1,7 @@
|
|||
[socket]
|
||||
# Which socket to listen on? If relative,
|
||||
# defaults to being relative to $XDG_RUNTIME_DIR
|
||||
listen = "wayland-2"
|
||||
listen = "wayland-10"
|
||||
# Which Wayland socket to use as upstream?
|
||||
# If missing, defaults to $WAYLAND_DISPLAY
|
||||
# upstream = "wayland-1"
|
||||
|
@ -16,6 +16,7 @@ allowed_globals = [
|
|||
"wl_compositor",
|
||||
"wl_shm",
|
||||
"wl_data_device_manager",
|
||||
"wl_output", # each output is also a global
|
||||
"wl_seat",
|
||||
# Window management
|
||||
"xdg_wm_base",
|
||||
|
|
22
src/main.rs
22
src/main.rs
|
@ -10,7 +10,7 @@ use std::{io, path::Path, sync::Arc};
|
|||
|
||||
use config::Config;
|
||||
use io_util::{WlMsgReader, WlMsgWriter};
|
||||
use state::WlMitmState;
|
||||
use state::{WlMitmOutcome, WlMitmState, WlMitmVerdict};
|
||||
use tokio::net::{UnixListener, UnixStream};
|
||||
use tracing::{Instrument, Level, debug, error, info, span};
|
||||
|
||||
|
@ -93,11 +93,15 @@ pub async fn handle_conn(
|
|||
codec::DecoderOutcome::Decoded(mut wl_raw_msg) => {
|
||||
debug!(obj_id = wl_raw_msg.obj_id, opcode = wl_raw_msg.opcode, num_fds = wl_raw_msg.fds.len(), "s2c event");
|
||||
|
||||
let (num_consumed_fds, verdict) = state.on_s2c_event(&wl_raw_msg).await;
|
||||
let WlMitmOutcome(num_consumed_fds, verdict) = state.on_s2c_event(&wl_raw_msg).await;
|
||||
upstream_read.return_unused_fds(&mut wl_raw_msg, num_consumed_fds);
|
||||
|
||||
if verdict {
|
||||
downstream_write.queue_write(wl_raw_msg);
|
||||
match verdict {
|
||||
WlMitmVerdict::Allowed => {
|
||||
downstream_write.queue_write(wl_raw_msg);
|
||||
},
|
||||
WlMitmVerdict::Terminate => break Err(io::Error::new(io::ErrorKind::ConnectionAborted, "aborting connection")),
|
||||
_ => {}
|
||||
}
|
||||
},
|
||||
codec::DecoderOutcome::Incomplete => continue,
|
||||
|
@ -109,11 +113,15 @@ pub async fn handle_conn(
|
|||
codec::DecoderOutcome::Decoded(mut wl_raw_msg) => {
|
||||
debug!(obj_id = wl_raw_msg.obj_id, opcode = wl_raw_msg.opcode, num_fds = wl_raw_msg.fds.len(), "c2s request");
|
||||
|
||||
let (num_consumed_fds, verdict) = state.on_c2s_request(&wl_raw_msg).await;
|
||||
let WlMitmOutcome(num_consumed_fds, verdict) = state.on_c2s_request(&wl_raw_msg).await;
|
||||
downstream_read.return_unused_fds(&mut wl_raw_msg, num_consumed_fds);
|
||||
|
||||
if verdict {
|
||||
upstream_write.queue_write(wl_raw_msg);
|
||||
match verdict {
|
||||
WlMitmVerdict::Allowed => {
|
||||
upstream_write.queue_write(wl_raw_msg);
|
||||
},
|
||||
WlMitmVerdict::Terminate => break Err(io::Error::new(io::ErrorKind::ConnectionAborted, "aborting connection")),
|
||||
_ => {}
|
||||
}
|
||||
},
|
||||
codec::DecoderOutcome::Incomplete => continue,
|
||||
|
|
124
src/state.rs
124
src/state.rs
|
@ -12,6 +12,52 @@ use crate::{
|
|||
},
|
||||
};
|
||||
|
||||
/// What to do for a message?
|
||||
pub enum WlMitmVerdict {
|
||||
/// This message is allowed. Pass it through to the opposite end.
|
||||
Allowed,
|
||||
/// This message is filtered.
|
||||
/// TODO: We should probably construct a proper error response
|
||||
Filtered,
|
||||
/// Terminate this entire session. Something is off.
|
||||
Terminate,
|
||||
}
|
||||
|
||||
impl Default for WlMitmVerdict {
|
||||
fn default() -> Self {
|
||||
WlMitmVerdict::Terminate
|
||||
}
|
||||
}
|
||||
|
||||
/// Result returned by [WlMitmState] when handling messages.
|
||||
/// It's a pair of (num_consumed_fds, verdict).
|
||||
///
|
||||
/// We need to return back unused fds from the [WlRawMsg], which
|
||||
/// is why this has to be returned from here.
|
||||
#[derive(Default)]
|
||||
pub struct WlMitmOutcome(pub usize, pub WlMitmVerdict);
|
||||
|
||||
impl WlMitmOutcome {
|
||||
fn set_consumed_fds(&mut self, consumed_fds: usize) {
|
||||
self.0 = consumed_fds;
|
||||
}
|
||||
|
||||
fn allowed(mut self) -> Self {
|
||||
self.1 = WlMitmVerdict::Allowed;
|
||||
self
|
||||
}
|
||||
|
||||
fn filtered(mut self) -> Self {
|
||||
self.1 = WlMitmVerdict::Filtered;
|
||||
self
|
||||
}
|
||||
|
||||
fn terminate(mut self) -> Self {
|
||||
self.1 = WlMitmVerdict::Terminate;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
pub struct WlMitmState {
|
||||
config: Arc<Config>,
|
||||
objects: WlObjects,
|
||||
|
@ -27,12 +73,28 @@ impl WlMitmState {
|
|||
|
||||
/// Handle messages which register new objects with known interfaces or deletes them.
|
||||
///
|
||||
/// If there is an error, this function will return false and the connection shall be terminated.
|
||||
///
|
||||
/// Note that most _globals_ are instantiated using [WlRegistryBindRequest]. That request
|
||||
/// is not handled here.
|
||||
fn handle_created_or_destroyed_objects(&mut self, msg: &dyn AnyWlParsedMessage<'_>) {
|
||||
fn handle_created_or_destroyed_objects(&mut self, msg: &dyn AnyWlParsedMessage<'_>) -> bool {
|
||||
if let Some(created_objects) = msg.known_objects_created() {
|
||||
if let Some(parent_obj) = self.objects.lookup_object(msg.obj_id()) {
|
||||
for (id, tt) in created_objects.into_iter() {
|
||||
if let Some(existing_obj_type) = self.objects.lookup_object(id) {
|
||||
debug!(
|
||||
parent_obj_id = msg.obj_id(),
|
||||
obj_type = tt.interface(),
|
||||
obj_id = id,
|
||||
existing_obj_type = existing_obj_type.interface(),
|
||||
"Trying to create object via message {}::{} but the object ID is already used!",
|
||||
parent_obj.interface(),
|
||||
msg.self_msg_name()
|
||||
);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
debug!(
|
||||
parent_obj_id = msg.obj_id(),
|
||||
obj_type = tt.interface(),
|
||||
|
@ -44,7 +106,8 @@ impl WlMitmState {
|
|||
self.objects.record_object(tt, id);
|
||||
}
|
||||
} else {
|
||||
error!("Parent object ID {} not found, ignoring", msg.obj_id());
|
||||
error!("Parent object ID {} not found!", msg.obj_id());
|
||||
return false;
|
||||
}
|
||||
} else if msg.is_destructor() {
|
||||
if let Some(obj_type) = self.objects.lookup_object(msg.obj_id()) {
|
||||
|
@ -58,25 +121,38 @@ impl WlMitmState {
|
|||
|
||||
self.objects.remove_object(msg.obj_id());
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
/// Returns the number of fds consumed while parsing the message as a concrete Wayland type, and a verdict
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn on_c2s_request(&mut self, raw_msg: &WlRawMsg) -> (usize, bool) {
|
||||
pub async fn on_c2s_request(&mut self, raw_msg: &WlRawMsg) -> WlMitmOutcome {
|
||||
let mut outcome: WlMitmOutcome = Default::default();
|
||||
let msg = match crate::proto::decode_request(&self.objects, raw_msg) {
|
||||
WaylandProtocolParsingOutcome::Ok(msg) => msg,
|
||||
_ => {
|
||||
let obj_type = self
|
||||
.objects
|
||||
.lookup_object(raw_msg.obj_id)
|
||||
.map(|t| t.interface());
|
||||
|
||||
error!(
|
||||
obj_id = raw_msg.obj_id,
|
||||
obj_type = ?obj_type,
|
||||
opcode = raw_msg.opcode,
|
||||
num_fds = raw_msg.fds.len(),
|
||||
"Malformed or unknown request"
|
||||
);
|
||||
return (0, false);
|
||||
return outcome.terminate();
|
||||
}
|
||||
};
|
||||
|
||||
self.handle_created_or_destroyed_objects(&*msg);
|
||||
outcome.set_consumed_fds(msg.num_consumed_fds());
|
||||
|
||||
if !self.handle_created_or_destroyed_objects(&*msg) {
|
||||
return outcome.terminate();
|
||||
}
|
||||
|
||||
// The bind request doesn't create interface with a fixed type; handle it separately.
|
||||
if let Some(msg) = msg.downcast_ref::<WlRegistryBindRequest>() {
|
||||
|
@ -86,7 +162,7 @@ impl WlMitmState {
|
|||
// to bind to it; if it does, it's likely a malicious client!
|
||||
// So, we simply remove these messages from the stream, which will cause the Wayland server to error out.
|
||||
let Some(obj_type) = self.objects.lookup_global(msg.name) else {
|
||||
return (0, false);
|
||||
return outcome.terminate();
|
||||
};
|
||||
|
||||
if obj_type.interface() != msg.id_interface_name {
|
||||
|
@ -96,7 +172,7 @@ impl WlMitmState {
|
|||
msg.name,
|
||||
obj_type.interface()
|
||||
);
|
||||
return (0, false);
|
||||
return outcome.terminate();
|
||||
}
|
||||
|
||||
info!(
|
||||
|
@ -144,9 +220,10 @@ impl WlMitmState {
|
|||
msg.self_msg_name(),
|
||||
status
|
||||
);
|
||||
return outcome.filtered();
|
||||
} else {
|
||||
return outcome.allowed();
|
||||
}
|
||||
|
||||
return (msg.num_consumed_fds(), status.success());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -155,7 +232,7 @@ impl WlMitmState {
|
|||
msg.self_object_type().interface(),
|
||||
msg.self_msg_name()
|
||||
);
|
||||
return (msg.num_consumed_fds(), false);
|
||||
return outcome.filtered();
|
||||
}
|
||||
WlFilterRequestAction::Block => {
|
||||
warn!(
|
||||
|
@ -164,31 +241,42 @@ impl WlMitmState {
|
|||
msg.self_msg_name()
|
||||
);
|
||||
// TODO: don't just return false, build an error event
|
||||
return (msg.num_consumed_fds(), false);
|
||||
return outcome.filtered();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
(msg.num_consumed_fds(), true)
|
||||
outcome.allowed()
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn on_s2c_event(&mut self, raw_msg: &WlRawMsg) -> (usize, bool) {
|
||||
pub async fn on_s2c_event(&mut self, raw_msg: &WlRawMsg) -> WlMitmOutcome {
|
||||
let mut outcome: WlMitmOutcome = Default::default();
|
||||
let msg = match crate::proto::decode_event(&self.objects, raw_msg) {
|
||||
WaylandProtocolParsingOutcome::Ok(msg) => msg,
|
||||
_ => {
|
||||
let obj_type = self
|
||||
.objects
|
||||
.lookup_object(raw_msg.obj_id)
|
||||
.map(|t| t.interface());
|
||||
|
||||
error!(
|
||||
obj_id = raw_msg.obj_id,
|
||||
obj_type = ?obj_type,
|
||||
opcode = raw_msg.opcode,
|
||||
num_fds = raw_msg.fds.len(),
|
||||
"Malformed or unknown event"
|
||||
);
|
||||
return (0, false);
|
||||
return outcome.terminate();
|
||||
}
|
||||
};
|
||||
|
||||
self.handle_created_or_destroyed_objects(&*msg);
|
||||
outcome.set_consumed_fds(msg.num_consumed_fds());
|
||||
|
||||
if !self.handle_created_or_destroyed_objects(&*msg) {
|
||||
return outcome.terminate();
|
||||
}
|
||||
|
||||
if let Some(msg) = msg.downcast_ref::<WlRegistryGlobalEvent>() {
|
||||
// This event is how Wayland servers announce globals -- and they are the entrypoint to
|
||||
|
@ -207,7 +295,7 @@ impl WlMitmState {
|
|||
"Unknown interface removed! If required, please include its XML when building wl-mitm!"
|
||||
);
|
||||
|
||||
return (0, false);
|
||||
return outcome.filtered();
|
||||
};
|
||||
|
||||
// To block entire extensions, we just need to filter out their announced global objects.
|
||||
|
@ -216,7 +304,7 @@ impl WlMitmState {
|
|||
interface = msg.interface,
|
||||
"Removing interface from published globals"
|
||||
);
|
||||
return (0, false);
|
||||
return outcome.filtered();
|
||||
}
|
||||
|
||||
// Else, record the global object. These are the only ones we're ever going to allow through.
|
||||
|
@ -230,6 +318,6 @@ impl WlMitmState {
|
|||
self.objects.remove_object(msg.id);
|
||||
}
|
||||
|
||||
(msg.num_consumed_fds(), true)
|
||||
outcome.allowed()
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue