Compare commits

..

2 commits

Author SHA1 Message Date
cbb4a68c9f Handle global remove event 2025-02-28 19:15:40 -05:00
75f5bf8f73 Retire the match_decoded!{} macro 2025-02-28 19:06:06 -05:00
3 changed files with 103 additions and 83 deletions

View file

@ -75,4 +75,8 @@ impl WlObjects {
pub fn lookup_global(&self, name: u32) -> Option<&str> {
self.global_names.get(&name).map(|s| s.as_str())
}
pub fn remove_global(&mut self, name: u32) {
self.global_names.remove(&name);
}
}

View file

@ -25,18 +25,6 @@ macro_rules! bubble_malformed {
}};
}
macro_rules! match_decoded {
(match $decoded:ident {$($t:ty => $act:block$(,)?)+}) => {
if let crate::proto::WaylandProtocolParsingOutcome::Ok($decoded) = $decoded {
$(
if let Some($decoded) = $decoded.downcast_ref::<$t>() {
$act
}
)+
}
};
}
#[derive(PartialEq, Eq)]
pub enum WlMsgType {
Request,

View file

@ -7,8 +7,9 @@ use crate::{
config::Config,
objects::WlObjects,
proto::{
WL_REGISTRY, WlDisplayDeleteIdEvent, WlDisplayGetRegistryRequest, WlRegistryBindRequest,
WlRegistryGlobalEvent,
WL_REGISTRY, WaylandProtocolParsingOutcome, WlDisplayDeleteIdEvent,
WlDisplayGetRegistryRequest, WlRegistryBindRequest, WlRegistryGlobalEvent,
WlRegistryGlobalRemoveEvent,
},
};
@ -27,42 +28,60 @@ impl WlMitmState {
#[tracing::instrument(skip_all)]
pub fn on_c2s_request(&mut self, raw_msg: &WlRawMsg) -> bool {
let msg = crate::proto::decode_request(&self.objects, raw_msg);
if let crate::proto::WaylandProtocolParsingOutcome::MalformedMessage = msg {
error!(
obj_id = raw_msg.obj_id,
opcode = raw_msg.opcode,
num_fds = raw_msg.fds.len(),
"Malformed request"
let msg = match crate::proto::decode_request(&self.objects, raw_msg) {
WaylandProtocolParsingOutcome::Ok(msg) => msg,
WaylandProtocolParsingOutcome::MalformedMessage => {
// Kill all malformed messages
// Note that they are different from messages whose object / message types are unknown
error!(
obj_id = raw_msg.obj_id,
opcode = raw_msg.opcode,
num_fds = raw_msg.fds.len(),
"Malformed request"
);
return false;
}
_ => {
// Pass through all unknown messages -- they could be from a Wayland protocol we haven't
// been built against!
// Note that this won't pass through messages for globals we haven't allowed:
// to use a global, a client must first _bind_ that global, and _that_ message is intercepted
// below. There, we match based on the textual representation of the interface, so it works
// even for globals from protocols we don't know.
// It does mean we can't filter against methods that create more objects _from_ that
// global, though.
return true;
}
};
if let Some(msg) = msg.downcast_ref::<WlDisplayGetRegistryRequest>() {
self.objects.record_object(WL_REGISTRY, msg.registry);
} else if let Some(msg) = msg.downcast_ref::<WlRegistryBindRequest>() {
// If we have blocked this global, this lookup should return None, thus blocking client attempts
// to bind to a blocked global.
// Note that because we've removed said global from the registry, a client _SHOULD NOT_ be attempting
// to bind to it; if it does, it's likely a malicious client!
// So, we simply remove these messages from the stream, which will cause the Wayland server to error out.
let Some(interface) = self.objects.lookup_global(msg.name) else {
return false;
};
if interface != msg.id_interface_name {
error!(
"Client binding to interface {}, but the interface name {} should correspond to {}",
msg.id_interface_name, msg.name, interface
);
return false;
}
info!(
interface = interface,
obj_id = msg.id,
"Client binding interface"
);
return false;
}
match_decoded! {
match msg {
WlDisplayGetRegistryRequest => {
self.objects.record_object(WL_REGISTRY, msg.registry);
}
WlRegistryBindRequest => {
let Some(interface) = self.objects.lookup_global(msg.name) else {
return false;
};
if interface != msg.id_interface_name {
error!("Client binding to interface {}, but the interface name {} should correspond to {}", msg.id_interface_name, msg.name, interface);
return false;
}
info!(
interface = interface,
obj_id = msg.id,
"Client binding interface"
);
if let Some(t) = crate::proto::lookup_known_object_type(interface) {
self.objects.record_object(t, msg.id);
}
}
if let Some(t) = crate::proto::lookup_known_object_type(interface) {
self.objects.record_object(t, msg.id);
}
}
@ -71,42 +90,51 @@ impl WlMitmState {
#[tracing::instrument(skip_all)]
pub fn on_s2c_event(&mut self, raw_msg: &WlRawMsg) -> bool {
let msg = crate::proto::decode_event(&self.objects, raw_msg);
if let crate::proto::WaylandProtocolParsingOutcome::MalformedMessage = msg {
error!(
obj_id = raw_msg.obj_id,
opcode = raw_msg.opcode,
"Malformed event"
);
return false;
}
match_decoded! {
match msg {
WlRegistryGlobalEvent => {
debug!(
interface = msg.interface,
name = msg.name,
version = msg.version,
"got global"
);
self.objects.record_global(msg.name, msg.interface);
if !self.config.filter.allowed_globals.contains(msg.interface) {
info!(
interface = msg.interface,
"Removing interface from published globals"
);
return false;
}
}
WlDisplayDeleteIdEvent => {
// When an object is acknowledged to be deleted, remove it from our
// internal cache of all registered objects
self.objects.remove_object(msg.id);
}
let msg = match crate::proto::decode_event(&self.objects, raw_msg) {
WaylandProtocolParsingOutcome::Ok(msg) => msg,
WaylandProtocolParsingOutcome::MalformedMessage => {
error!(
obj_id = raw_msg.obj_id,
opcode = raw_msg.opcode,
num_fds = raw_msg.fds.len(),
"Malformed event"
);
return false;
}
_ => {
return true;
}
};
if let Some(msg) = msg.downcast_ref::<WlRegistryGlobalEvent>() {
// This event is how Wayland servers announce globals -- and they are the entrypoint to
// most extensions! You need at least one global registered for clients to be able to
// access methods from that extension; but those methods _could_ create more objects.
debug!(
interface = msg.interface,
name = msg.name,
version = msg.version,
"got global"
);
// To block entire extensions, we just need to filter out their announced global objects.
if !self.config.filter.allowed_globals.contains(msg.interface) {
info!(
interface = msg.interface,
"Removing interface from published globals"
);
return false;
}
// Else, record the global object. These are the only ones we're ever going to allow through.
// We block bind requests on any interface that's not recorded here.
self.objects.record_global(msg.name, msg.interface);
} else if let Some(msg) = msg.downcast_ref::<WlRegistryGlobalRemoveEvent>() {
// Remove globals that the server has removed
self.objects.remove_global(msg.name);
} else if let Some(msg) = msg.downcast_ref::<WlDisplayDeleteIdEvent>() {
// Server has acknowledged deletion of an object
self.objects.remove_object(msg.id);
}
true