Compare commits
4 commits
65c5567c88
...
f78628a3fd
Author | SHA1 | Date | |
---|---|---|---|
f78628a3fd | |||
aea28e4457 | |||
49947ba052 | |||
6d772971ff |
3 changed files with 123 additions and 26 deletions
|
@ -1,7 +1,7 @@
|
||||||
[socket]
|
[socket]
|
||||||
# Which socket to listen on? If relative,
|
# Which socket to listen on? If relative,
|
||||||
# defaults to being relative to $XDG_RUNTIME_DIR
|
# defaults to being relative to $XDG_RUNTIME_DIR
|
||||||
listen = "wayland-2"
|
listen = "wayland-10"
|
||||||
# Which Wayland socket to use as upstream?
|
# Which Wayland socket to use as upstream?
|
||||||
# If missing, defaults to $WAYLAND_DISPLAY
|
# If missing, defaults to $WAYLAND_DISPLAY
|
||||||
# upstream = "wayland-1"
|
# upstream = "wayland-1"
|
||||||
|
@ -16,6 +16,7 @@ allowed_globals = [
|
||||||
"wl_compositor",
|
"wl_compositor",
|
||||||
"wl_shm",
|
"wl_shm",
|
||||||
"wl_data_device_manager",
|
"wl_data_device_manager",
|
||||||
|
"wl_output", # each output is also a global
|
||||||
"wl_seat",
|
"wl_seat",
|
||||||
# Window management
|
# Window management
|
||||||
"xdg_wm_base",
|
"xdg_wm_base",
|
||||||
|
|
22
src/main.rs
22
src/main.rs
|
@ -10,7 +10,7 @@ use std::{io, path::Path, sync::Arc};
|
||||||
|
|
||||||
use config::Config;
|
use config::Config;
|
||||||
use io_util::{WlMsgReader, WlMsgWriter};
|
use io_util::{WlMsgReader, WlMsgWriter};
|
||||||
use state::WlMitmState;
|
use state::{WlMitmOutcome, WlMitmState, WlMitmVerdict};
|
||||||
use tokio::net::{UnixListener, UnixStream};
|
use tokio::net::{UnixListener, UnixStream};
|
||||||
use tracing::{Instrument, Level, debug, error, info, span};
|
use tracing::{Instrument, Level, debug, error, info, span};
|
||||||
|
|
||||||
|
@ -93,11 +93,15 @@ pub async fn handle_conn(
|
||||||
codec::DecoderOutcome::Decoded(mut wl_raw_msg) => {
|
codec::DecoderOutcome::Decoded(mut wl_raw_msg) => {
|
||||||
debug!(obj_id = wl_raw_msg.obj_id, opcode = wl_raw_msg.opcode, num_fds = wl_raw_msg.fds.len(), "s2c event");
|
debug!(obj_id = wl_raw_msg.obj_id, opcode = wl_raw_msg.opcode, num_fds = wl_raw_msg.fds.len(), "s2c event");
|
||||||
|
|
||||||
let (num_consumed_fds, verdict) = state.on_s2c_event(&wl_raw_msg).await;
|
let WlMitmOutcome(num_consumed_fds, verdict) = state.on_s2c_event(&wl_raw_msg).await;
|
||||||
upstream_read.return_unused_fds(&mut wl_raw_msg, num_consumed_fds);
|
upstream_read.return_unused_fds(&mut wl_raw_msg, num_consumed_fds);
|
||||||
|
|
||||||
if verdict {
|
match verdict {
|
||||||
downstream_write.queue_write(wl_raw_msg);
|
WlMitmVerdict::Allowed => {
|
||||||
|
downstream_write.queue_write(wl_raw_msg);
|
||||||
|
},
|
||||||
|
WlMitmVerdict::Terminate => break Err(io::Error::new(io::ErrorKind::ConnectionAborted, "aborting connection")),
|
||||||
|
_ => {}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
codec::DecoderOutcome::Incomplete => continue,
|
codec::DecoderOutcome::Incomplete => continue,
|
||||||
|
@ -109,11 +113,15 @@ pub async fn handle_conn(
|
||||||
codec::DecoderOutcome::Decoded(mut wl_raw_msg) => {
|
codec::DecoderOutcome::Decoded(mut wl_raw_msg) => {
|
||||||
debug!(obj_id = wl_raw_msg.obj_id, opcode = wl_raw_msg.opcode, num_fds = wl_raw_msg.fds.len(), "c2s request");
|
debug!(obj_id = wl_raw_msg.obj_id, opcode = wl_raw_msg.opcode, num_fds = wl_raw_msg.fds.len(), "c2s request");
|
||||||
|
|
||||||
let (num_consumed_fds, verdict) = state.on_c2s_request(&wl_raw_msg).await;
|
let WlMitmOutcome(num_consumed_fds, verdict) = state.on_c2s_request(&wl_raw_msg).await;
|
||||||
downstream_read.return_unused_fds(&mut wl_raw_msg, num_consumed_fds);
|
downstream_read.return_unused_fds(&mut wl_raw_msg, num_consumed_fds);
|
||||||
|
|
||||||
if verdict {
|
match verdict {
|
||||||
upstream_write.queue_write(wl_raw_msg);
|
WlMitmVerdict::Allowed => {
|
||||||
|
upstream_write.queue_write(wl_raw_msg);
|
||||||
|
},
|
||||||
|
WlMitmVerdict::Terminate => break Err(io::Error::new(io::ErrorKind::ConnectionAborted, "aborting connection")),
|
||||||
|
_ => {}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
codec::DecoderOutcome::Incomplete => continue,
|
codec::DecoderOutcome::Incomplete => continue,
|
||||||
|
|
124
src/state.rs
124
src/state.rs
|
@ -12,6 +12,52 @@ use crate::{
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// What to do for a message?
|
||||||
|
pub enum WlMitmVerdict {
|
||||||
|
/// This message is allowed. Pass it through to the opposite end.
|
||||||
|
Allowed,
|
||||||
|
/// This message is filtered.
|
||||||
|
/// TODO: We should probably construct a proper error response
|
||||||
|
Filtered,
|
||||||
|
/// Terminate this entire session. Something is off.
|
||||||
|
Terminate,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for WlMitmVerdict {
|
||||||
|
fn default() -> Self {
|
||||||
|
WlMitmVerdict::Terminate
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Result returned by [WlMitmState] when handling messages.
|
||||||
|
/// It's a pair of (num_consumed_fds, verdict).
|
||||||
|
///
|
||||||
|
/// We need to return back unused fds from the [WlRawMsg], which
|
||||||
|
/// is why this has to be returned from here.
|
||||||
|
#[derive(Default)]
|
||||||
|
pub struct WlMitmOutcome(pub usize, pub WlMitmVerdict);
|
||||||
|
|
||||||
|
impl WlMitmOutcome {
|
||||||
|
fn set_consumed_fds(&mut self, consumed_fds: usize) {
|
||||||
|
self.0 = consumed_fds;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn allowed(mut self) -> Self {
|
||||||
|
self.1 = WlMitmVerdict::Allowed;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
fn filtered(mut self) -> Self {
|
||||||
|
self.1 = WlMitmVerdict::Filtered;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
fn terminate(mut self) -> Self {
|
||||||
|
self.1 = WlMitmVerdict::Terminate;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub struct WlMitmState {
|
pub struct WlMitmState {
|
||||||
config: Arc<Config>,
|
config: Arc<Config>,
|
||||||
objects: WlObjects,
|
objects: WlObjects,
|
||||||
|
@ -27,12 +73,28 @@ impl WlMitmState {
|
||||||
|
|
||||||
/// Handle messages which register new objects with known interfaces or deletes them.
|
/// Handle messages which register new objects with known interfaces or deletes them.
|
||||||
///
|
///
|
||||||
|
/// If there is an error, this function will return false and the connection shall be terminated.
|
||||||
|
///
|
||||||
/// Note that most _globals_ are instantiated using [WlRegistryBindRequest]. That request
|
/// Note that most _globals_ are instantiated using [WlRegistryBindRequest]. That request
|
||||||
/// is not handled here.
|
/// is not handled here.
|
||||||
fn handle_created_or_destroyed_objects(&mut self, msg: &dyn AnyWlParsedMessage<'_>) {
|
fn handle_created_or_destroyed_objects(&mut self, msg: &dyn AnyWlParsedMessage<'_>) -> bool {
|
||||||
if let Some(created_objects) = msg.known_objects_created() {
|
if let Some(created_objects) = msg.known_objects_created() {
|
||||||
if let Some(parent_obj) = self.objects.lookup_object(msg.obj_id()) {
|
if let Some(parent_obj) = self.objects.lookup_object(msg.obj_id()) {
|
||||||
for (id, tt) in created_objects.into_iter() {
|
for (id, tt) in created_objects.into_iter() {
|
||||||
|
if let Some(existing_obj_type) = self.objects.lookup_object(id) {
|
||||||
|
debug!(
|
||||||
|
parent_obj_id = msg.obj_id(),
|
||||||
|
obj_type = tt.interface(),
|
||||||
|
obj_id = id,
|
||||||
|
existing_obj_type = existing_obj_type.interface(),
|
||||||
|
"Trying to create object via message {}::{} but the object ID is already used!",
|
||||||
|
parent_obj.interface(),
|
||||||
|
msg.self_msg_name()
|
||||||
|
);
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
debug!(
|
debug!(
|
||||||
parent_obj_id = msg.obj_id(),
|
parent_obj_id = msg.obj_id(),
|
||||||
obj_type = tt.interface(),
|
obj_type = tt.interface(),
|
||||||
|
@ -44,7 +106,8 @@ impl WlMitmState {
|
||||||
self.objects.record_object(tt, id);
|
self.objects.record_object(tt, id);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
error!("Parent object ID {} not found, ignoring", msg.obj_id());
|
error!("Parent object ID {} not found!", msg.obj_id());
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
} else if msg.is_destructor() {
|
} else if msg.is_destructor() {
|
||||||
if let Some(obj_type) = self.objects.lookup_object(msg.obj_id()) {
|
if let Some(obj_type) = self.objects.lookup_object(msg.obj_id()) {
|
||||||
|
@ -58,25 +121,38 @@ impl WlMitmState {
|
||||||
|
|
||||||
self.objects.remove_object(msg.obj_id());
|
self.objects.remove_object(msg.obj_id());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the number of fds consumed while parsing the message as a concrete Wayland type, and a verdict
|
/// Returns the number of fds consumed while parsing the message as a concrete Wayland type, and a verdict
|
||||||
#[tracing::instrument(skip_all)]
|
#[tracing::instrument(skip_all)]
|
||||||
pub async fn on_c2s_request(&mut self, raw_msg: &WlRawMsg) -> (usize, bool) {
|
pub async fn on_c2s_request(&mut self, raw_msg: &WlRawMsg) -> WlMitmOutcome {
|
||||||
|
let mut outcome: WlMitmOutcome = Default::default();
|
||||||
let msg = match crate::proto::decode_request(&self.objects, raw_msg) {
|
let msg = match crate::proto::decode_request(&self.objects, raw_msg) {
|
||||||
WaylandProtocolParsingOutcome::Ok(msg) => msg,
|
WaylandProtocolParsingOutcome::Ok(msg) => msg,
|
||||||
_ => {
|
_ => {
|
||||||
|
let obj_type = self
|
||||||
|
.objects
|
||||||
|
.lookup_object(raw_msg.obj_id)
|
||||||
|
.map(|t| t.interface());
|
||||||
|
|
||||||
error!(
|
error!(
|
||||||
obj_id = raw_msg.obj_id,
|
obj_id = raw_msg.obj_id,
|
||||||
|
obj_type = ?obj_type,
|
||||||
opcode = raw_msg.opcode,
|
opcode = raw_msg.opcode,
|
||||||
num_fds = raw_msg.fds.len(),
|
num_fds = raw_msg.fds.len(),
|
||||||
"Malformed or unknown request"
|
"Malformed or unknown request"
|
||||||
);
|
);
|
||||||
return (0, false);
|
return outcome.terminate();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
self.handle_created_or_destroyed_objects(&*msg);
|
outcome.set_consumed_fds(msg.num_consumed_fds());
|
||||||
|
|
||||||
|
if !self.handle_created_or_destroyed_objects(&*msg) {
|
||||||
|
return outcome.terminate();
|
||||||
|
}
|
||||||
|
|
||||||
// The bind request doesn't create interface with a fixed type; handle it separately.
|
// The bind request doesn't create interface with a fixed type; handle it separately.
|
||||||
if let Some(msg) = msg.downcast_ref::<WlRegistryBindRequest>() {
|
if let Some(msg) = msg.downcast_ref::<WlRegistryBindRequest>() {
|
||||||
|
@ -86,7 +162,7 @@ impl WlMitmState {
|
||||||
// to bind to it; if it does, it's likely a malicious client!
|
// to bind to it; if it does, it's likely a malicious client!
|
||||||
// So, we simply remove these messages from the stream, which will cause the Wayland server to error out.
|
// So, we simply remove these messages from the stream, which will cause the Wayland server to error out.
|
||||||
let Some(obj_type) = self.objects.lookup_global(msg.name) else {
|
let Some(obj_type) = self.objects.lookup_global(msg.name) else {
|
||||||
return (0, false);
|
return outcome.terminate();
|
||||||
};
|
};
|
||||||
|
|
||||||
if obj_type.interface() != msg.id_interface_name {
|
if obj_type.interface() != msg.id_interface_name {
|
||||||
|
@ -96,7 +172,7 @@ impl WlMitmState {
|
||||||
msg.name,
|
msg.name,
|
||||||
obj_type.interface()
|
obj_type.interface()
|
||||||
);
|
);
|
||||||
return (0, false);
|
return outcome.terminate();
|
||||||
}
|
}
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
|
@ -144,9 +220,10 @@ impl WlMitmState {
|
||||||
msg.self_msg_name(),
|
msg.self_msg_name(),
|
||||||
status
|
status
|
||||||
);
|
);
|
||||||
|
return outcome.filtered();
|
||||||
|
} else {
|
||||||
|
return outcome.allowed();
|
||||||
}
|
}
|
||||||
|
|
||||||
return (msg.num_consumed_fds(), status.success());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -155,7 +232,7 @@ impl WlMitmState {
|
||||||
msg.self_object_type().interface(),
|
msg.self_object_type().interface(),
|
||||||
msg.self_msg_name()
|
msg.self_msg_name()
|
||||||
);
|
);
|
||||||
return (msg.num_consumed_fds(), false);
|
return outcome.filtered();
|
||||||
}
|
}
|
||||||
WlFilterRequestAction::Block => {
|
WlFilterRequestAction::Block => {
|
||||||
warn!(
|
warn!(
|
||||||
|
@ -164,31 +241,42 @@ impl WlMitmState {
|
||||||
msg.self_msg_name()
|
msg.self_msg_name()
|
||||||
);
|
);
|
||||||
// TODO: don't just return false, build an error event
|
// TODO: don't just return false, build an error event
|
||||||
return (msg.num_consumed_fds(), false);
|
return outcome.filtered();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
(msg.num_consumed_fds(), true)
|
outcome.allowed()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument(skip_all)]
|
#[tracing::instrument(skip_all)]
|
||||||
pub async fn on_s2c_event(&mut self, raw_msg: &WlRawMsg) -> (usize, bool) {
|
pub async fn on_s2c_event(&mut self, raw_msg: &WlRawMsg) -> WlMitmOutcome {
|
||||||
|
let mut outcome: WlMitmOutcome = Default::default();
|
||||||
let msg = match crate::proto::decode_event(&self.objects, raw_msg) {
|
let msg = match crate::proto::decode_event(&self.objects, raw_msg) {
|
||||||
WaylandProtocolParsingOutcome::Ok(msg) => msg,
|
WaylandProtocolParsingOutcome::Ok(msg) => msg,
|
||||||
_ => {
|
_ => {
|
||||||
|
let obj_type = self
|
||||||
|
.objects
|
||||||
|
.lookup_object(raw_msg.obj_id)
|
||||||
|
.map(|t| t.interface());
|
||||||
|
|
||||||
error!(
|
error!(
|
||||||
obj_id = raw_msg.obj_id,
|
obj_id = raw_msg.obj_id,
|
||||||
|
obj_type = ?obj_type,
|
||||||
opcode = raw_msg.opcode,
|
opcode = raw_msg.opcode,
|
||||||
num_fds = raw_msg.fds.len(),
|
num_fds = raw_msg.fds.len(),
|
||||||
"Malformed or unknown event"
|
"Malformed or unknown event"
|
||||||
);
|
);
|
||||||
return (0, false);
|
return outcome.terminate();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
self.handle_created_or_destroyed_objects(&*msg);
|
outcome.set_consumed_fds(msg.num_consumed_fds());
|
||||||
|
|
||||||
|
if !self.handle_created_or_destroyed_objects(&*msg) {
|
||||||
|
return outcome.terminate();
|
||||||
|
}
|
||||||
|
|
||||||
if let Some(msg) = msg.downcast_ref::<WlRegistryGlobalEvent>() {
|
if let Some(msg) = msg.downcast_ref::<WlRegistryGlobalEvent>() {
|
||||||
// This event is how Wayland servers announce globals -- and they are the entrypoint to
|
// This event is how Wayland servers announce globals -- and they are the entrypoint to
|
||||||
|
@ -207,7 +295,7 @@ impl WlMitmState {
|
||||||
"Unknown interface removed! If required, please include its XML when building wl-mitm!"
|
"Unknown interface removed! If required, please include its XML when building wl-mitm!"
|
||||||
);
|
);
|
||||||
|
|
||||||
return (0, false);
|
return outcome.filtered();
|
||||||
};
|
};
|
||||||
|
|
||||||
// To block entire extensions, we just need to filter out their announced global objects.
|
// To block entire extensions, we just need to filter out their announced global objects.
|
||||||
|
@ -216,7 +304,7 @@ impl WlMitmState {
|
||||||
interface = msg.interface,
|
interface = msg.interface,
|
||||||
"Removing interface from published globals"
|
"Removing interface from published globals"
|
||||||
);
|
);
|
||||||
return (0, false);
|
return outcome.filtered();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Else, record the global object. These are the only ones we're ever going to allow through.
|
// Else, record the global object. These are the only ones we're ever going to allow through.
|
||||||
|
@ -230,6 +318,6 @@ impl WlMitmState {
|
||||||
self.objects.remove_object(msg.id);
|
self.objects.remove_object(msg.id);
|
||||||
}
|
}
|
||||||
|
|
||||||
(msg.num_consumed_fds(), true)
|
outcome.allowed()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue