Compare commits

...

2 commits

Author SHA1 Message Date
e0cc002148 Handle fds properly
fds aren't guaranteed to arrive along with messages they belong to; so
we can't just assume everything in a WlRawMsg belongs to a parsed
Wayland message.

The only way to tell is by looking at the XML, so we'll have to return
any unused ones to the decoder after receiving them.
2025-03-02 11:14:30 -05:00
d0afdbdda2 Parse serialized message to ask scripts 2025-03-02 08:37:16 -05:00
9 changed files with 144 additions and 41 deletions

26
Cargo.lock generated
View file

@ -95,6 +95,7 @@ dependencies = [
"az",
"bytemuck",
"half",
"serde",
"typenum",
]
@ -130,6 +131,12 @@ dependencies = [
"hashbrown",
]
[[package]]
name = "itoa"
version = "1.0.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674"
[[package]]
name = "lazy_static"
version = "1.5.0"
@ -266,6 +273,12 @@ version = "0.1.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f"
[[package]]
name = "ryu"
version = "1.0.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd"
[[package]]
name = "sendfd"
version = "0.4.3"
@ -296,6 +309,18 @@ dependencies = [
"syn",
]
[[package]]
name = "serde_json"
version = "1.0.139"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "44f86c3acccc9c65b153fe1b85a3be07fe5515274ec9f0653b4a0875731c72a6"
dependencies = [
"itoa",
"memchr",
"ryu",
"serde",
]
[[package]]
name = "serde_spanned"
version = "0.6.8"
@ -619,6 +644,7 @@ dependencies = [
"sendfd",
"serde",
"serde_derive",
"serde_json",
"tokio",
"toml",
"tracing",

View file

@ -11,11 +11,12 @@ members = ["protogen"]
[dependencies]
byteorder = "1.5.0"
bytes = "1.10.0"
fixed = "1.29.0"
fixed = { version = "1.29.0", features = [ "serde" ] }
nix = "0.29.0"
sendfd = { version = "0.4", features = [ "tokio" ] }
serde = "1.0.218"
serde_derive = "1.0.218"
serde_json = "1.0.139"
tokio = { version = "1.43.0", features = [ "fs", "net", "rt", "rt-multi-thread", "macros", "io-util", "process" ]}
toml = "0.8.20"
tracing = "0.1.41"

View file

@ -36,6 +36,9 @@ allowed_globals = [
#
# The first and second arguments to this program will be the interface
# and request name, respectively.
#
# The number of arguments may change in the future, but the _last_ argument
# is always a JSON-serialized representation of the request's arguments.
ask_cmd = "contrib/ask-bemenu.sh"
# A list of requests we'd like to filter

View file

@ -98,12 +98,30 @@ impl WlMsg {
let parser_fn_name = format_ident!("{}", self.parser_fn_name());
// Build all field names and their corresponding Rust type identifiers
let (field_names, field_types): (Vec<_>, Vec<_>) = self
let (field_names, (field_types, field_attrs)): (Vec<_>, (Vec<_>, Vec<_>)) = self
.args
.iter()
.map(|(name, tt)| (format_ident!("{name}"), tt.to_rust_type()))
.map(|(name, tt)| {
(
format_ident!("{name}"),
(
tt.to_rust_type(),
match tt {
// Can't serialize fds!
WlArgType::Fd => quote! { #[serde(skip)] },
_ => quote! {},
},
),
)
})
.unzip();
let num_consumed_fds = self
.args
.iter()
.filter(|(_, tt)| matches!(tt, WlArgType::Fd))
.count();
// Generate code to include in the parser for every field
let parser_code: Vec<_> = self
.args
@ -143,10 +161,12 @@ impl WlMsg {
quote! {
#[allow(unused)]
#[derive(Serialize)]
pub struct #struct_name<'a> {
#[serde(skip)]
_phantom: std::marker::PhantomData<&'a ()>,
obj_id: u32,
#( pub #field_names: #field_types, )*
#( #field_attrs pub #field_names: #field_types, )*
}
impl<'a> __private::WlParsedMessagePrivate for #struct_name<'a> {}
@ -184,6 +204,7 @@ impl WlMsg {
fn try_from_msg_impl(msg: &crate::codec::WlRawMsg, _token: __private::WlParsedMessagePrivateToken) -> WaylandProtocolParsingOutcome<#struct_name> {
let payload = msg.payload();
let mut pos = 0usize;
let mut pos_fds = 0usize;
#( #parser_code )*
WaylandProtocolParsingOutcome::Ok(#struct_name {
_phantom: std::marker::PhantomData,
@ -203,6 +224,14 @@ impl WlMsg {
fn known_objects_created(&self) -> Option<Vec<(u32, WlObjectType)>> {
#known_objects_created
}
fn to_json(&self) -> String {
serde_json::to_string(self).unwrap()
}
fn num_consumed_fds(&self) -> usize {
#num_consumed_fds
}
}
unsafe impl<'a> AnyWlParsedMessage<'a> for #struct_name<'a> {}
@ -376,11 +405,12 @@ impl WlArgType {
};
},
WlArgType::Fd => quote! {
if msg.fds.len() == 0 {
if msg.fds.len() < pos_fds + 1 {
return WaylandProtocolParsingOutcome::MalformedMessage;
}
let #var_name: std::os::fd::BorrowedFd<'_> = std::os::fd::AsFd::as_fd(&msg.fds[0]);
let #var_name: std::os::fd::BorrowedFd<'_> = std::os::fd::AsFd::as_fd(&msg.fds[pos_fds]);
pos_fds += 1;
},
}
}

View file

@ -1,4 +1,4 @@
use std::os::fd::OwnedFd;
use std::{collections::VecDeque, os::fd::OwnedFd};
use byteorder::{ByteOrder, NativeEndian};
use bytes::{Bytes, BytesMut};
@ -13,11 +13,18 @@ pub struct WlRawMsg {
pub opcode: u16,
// len bytes -- containing the header
msg_buf: Bytes,
pub fds: Box<[OwnedFd]>,
/// All fds we have seen up until decoding this message frame
/// fds aren't guaranteed to be separated between messages; therefore, there
/// is no way for us to tell that all fds here belong to the current message
/// without actually loading the Wayland XML protocols.
///
/// Instead, downstream parsers should return any unused fds back to the decoder
/// with [WlDecoder::return_unused_fds].
pub fds: Vec<OwnedFd>,
}
impl WlRawMsg {
pub fn try_decode(buf: &mut BytesMut, fds: &mut Vec<OwnedFd>) -> Option<WlRawMsg> {
pub fn try_decode(buf: &mut BytesMut, fds: &mut VecDeque<OwnedFd>) -> Option<WlRawMsg> {
let buf_len = buf.len();
// Not even a complete message header
if buf_len < 8 {
@ -36,14 +43,16 @@ impl WlRawMsg {
let msg_buf = buf.split_to(msg_len as usize);
let mut new_fds = Vec::with_capacity(fds.len());
new_fds.append(fds);
while let Some(fd) = fds.pop_front() {
new_fds.push(fd);
}
Some(WlRawMsg {
obj_id,
len: msg_len as u16,
opcode: opcode as u16,
msg_buf: msg_buf.freeze(),
fds: new_fds.into_boxed_slice(),
fds: new_fds,
})
}
@ -52,7 +61,7 @@ impl WlRawMsg {
}
pub fn into_parts(self) -> (Bytes, Box<[OwnedFd]>) {
(self.msg_buf, self.fds)
(self.msg_buf, self.fds.into_boxed_slice())
}
}
@ -64,14 +73,25 @@ pub enum DecoderOutcome {
pub struct WlDecoder {
buf: BytesMut,
fds: Vec<OwnedFd>,
fds: VecDeque<OwnedFd>,
}
impl WlDecoder {
pub fn new() -> WlDecoder {
WlDecoder {
buf: BytesMut::new(),
fds: Vec::new(),
fds: VecDeque::new(),
}
}
pub fn return_unused_fds(&mut self, msg: &mut WlRawMsg, num_consumed: usize) {
let mut unused = msg.fds.split_off(num_consumed);
// Add all unused vectors, in order, to the _front_ of our queue
// This means that we take one item from the _back_ of the unused
// chunk at a time and insert that to the _front_, to preserve order.
while let Some(fd) = unused.pop() {
self.fds.push_front(fd);
}
}
@ -86,9 +106,9 @@ impl WlDecoder {
}
}
pub fn decode_after_read(&mut self, buf: &[u8], fds: &mut Vec<OwnedFd>) -> DecoderOutcome {
pub fn decode_after_read(&mut self, buf: &[u8], fds: Vec<OwnedFd>) -> DecoderOutcome {
self.buf.extend_from_slice(&buf);
self.fds.append(fds);
self.fds.extend(fds.into_iter());
match WlRawMsg::try_decode(&mut self.buf, &mut self.fds) {
Some(res) => DecoderOutcome::Decoded(res),

View file

@ -25,6 +25,10 @@ impl<'a> WlMsgReader<'a> {
}
}
pub fn return_unused_fds(&mut self, msg: &mut WlRawMsg, num_consumed: usize) {
self.decoder.return_unused_fds(msg, num_consumed);
}
pub async fn read(&mut self) -> io::Result<DecoderOutcome> {
if let Some(DecoderOutcome::Decoded(msg)) = self.decoder.decode_buf() {
return Ok(DecoderOutcome::Decoded(msg));
@ -50,7 +54,7 @@ impl<'a> WlMsgReader<'a> {
return Ok(self
.decoder
.decode_after_read(&tmp_buf[0..read_bytes], &mut fd_vec));
.decode_after_read(&tmp_buf[0..read_bytes], fd_vec));
}
}
}

View file

@ -90,10 +90,13 @@ pub async fn handle_conn(
tokio::select! {
s2c_msg = upstream_read.read() => {
match s2c_msg? {
codec::DecoderOutcome::Decoded(wl_raw_msg) => {
codec::DecoderOutcome::Decoded(mut wl_raw_msg) => {
debug!(obj_id = wl_raw_msg.obj_id, opcode = wl_raw_msg.opcode, num_fds = wl_raw_msg.fds.len(), "s2c event");
if state.on_s2c_event(&wl_raw_msg).await {
let (num_consumed_fds, verdict) = state.on_s2c_event(&wl_raw_msg).await;
upstream_read.return_unused_fds(&mut wl_raw_msg, num_consumed_fds);
if verdict {
downstream_write.queue_write(wl_raw_msg);
}
},
@ -103,10 +106,13 @@ pub async fn handle_conn(
},
c2s_msg = downstream_read.read() => {
match c2s_msg? {
codec::DecoderOutcome::Decoded(wl_raw_msg) => {
codec::DecoderOutcome::Decoded(mut wl_raw_msg) => {
debug!(obj_id = wl_raw_msg.obj_id, opcode = wl_raw_msg.opcode, num_fds = wl_raw_msg.fds.len(), "c2s request");
if state.on_c2s_request(&wl_raw_msg).await {
let (num_consumed_fds, verdict) = state.on_c2s_request(&wl_raw_msg).await;
downstream_read.return_unused_fds(&mut wl_raw_msg, num_consumed_fds);
if verdict {
upstream_write.queue_write(wl_raw_msg);
}
},

View file

@ -3,6 +3,7 @@
use std::{collections::HashMap, sync::LazyLock};
use byteorder::ByteOrder;
use serde_derive::Serialize;
use crate::{
codec::WlRawMsg,
@ -114,6 +115,13 @@ pub trait WlParsedMessage<'a>: __private::WlParsedMessagePrivate {
/// widely-used message with that capability is [WlRegistryBindRequest],
/// which is already handled separately on its own.
fn known_objects_created(&self) -> Option<Vec<(u32, WlObjectType)>>;
/// Serialize this message into a JSON string, for use with ask scripts
fn to_json(&self) -> String;
/// How many fds have been consumed in parsing this message?
/// This is used to return any unused fds to the decoder.
fn num_consumed_fds(&self) -> usize;
}
/// A version of [WlParsedMessage] that supports downcasting. By implementing this

View file

@ -60,8 +60,9 @@ impl WlMitmState {
}
}
/// Returns the number of fds consumed while parsing the message as a concrete Wayland type, and a verdict
#[tracing::instrument(skip_all)]
pub async fn on_c2s_request(&mut self, raw_msg: &WlRawMsg) -> bool {
pub async fn on_c2s_request(&mut self, raw_msg: &WlRawMsg) -> (usize, bool) {
let msg = match crate::proto::decode_request(&self.objects, raw_msg) {
WaylandProtocolParsingOutcome::Ok(msg) => msg,
WaylandProtocolParsingOutcome::MalformedMessage => {
@ -73,9 +74,11 @@ impl WlMitmState {
num_fds = raw_msg.fds.len(),
"Malformed request"
);
return false;
return (0, false);
}
_ => {
// TODO: due to fds, we can't expect to be able to pass through unknown messages.
// Load all sensible Wayland extensions and remove this condition.
// Pass through all unknown messages -- they could be from a Wayland protocol we haven't
// been built against!
// Note that this won't pass through messages for globals we haven't allowed:
@ -84,7 +87,7 @@ impl WlMitmState {
// even for globals from protocols we don't know.
// It does mean we can't filter against methods that create more objects _from_ that
// global, though.
return true;
return (0, true);
}
};
@ -98,7 +101,7 @@ impl WlMitmState {
// to bind to it; if it does, it's likely a malicious client!
// So, we simply remove these messages from the stream, which will cause the Wayland server to error out.
let Some(interface) = self.objects.lookup_global(msg.name) else {
return false;
return (0, false);
};
if interface != msg.id_interface_name {
@ -106,7 +109,7 @@ impl WlMitmState {
"Client binding to interface {}, but the interface name {} should correspond to {}",
msg.id_interface_name, msg.name, interface
);
return false;
return (0, false);
}
info!(
@ -141,12 +144,14 @@ impl WlMitmState {
msg.self_object_type().interface(),
msg.self_msg_name()
);
if let Ok(status) = tokio::process::Command::new(ask_cmd)
.arg(msg.self_object_type().interface())
.arg(msg.self_msg_name())
.status()
.await
{
let mut cmd = tokio::process::Command::new(ask_cmd);
cmd.arg(msg.self_object_type().interface());
cmd.arg(msg.self_msg_name());
// Note: the _last_ argument is always the JSON representation!
cmd.arg(msg.to_json());
if let Ok(status) = cmd.status().await {
if !status.success() {
warn!(
"Blocked {}::{} because of return status {}",
@ -156,7 +161,7 @@ impl WlMitmState {
);
}
return status.success();
return (msg.num_consumed_fds(), status.success());
}
}
@ -165,7 +170,7 @@ impl WlMitmState {
msg.self_object_type().interface(),
msg.self_msg_name()
);
return false;
return (msg.num_consumed_fds(), false);
}
WlFilterRequestAction::Block => {
warn!(
@ -174,17 +179,17 @@ impl WlMitmState {
msg.self_msg_name()
);
// TODO: don't just return false, build an error event
return false;
return (msg.num_consumed_fds(), false);
}
}
}
}
true
(msg.num_consumed_fds(), true)
}
#[tracing::instrument(skip_all)]
pub async fn on_s2c_event(&mut self, raw_msg: &WlRawMsg) -> bool {
pub async fn on_s2c_event(&mut self, raw_msg: &WlRawMsg) -> (usize, bool) {
let msg = match crate::proto::decode_event(&self.objects, raw_msg) {
WaylandProtocolParsingOutcome::Ok(msg) => msg,
WaylandProtocolParsingOutcome::MalformedMessage => {
@ -194,10 +199,10 @@ impl WlMitmState {
num_fds = raw_msg.fds.len(),
"Malformed event"
);
return false;
return (0, false);
}
_ => {
return true;
return (0, true);
}
};
@ -220,7 +225,7 @@ impl WlMitmState {
interface = msg.interface,
"Removing interface from published globals"
);
return false;
return (0, false);
}
// Else, record the global object. These are the only ones we're ever going to allow through.
@ -234,6 +239,6 @@ impl WlMitmState {
self.objects.remove_object(msg.id);
}
true
(msg.num_consumed_fds(), true)
}
}