wl-mitm/src/state.rs

517 lines
20 KiB
Rust

use std::sync::Arc;
use tracing::{debug, error, info, warn};
use crate::{
codec::WlRawMsg,
config::{Config, WlFilterRequestAction, WlFilterRequestBlockType},
objects::WlObjects,
proto::{
AnyWlParsedMessage, WaylandProtocolParsingOutcome, WlDisplayDeleteIdEvent,
WlKeyboardEnterEvent, WlPointerEnterEvent, WlRegistryBindRequest, WlRegistryGlobalEvent,
WlRegistryGlobalRemoveEvent, WlTouchDownEvent, XdgSurfaceGetToplevelRequest,
XdgToplevelSetAppIdRequest, XdgToplevelSetTitleRequest, XdgWmBaseGetXdgSurfaceRequest,
},
};
/// What to do for a message?
#[derive(Debug)]
pub enum WlMitmVerdict {
/// This message is allowed. Pass it through to the opposite end.
Allowed,
/// This message is filtered
Filtered,
/// This messages is rejected (i.e. filtered, but comes with an error code to return to sender)
Rejected(u32),
/// Terminate this entire session. Something is off.
Terminate,
}
impl WlMitmVerdict {
pub fn is_allowed(&self) -> bool {
matches!(self, WlMitmVerdict::Allowed)
}
}
impl Default for WlMitmVerdict {
fn default() -> Self {
WlMitmVerdict::Terminate
}
}
/// Result returned by [WlMitmState] when handling messages.
/// It's a pair of (num_consumed_fds, verdict).
///
/// We need to return back unused fds from the [WlRawMsg], which
/// is why this has to be returned from here.
#[derive(Default)]
pub struct WlMitmOutcome(pub usize, pub WlMitmVerdict);
impl WlMitmOutcome {
fn set_consumed_fds(&mut self, consumed_fds: usize) {
self.0 = consumed_fds;
}
fn allowed(mut self) -> Self {
self.1 = WlMitmVerdict::Allowed;
self
}
fn filtered(mut self) -> Self {
self.1 = WlMitmVerdict::Filtered;
self
}
fn terminate(mut self) -> Self {
self.1 = WlMitmVerdict::Terminate;
self
}
fn rejected(mut self, error_code: u32) -> Self {
self.1 = WlMitmVerdict::Rejected(error_code);
self
}
}
/// Association between a wl_surface and an xdg_surface, to facilitate
/// lookup for [ToplevelSurfaceInfo] from a wl_surface
struct SurfaceXdgAssociation(u32);
/// Association between an xdg_surface and an xdg_toplevel
struct XdgToplevelAssociation(u32);
/// A struct to track information about an app's top-level surfaces (windows)
/// This gets passed down to ask and notify scripts to produce user-friendly
/// messages.
#[derive(Default, Debug)]
struct ToplevelSurfaceInfo {
pub title: Option<String>,
pub app_id: Option<String>,
}
/// Tracks state for _one_ Wayland connection.
pub struct WlMitmState {
config: Arc<Config>,
objects: WlObjects,
/// The last toplevel object ID (NOT the underlying wl_surface) that was "active"
/// for this connection.
/// This is used to hint the ask and notify scripts about the app's id and name,
/// even though this can never actually be perfect -- we can't track precisely
/// what might have caused the last filtered request to happen!
last_toplevel: Option<u32>,
}
impl WlMitmState {
pub fn new(config: Arc<Config>) -> WlMitmState {
WlMitmState {
config,
objects: WlObjects::new(),
last_toplevel: None,
}
}
/// Handle messages which register new objects with known interfaces or deletes them.
///
/// If there is an error, this function will return false and the connection shall be terminated.
///
/// Note that most _globals_ are instantiated using [WlRegistryBindRequest]. That request
/// is not handled here.
fn handle_created_or_destroyed_objects(
&mut self,
msg: &dyn AnyWlParsedMessage,
from_client: bool,
) -> bool {
if let Some(created_objects) = msg.known_objects_created() {
if let Some(parent_obj) = self.objects.lookup_object(msg.obj_id()) {
for (id, tt) in created_objects.into_iter() {
if let Some(existing_obj_type) = self.objects.lookup_object(id) {
debug!(
parent_obj_id = msg.obj_id(),
obj_type = tt.interface(),
obj_id = id,
existing_obj_type = existing_obj_type.interface(),
is_half_destroyed = self.objects.is_half_destroyed(id),
"Trying to create object via message {}::{} but the object ID is already used!",
parent_obj.interface(),
msg.msg_name()
);
return false;
}
debug!(
parent_obj_id = msg.obj_id(),
obj_type = tt.interface(),
obj_id = id,
"Created object via message {}::{}",
parent_obj.interface(),
msg.msg_name()
);
self.objects.record_object(tt, id);
}
} else {
error!("Parent object ID {} not found!", msg.obj_id());
return false;
}
} else if msg.is_destructor() {
let Some(obj_type) = self.objects.lookup_object(msg.obj_id()) else {
// This shouldn't really happen -- to decode the message we have to have a record of the object
error!("Destructed object ID {} not found!", msg.obj_id());
return false;
};
debug!(
obj_id = msg.obj_id(),
"Object destructed via destructor {}::{}",
obj_type.interface(),
msg.msg_name()
);
self.objects.remove_object(msg.obj_id(), from_client);
if self.last_toplevel.is_some_and(|id| id == msg.obj_id()) {
self.last_toplevel = None;
}
}
true
}
fn prepare_command(
&self,
msg: &dyn AnyWlParsedMessage,
cmd_str: &str,
desc: &str,
) -> tokio::process::Command {
let mut cmd = tokio::process::Command::new(cmd_str);
cmd.arg(msg.object_type().interface());
cmd.arg(msg.msg_name());
cmd.arg(desc);
cmd.env("WL_MITM_MSG_JSON", msg.to_json());
if let Some(last_toplevel) = self.last_toplevel {
if let Some(info) = self
.objects
.get_object_extension::<ToplevelSurfaceInfo>(last_toplevel)
{
if let Some(ref title) = info.title {
cmd.env("WL_MITM_LAST_TOPLEVEL_TITLE", title);
}
if let Some(ref app_id) = info.app_id {
cmd.env("WL_MITM_LAST_TOPLEVEL_APP_ID", app_id);
}
}
}
cmd
}
fn update_last_active_surface(&mut self, surface: u32) {
if let Some(SurfaceXdgAssociation(xdg_surface)) = self.objects.get_object_extension(surface)
{
if let Some(XdgToplevelAssociation(xdg_toplevel)) =
self.objects.get_object_extension(*xdg_surface)
{
self.last_toplevel = Some(*xdg_toplevel);
}
}
}
/// Returns the number of fds consumed while parsing the message as a concrete Wayland type, and a verdict
#[tracing::instrument(skip_all)]
pub async fn on_c2s_request(&mut self, raw_msg: &WlRawMsg) -> WlMitmOutcome {
let mut outcome: WlMitmOutcome = Default::default();
let msg = match crate::proto::decode_request(&self.objects, raw_msg) {
WaylandProtocolParsingOutcome::Ok(msg) => msg,
_ => {
let obj_type = self
.objects
.lookup_object(raw_msg.obj_id)
.map(|t| t.interface());
error!(
obj_id = raw_msg.obj_id,
obj_type = ?obj_type,
opcode = raw_msg.opcode,
num_fds = raw_msg.fds.len(),
"Malformed or unknown request"
);
return outcome.terminate();
}
};
outcome.set_consumed_fds(msg.num_consumed_fds());
if self.config.logging.log_all_requests {
debug!(
obj_id = msg.obj_id(),
raw_payload_bytes = ?raw_msg.payload(),
num_fds = raw_msg.fds.len(),
num_consumed_fds = msg.num_consumed_fds(),
"{}::{}",
msg.object_type().interface(),
msg.msg_name(),
)
}
// To get here, the object referred to in raw_msg must exist, but it might already be destroyed by the client
// In that case, the client is broken!
if self.objects.is_half_destroyed(msg.obj_id()) {
error!(
obj_id = msg.obj_id(),
opcode = msg.opcode(),
"Client request detected on object already scheduled for destruction; aborting!"
);
return outcome.terminate();
}
if !self.handle_created_or_destroyed_objects(&*msg, true) {
return outcome.terminate();
}
// The bind request doesn't create interface with a fixed type; handle it separately.
if let Some(msg) = msg.downcast_ref::<WlRegistryBindRequest>() {
// If we have blocked this global, this lookup should return None, thus blocking client attempts
// to bind to a blocked global.
// Note that because we've removed said global from the registry, a client _SHOULD NOT_ be attempting
// to bind to it; if it does, it's likely a malicious client!
// So, we simply remove these messages from the stream, which will cause the Wayland server to error out.
let Some(obj_type) = self.objects.lookup_global(msg.name) else {
warn!(
interface = msg.name,
version = msg.id_interface_version,
obj_id = msg.id,
"Client binding non-existent or filtered interface"
);
return outcome.terminate();
};
if obj_type.interface() != msg.id_interface_name {
error!(
"Client binding to interface {}, but the interface name {} should correspond to {}",
msg.id_interface_name,
msg.name,
obj_type.interface()
);
return outcome.terminate();
}
info!(
interface = obj_type.interface(),
version = msg.id_interface_version,
obj_id = msg.id,
"Client binding interface"
);
self.objects.record_object(obj_type, msg.id);
} else if let Some(msg) = msg.downcast_ref::<XdgWmBaseGetXdgSurfaceRequest>() {
self.objects
.put_object_extension(msg.surface, SurfaceXdgAssociation(msg.id));
} else if let Some(msg) = msg.downcast_ref::<XdgSurfaceGetToplevelRequest>() {
self.objects
.put_object_extension(msg.obj_id(), XdgToplevelAssociation(msg.id));
self.objects
.put_object_extension(msg.id, ToplevelSurfaceInfo::default());
} else if let Some(msg) = msg.downcast_ref::<XdgToplevelSetAppIdRequest>() {
if let Some(info) = self
.objects
.get_object_extension_mut::<ToplevelSurfaceInfo>(msg.obj_id())
{
info.app_id = Some(msg.app_id.to_string());
}
} else if let Some(msg) = msg.downcast_ref::<XdgToplevelSetTitleRequest>() {
if let Some(info) = self
.objects
.get_object_extension_mut::<ToplevelSurfaceInfo>(msg.obj_id())
{
info.title = Some(msg.title.to_string());
}
}
// Handle requests configured to be filtered
if let Some(filtered_requests) = self
.config
.filter
.requests
.get(msg.object_type().interface())
{
if let Some(filtered) = filtered_requests
.iter()
.find(|f| f.requests.contains(msg.msg_name()))
{
match filtered.action {
WlFilterRequestAction::Ask => {
if let Some(ref ask_cmd) = self.config.exec.ask_cmd {
info!(
ask_cmd = ask_cmd,
"Running ask command for {}::{}",
msg.object_type().interface(),
msg.msg_name()
);
let mut cmd = self.prepare_command(
&*msg,
ask_cmd,
filtered.desc.as_deref().unwrap_or_else(|| ""),
);
if let Ok(status) = cmd.status().await {
if !status.success() {
warn!(
"Blocked {}::{} because of return status {}",
msg.object_type().interface(),
msg.msg_name(),
status
);
return match filtered.block_type {
WlFilterRequestBlockType::Ignore => outcome.filtered(),
WlFilterRequestBlockType::Reject => {
outcome.rejected(filtered.error_code)
}
};
} else {
return outcome.allowed();
}
}
}
warn!(
"Blocked {}::{} because of missing ask_cmd",
msg.object_type().interface(),
msg.msg_name()
);
return match filtered.block_type {
WlFilterRequestBlockType::Ignore => outcome.filtered(),
WlFilterRequestBlockType::Reject => {
outcome.rejected(filtered.error_code)
}
};
}
WlFilterRequestAction::Notify => {
if let Some(ref notify_cmd) = self.config.exec.notify_cmd {
info!(
notify_cmd = notify_cmd,
"Running notify command for {}::{}",
msg.object_type().interface(),
msg.msg_name()
);
let mut cmd = self.prepare_command(
&*msg,
notify_cmd,
filtered.desc.as_deref().unwrap_or_else(|| ""),
);
cmd.spawn().ok();
}
}
WlFilterRequestAction::Block => {
warn!(
"Blocked {}::{}",
msg.object_type().interface(),
msg.msg_name()
);
return match filtered.block_type {
WlFilterRequestBlockType::Ignore => outcome.filtered(),
WlFilterRequestBlockType::Reject => {
outcome.rejected(filtered.error_code)
}
};
}
}
}
}
outcome.allowed()
}
#[tracing::instrument(skip_all)]
pub async fn on_s2c_event(&mut self, raw_msg: &WlRawMsg) -> WlMitmOutcome {
let mut outcome: WlMitmOutcome = Default::default();
let msg = match crate::proto::decode_event(&self.objects, raw_msg) {
WaylandProtocolParsingOutcome::Ok(msg) => msg,
_ => {
let obj_type = self
.objects
.lookup_object(raw_msg.obj_id)
.map(|t| t.interface());
error!(
obj_id = raw_msg.obj_id,
obj_type = ?obj_type,
opcode = raw_msg.opcode,
num_fds = raw_msg.fds.len(),
"Malformed or unknown event"
);
return outcome.terminate();
}
};
outcome.set_consumed_fds(msg.num_consumed_fds());
if self.config.logging.log_all_events {
debug!(
obj_id = msg.obj_id(),
raw_payload_bytes = ?raw_msg.payload(),
num_fds = raw_msg.fds.len(),
num_consumed_fds = msg.num_consumed_fds(),
"{}::{}",
msg.object_type().interface(),
msg.msg_name(),
)
}
if !self.handle_created_or_destroyed_objects(&*msg, false) {
return outcome.terminate();
}
if let Some(msg) = msg.downcast_ref::<WlRegistryGlobalEvent>() {
// This event is how Wayland servers announce globals -- and they are the entrypoint to
// most extensions! You need at least one global registered for clients to be able to
// access methods from that extension; but those methods _could_ create more objects.
debug!(
interface = msg.interface,
name = msg.name,
version = msg.version,
"got global"
);
let Some(obj_type) = crate::proto::lookup_known_object_type(msg.interface) else {
error!(
interface = msg.interface,
"Unknown interface removed! If required, please include its XML when building wl-mitm!"
);
return outcome.filtered();
};
// To block entire extensions, we just need to filter out their announced global objects.
if !self.config.filter.allowed_globals.contains(msg.interface) {
info!(
interface = msg.interface,
"Removing interface from published globals"
);
return outcome.filtered();
}
// Else, record the global object. These are the only ones we're ever going to allow through.
// We block bind requests on any interface that's not recorded here.
self.objects.record_global(msg.name, obj_type);
} else if let Some(msg) = msg.downcast_ref::<WlRegistryGlobalRemoveEvent>() {
// Remove globals that the server has removed
self.objects.remove_global(msg.name);
} else if let Some(msg) = msg.downcast_ref::<WlDisplayDeleteIdEvent>() {
// Server has acknowledged deletion of an object
self.objects.remove_object(msg.id, false);
} else if let Some(msg) = msg.downcast_ref::<WlPointerEnterEvent>() {
self.update_last_active_surface(msg.surface);
} else if let Some(msg) = msg.downcast_ref::<WlKeyboardEnterEvent>() {
self.update_last_active_surface(msg.surface);
} else if let Some(msg) = msg.downcast_ref::<WlTouchDownEvent>() {
self.update_last_active_surface(msg.surface);
}
outcome.allowed()
}
}