Compare commits

..

2 commits

Author SHA1 Message Date
e0cc002148 Handle fds properly
fds aren't guaranteed to arrive along with messages they belong to; so
we can't just assume everything in a WlRawMsg belongs to a parsed
Wayland message.

The only way to tell is by looking at the XML, so we'll have to return
any unused ones to the decoder after receiving them.
2025-03-02 11:14:30 -05:00
d0afdbdda2 Parse serialized message to ask scripts 2025-03-02 08:37:16 -05:00
9 changed files with 144 additions and 41 deletions

26
Cargo.lock generated
View file

@ -95,6 +95,7 @@ dependencies = [
"az", "az",
"bytemuck", "bytemuck",
"half", "half",
"serde",
"typenum", "typenum",
] ]
@ -130,6 +131,12 @@ dependencies = [
"hashbrown", "hashbrown",
] ]
[[package]]
name = "itoa"
version = "1.0.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674"
[[package]] [[package]]
name = "lazy_static" name = "lazy_static"
version = "1.5.0" version = "1.5.0"
@ -266,6 +273,12 @@ version = "0.1.24"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f"
[[package]]
name = "ryu"
version = "1.0.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd"
[[package]] [[package]]
name = "sendfd" name = "sendfd"
version = "0.4.3" version = "0.4.3"
@ -296,6 +309,18 @@ dependencies = [
"syn", "syn",
] ]
[[package]]
name = "serde_json"
version = "1.0.139"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "44f86c3acccc9c65b153fe1b85a3be07fe5515274ec9f0653b4a0875731c72a6"
dependencies = [
"itoa",
"memchr",
"ryu",
"serde",
]
[[package]] [[package]]
name = "serde_spanned" name = "serde_spanned"
version = "0.6.8" version = "0.6.8"
@ -619,6 +644,7 @@ dependencies = [
"sendfd", "sendfd",
"serde", "serde",
"serde_derive", "serde_derive",
"serde_json",
"tokio", "tokio",
"toml", "toml",
"tracing", "tracing",

View file

@ -11,11 +11,12 @@ members = ["protogen"]
[dependencies] [dependencies]
byteorder = "1.5.0" byteorder = "1.5.0"
bytes = "1.10.0" bytes = "1.10.0"
fixed = "1.29.0" fixed = { version = "1.29.0", features = [ "serde" ] }
nix = "0.29.0" nix = "0.29.0"
sendfd = { version = "0.4", features = [ "tokio" ] } sendfd = { version = "0.4", features = [ "tokio" ] }
serde = "1.0.218" serde = "1.0.218"
serde_derive = "1.0.218" serde_derive = "1.0.218"
serde_json = "1.0.139"
tokio = { version = "1.43.0", features = [ "fs", "net", "rt", "rt-multi-thread", "macros", "io-util", "process" ]} tokio = { version = "1.43.0", features = [ "fs", "net", "rt", "rt-multi-thread", "macros", "io-util", "process" ]}
toml = "0.8.20" toml = "0.8.20"
tracing = "0.1.41" tracing = "0.1.41"

View file

@ -36,6 +36,9 @@ allowed_globals = [
# #
# The first and second arguments to this program will be the interface # The first and second arguments to this program will be the interface
# and request name, respectively. # and request name, respectively.
#
# The number of arguments may change in the future, but the _last_ argument
# is always a JSON-serialized representation of the request's arguments.
ask_cmd = "contrib/ask-bemenu.sh" ask_cmd = "contrib/ask-bemenu.sh"
# A list of requests we'd like to filter # A list of requests we'd like to filter

View file

@ -98,12 +98,30 @@ impl WlMsg {
let parser_fn_name = format_ident!("{}", self.parser_fn_name()); let parser_fn_name = format_ident!("{}", self.parser_fn_name());
// Build all field names and their corresponding Rust type identifiers // Build all field names and their corresponding Rust type identifiers
let (field_names, field_types): (Vec<_>, Vec<_>) = self let (field_names, (field_types, field_attrs)): (Vec<_>, (Vec<_>, Vec<_>)) = self
.args .args
.iter() .iter()
.map(|(name, tt)| (format_ident!("{name}"), tt.to_rust_type())) .map(|(name, tt)| {
(
format_ident!("{name}"),
(
tt.to_rust_type(),
match tt {
// Can't serialize fds!
WlArgType::Fd => quote! { #[serde(skip)] },
_ => quote! {},
},
),
)
})
.unzip(); .unzip();
let num_consumed_fds = self
.args
.iter()
.filter(|(_, tt)| matches!(tt, WlArgType::Fd))
.count();
// Generate code to include in the parser for every field // Generate code to include in the parser for every field
let parser_code: Vec<_> = self let parser_code: Vec<_> = self
.args .args
@ -143,10 +161,12 @@ impl WlMsg {
quote! { quote! {
#[allow(unused)] #[allow(unused)]
#[derive(Serialize)]
pub struct #struct_name<'a> { pub struct #struct_name<'a> {
#[serde(skip)]
_phantom: std::marker::PhantomData<&'a ()>, _phantom: std::marker::PhantomData<&'a ()>,
obj_id: u32, obj_id: u32,
#( pub #field_names: #field_types, )* #( #field_attrs pub #field_names: #field_types, )*
} }
impl<'a> __private::WlParsedMessagePrivate for #struct_name<'a> {} impl<'a> __private::WlParsedMessagePrivate for #struct_name<'a> {}
@ -184,6 +204,7 @@ impl WlMsg {
fn try_from_msg_impl(msg: &crate::codec::WlRawMsg, _token: __private::WlParsedMessagePrivateToken) -> WaylandProtocolParsingOutcome<#struct_name> { fn try_from_msg_impl(msg: &crate::codec::WlRawMsg, _token: __private::WlParsedMessagePrivateToken) -> WaylandProtocolParsingOutcome<#struct_name> {
let payload = msg.payload(); let payload = msg.payload();
let mut pos = 0usize; let mut pos = 0usize;
let mut pos_fds = 0usize;
#( #parser_code )* #( #parser_code )*
WaylandProtocolParsingOutcome::Ok(#struct_name { WaylandProtocolParsingOutcome::Ok(#struct_name {
_phantom: std::marker::PhantomData, _phantom: std::marker::PhantomData,
@ -203,6 +224,14 @@ impl WlMsg {
fn known_objects_created(&self) -> Option<Vec<(u32, WlObjectType)>> { fn known_objects_created(&self) -> Option<Vec<(u32, WlObjectType)>> {
#known_objects_created #known_objects_created
} }
fn to_json(&self) -> String {
serde_json::to_string(self).unwrap()
}
fn num_consumed_fds(&self) -> usize {
#num_consumed_fds
}
} }
unsafe impl<'a> AnyWlParsedMessage<'a> for #struct_name<'a> {} unsafe impl<'a> AnyWlParsedMessage<'a> for #struct_name<'a> {}
@ -376,11 +405,12 @@ impl WlArgType {
}; };
}, },
WlArgType::Fd => quote! { WlArgType::Fd => quote! {
if msg.fds.len() == 0 { if msg.fds.len() < pos_fds + 1 {
return WaylandProtocolParsingOutcome::MalformedMessage; return WaylandProtocolParsingOutcome::MalformedMessage;
} }
let #var_name: std::os::fd::BorrowedFd<'_> = std::os::fd::AsFd::as_fd(&msg.fds[0]); let #var_name: std::os::fd::BorrowedFd<'_> = std::os::fd::AsFd::as_fd(&msg.fds[pos_fds]);
pos_fds += 1;
}, },
} }
} }

View file

@ -1,4 +1,4 @@
use std::os::fd::OwnedFd; use std::{collections::VecDeque, os::fd::OwnedFd};
use byteorder::{ByteOrder, NativeEndian}; use byteorder::{ByteOrder, NativeEndian};
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
@ -13,11 +13,18 @@ pub struct WlRawMsg {
pub opcode: u16, pub opcode: u16,
// len bytes -- containing the header // len bytes -- containing the header
msg_buf: Bytes, msg_buf: Bytes,
pub fds: Box<[OwnedFd]>, /// All fds we have seen up until decoding this message frame
/// fds aren't guaranteed to be separated between messages; therefore, there
/// is no way for us to tell that all fds here belong to the current message
/// without actually loading the Wayland XML protocols.
///
/// Instead, downstream parsers should return any unused fds back to the decoder
/// with [WlDecoder::return_unused_fds].
pub fds: Vec<OwnedFd>,
} }
impl WlRawMsg { impl WlRawMsg {
pub fn try_decode(buf: &mut BytesMut, fds: &mut Vec<OwnedFd>) -> Option<WlRawMsg> { pub fn try_decode(buf: &mut BytesMut, fds: &mut VecDeque<OwnedFd>) -> Option<WlRawMsg> {
let buf_len = buf.len(); let buf_len = buf.len();
// Not even a complete message header // Not even a complete message header
if buf_len < 8 { if buf_len < 8 {
@ -36,14 +43,16 @@ impl WlRawMsg {
let msg_buf = buf.split_to(msg_len as usize); let msg_buf = buf.split_to(msg_len as usize);
let mut new_fds = Vec::with_capacity(fds.len()); let mut new_fds = Vec::with_capacity(fds.len());
new_fds.append(fds); while let Some(fd) = fds.pop_front() {
new_fds.push(fd);
}
Some(WlRawMsg { Some(WlRawMsg {
obj_id, obj_id,
len: msg_len as u16, len: msg_len as u16,
opcode: opcode as u16, opcode: opcode as u16,
msg_buf: msg_buf.freeze(), msg_buf: msg_buf.freeze(),
fds: new_fds.into_boxed_slice(), fds: new_fds,
}) })
} }
@ -52,7 +61,7 @@ impl WlRawMsg {
} }
pub fn into_parts(self) -> (Bytes, Box<[OwnedFd]>) { pub fn into_parts(self) -> (Bytes, Box<[OwnedFd]>) {
(self.msg_buf, self.fds) (self.msg_buf, self.fds.into_boxed_slice())
} }
} }
@ -64,14 +73,25 @@ pub enum DecoderOutcome {
pub struct WlDecoder { pub struct WlDecoder {
buf: BytesMut, buf: BytesMut,
fds: Vec<OwnedFd>, fds: VecDeque<OwnedFd>,
} }
impl WlDecoder { impl WlDecoder {
pub fn new() -> WlDecoder { pub fn new() -> WlDecoder {
WlDecoder { WlDecoder {
buf: BytesMut::new(), buf: BytesMut::new(),
fds: Vec::new(), fds: VecDeque::new(),
}
}
pub fn return_unused_fds(&mut self, msg: &mut WlRawMsg, num_consumed: usize) {
let mut unused = msg.fds.split_off(num_consumed);
// Add all unused vectors, in order, to the _front_ of our queue
// This means that we take one item from the _back_ of the unused
// chunk at a time and insert that to the _front_, to preserve order.
while let Some(fd) = unused.pop() {
self.fds.push_front(fd);
} }
} }
@ -86,9 +106,9 @@ impl WlDecoder {
} }
} }
pub fn decode_after_read(&mut self, buf: &[u8], fds: &mut Vec<OwnedFd>) -> DecoderOutcome { pub fn decode_after_read(&mut self, buf: &[u8], fds: Vec<OwnedFd>) -> DecoderOutcome {
self.buf.extend_from_slice(&buf); self.buf.extend_from_slice(&buf);
self.fds.append(fds); self.fds.extend(fds.into_iter());
match WlRawMsg::try_decode(&mut self.buf, &mut self.fds) { match WlRawMsg::try_decode(&mut self.buf, &mut self.fds) {
Some(res) => DecoderOutcome::Decoded(res), Some(res) => DecoderOutcome::Decoded(res),

View file

@ -25,6 +25,10 @@ impl<'a> WlMsgReader<'a> {
} }
} }
pub fn return_unused_fds(&mut self, msg: &mut WlRawMsg, num_consumed: usize) {
self.decoder.return_unused_fds(msg, num_consumed);
}
pub async fn read(&mut self) -> io::Result<DecoderOutcome> { pub async fn read(&mut self) -> io::Result<DecoderOutcome> {
if let Some(DecoderOutcome::Decoded(msg)) = self.decoder.decode_buf() { if let Some(DecoderOutcome::Decoded(msg)) = self.decoder.decode_buf() {
return Ok(DecoderOutcome::Decoded(msg)); return Ok(DecoderOutcome::Decoded(msg));
@ -50,7 +54,7 @@ impl<'a> WlMsgReader<'a> {
return Ok(self return Ok(self
.decoder .decoder
.decode_after_read(&tmp_buf[0..read_bytes], &mut fd_vec)); .decode_after_read(&tmp_buf[0..read_bytes], fd_vec));
} }
} }
} }

View file

@ -90,10 +90,13 @@ pub async fn handle_conn(
tokio::select! { tokio::select! {
s2c_msg = upstream_read.read() => { s2c_msg = upstream_read.read() => {
match s2c_msg? { match s2c_msg? {
codec::DecoderOutcome::Decoded(wl_raw_msg) => { codec::DecoderOutcome::Decoded(mut wl_raw_msg) => {
debug!(obj_id = wl_raw_msg.obj_id, opcode = wl_raw_msg.opcode, num_fds = wl_raw_msg.fds.len(), "s2c event"); debug!(obj_id = wl_raw_msg.obj_id, opcode = wl_raw_msg.opcode, num_fds = wl_raw_msg.fds.len(), "s2c event");
if state.on_s2c_event(&wl_raw_msg).await { let (num_consumed_fds, verdict) = state.on_s2c_event(&wl_raw_msg).await;
upstream_read.return_unused_fds(&mut wl_raw_msg, num_consumed_fds);
if verdict {
downstream_write.queue_write(wl_raw_msg); downstream_write.queue_write(wl_raw_msg);
} }
}, },
@ -103,10 +106,13 @@ pub async fn handle_conn(
}, },
c2s_msg = downstream_read.read() => { c2s_msg = downstream_read.read() => {
match c2s_msg? { match c2s_msg? {
codec::DecoderOutcome::Decoded(wl_raw_msg) => { codec::DecoderOutcome::Decoded(mut wl_raw_msg) => {
debug!(obj_id = wl_raw_msg.obj_id, opcode = wl_raw_msg.opcode, num_fds = wl_raw_msg.fds.len(), "c2s request"); debug!(obj_id = wl_raw_msg.obj_id, opcode = wl_raw_msg.opcode, num_fds = wl_raw_msg.fds.len(), "c2s request");
if state.on_c2s_request(&wl_raw_msg).await { let (num_consumed_fds, verdict) = state.on_c2s_request(&wl_raw_msg).await;
downstream_read.return_unused_fds(&mut wl_raw_msg, num_consumed_fds);
if verdict {
upstream_write.queue_write(wl_raw_msg); upstream_write.queue_write(wl_raw_msg);
} }
}, },

View file

@ -3,6 +3,7 @@
use std::{collections::HashMap, sync::LazyLock}; use std::{collections::HashMap, sync::LazyLock};
use byteorder::ByteOrder; use byteorder::ByteOrder;
use serde_derive::Serialize;
use crate::{ use crate::{
codec::WlRawMsg, codec::WlRawMsg,
@ -114,6 +115,13 @@ pub trait WlParsedMessage<'a>: __private::WlParsedMessagePrivate {
/// widely-used message with that capability is [WlRegistryBindRequest], /// widely-used message with that capability is [WlRegistryBindRequest],
/// which is already handled separately on its own. /// which is already handled separately on its own.
fn known_objects_created(&self) -> Option<Vec<(u32, WlObjectType)>>; fn known_objects_created(&self) -> Option<Vec<(u32, WlObjectType)>>;
/// Serialize this message into a JSON string, for use with ask scripts
fn to_json(&self) -> String;
/// How many fds have been consumed in parsing this message?
/// This is used to return any unused fds to the decoder.
fn num_consumed_fds(&self) -> usize;
} }
/// A version of [WlParsedMessage] that supports downcasting. By implementing this /// A version of [WlParsedMessage] that supports downcasting. By implementing this

View file

@ -60,8 +60,9 @@ impl WlMitmState {
} }
} }
/// Returns the number of fds consumed while parsing the message as a concrete Wayland type, and a verdict
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
pub async fn on_c2s_request(&mut self, raw_msg: &WlRawMsg) -> bool { pub async fn on_c2s_request(&mut self, raw_msg: &WlRawMsg) -> (usize, bool) {
let msg = match crate::proto::decode_request(&self.objects, raw_msg) { let msg = match crate::proto::decode_request(&self.objects, raw_msg) {
WaylandProtocolParsingOutcome::Ok(msg) => msg, WaylandProtocolParsingOutcome::Ok(msg) => msg,
WaylandProtocolParsingOutcome::MalformedMessage => { WaylandProtocolParsingOutcome::MalformedMessage => {
@ -73,9 +74,11 @@ impl WlMitmState {
num_fds = raw_msg.fds.len(), num_fds = raw_msg.fds.len(),
"Malformed request" "Malformed request"
); );
return false; return (0, false);
} }
_ => { _ => {
// TODO: due to fds, we can't expect to be able to pass through unknown messages.
// Load all sensible Wayland extensions and remove this condition.
// Pass through all unknown messages -- they could be from a Wayland protocol we haven't // Pass through all unknown messages -- they could be from a Wayland protocol we haven't
// been built against! // been built against!
// Note that this won't pass through messages for globals we haven't allowed: // Note that this won't pass through messages for globals we haven't allowed:
@ -84,7 +87,7 @@ impl WlMitmState {
// even for globals from protocols we don't know. // even for globals from protocols we don't know.
// It does mean we can't filter against methods that create more objects _from_ that // It does mean we can't filter against methods that create more objects _from_ that
// global, though. // global, though.
return true; return (0, true);
} }
}; };
@ -98,7 +101,7 @@ impl WlMitmState {
// to bind to it; if it does, it's likely a malicious client! // to bind to it; if it does, it's likely a malicious client!
// So, we simply remove these messages from the stream, which will cause the Wayland server to error out. // So, we simply remove these messages from the stream, which will cause the Wayland server to error out.
let Some(interface) = self.objects.lookup_global(msg.name) else { let Some(interface) = self.objects.lookup_global(msg.name) else {
return false; return (0, false);
}; };
if interface != msg.id_interface_name { if interface != msg.id_interface_name {
@ -106,7 +109,7 @@ impl WlMitmState {
"Client binding to interface {}, but the interface name {} should correspond to {}", "Client binding to interface {}, but the interface name {} should correspond to {}",
msg.id_interface_name, msg.name, interface msg.id_interface_name, msg.name, interface
); );
return false; return (0, false);
} }
info!( info!(
@ -141,12 +144,14 @@ impl WlMitmState {
msg.self_object_type().interface(), msg.self_object_type().interface(),
msg.self_msg_name() msg.self_msg_name()
); );
if let Ok(status) = tokio::process::Command::new(ask_cmd)
.arg(msg.self_object_type().interface()) let mut cmd = tokio::process::Command::new(ask_cmd);
.arg(msg.self_msg_name()) cmd.arg(msg.self_object_type().interface());
.status() cmd.arg(msg.self_msg_name());
.await // Note: the _last_ argument is always the JSON representation!
{ cmd.arg(msg.to_json());
if let Ok(status) = cmd.status().await {
if !status.success() { if !status.success() {
warn!( warn!(
"Blocked {}::{} because of return status {}", "Blocked {}::{} because of return status {}",
@ -156,7 +161,7 @@ impl WlMitmState {
); );
} }
return status.success(); return (msg.num_consumed_fds(), status.success());
} }
} }
@ -165,7 +170,7 @@ impl WlMitmState {
msg.self_object_type().interface(), msg.self_object_type().interface(),
msg.self_msg_name() msg.self_msg_name()
); );
return false; return (msg.num_consumed_fds(), false);
} }
WlFilterRequestAction::Block => { WlFilterRequestAction::Block => {
warn!( warn!(
@ -174,17 +179,17 @@ impl WlMitmState {
msg.self_msg_name() msg.self_msg_name()
); );
// TODO: don't just return false, build an error event // TODO: don't just return false, build an error event
return false; return (msg.num_consumed_fds(), false);
} }
} }
} }
} }
true (msg.num_consumed_fds(), true)
} }
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
pub async fn on_s2c_event(&mut self, raw_msg: &WlRawMsg) -> bool { pub async fn on_s2c_event(&mut self, raw_msg: &WlRawMsg) -> (usize, bool) {
let msg = match crate::proto::decode_event(&self.objects, raw_msg) { let msg = match crate::proto::decode_event(&self.objects, raw_msg) {
WaylandProtocolParsingOutcome::Ok(msg) => msg, WaylandProtocolParsingOutcome::Ok(msg) => msg,
WaylandProtocolParsingOutcome::MalformedMessage => { WaylandProtocolParsingOutcome::MalformedMessage => {
@ -194,10 +199,10 @@ impl WlMitmState {
num_fds = raw_msg.fds.len(), num_fds = raw_msg.fds.len(),
"Malformed event" "Malformed event"
); );
return false; return (0, false);
} }
_ => { _ => {
return true; return (0, true);
} }
}; };
@ -220,7 +225,7 @@ impl WlMitmState {
interface = msg.interface, interface = msg.interface,
"Removing interface from published globals" "Removing interface from published globals"
); );
return false; return (0, false);
} }
// Else, record the global object. These are the only ones we're ever going to allow through. // Else, record the global object. These are the only ones we're ever going to allow through.
@ -234,6 +239,6 @@ impl WlMitmState {
self.objects.remove_object(msg.id); self.objects.remove_object(msg.id);
} }
true (msg.num_consumed_fds(), true)
} }
} }