Compare commits

..

No commits in common. "e0cc002148ff8295b7e4edbf00f9ed3882e63f8c" and "a15bb3fa6e3cc26811ba0cbe8b7059fef4710bb0" have entirely different histories.

9 changed files with 41 additions and 144 deletions

26
Cargo.lock generated
View file

@ -95,7 +95,6 @@ dependencies = [
"az", "az",
"bytemuck", "bytemuck",
"half", "half",
"serde",
"typenum", "typenum",
] ]
@ -131,12 +130,6 @@ dependencies = [
"hashbrown", "hashbrown",
] ]
[[package]]
name = "itoa"
version = "1.0.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674"
[[package]] [[package]]
name = "lazy_static" name = "lazy_static"
version = "1.5.0" version = "1.5.0"
@ -273,12 +266,6 @@ version = "0.1.24"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f"
[[package]]
name = "ryu"
version = "1.0.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd"
[[package]] [[package]]
name = "sendfd" name = "sendfd"
version = "0.4.3" version = "0.4.3"
@ -309,18 +296,6 @@ dependencies = [
"syn", "syn",
] ]
[[package]]
name = "serde_json"
version = "1.0.139"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "44f86c3acccc9c65b153fe1b85a3be07fe5515274ec9f0653b4a0875731c72a6"
dependencies = [
"itoa",
"memchr",
"ryu",
"serde",
]
[[package]] [[package]]
name = "serde_spanned" name = "serde_spanned"
version = "0.6.8" version = "0.6.8"
@ -644,7 +619,6 @@ dependencies = [
"sendfd", "sendfd",
"serde", "serde",
"serde_derive", "serde_derive",
"serde_json",
"tokio", "tokio",
"toml", "toml",
"tracing", "tracing",

View file

@ -11,12 +11,11 @@ members = ["protogen"]
[dependencies] [dependencies]
byteorder = "1.5.0" byteorder = "1.5.0"
bytes = "1.10.0" bytes = "1.10.0"
fixed = { version = "1.29.0", features = [ "serde" ] } fixed = "1.29.0"
nix = "0.29.0" nix = "0.29.0"
sendfd = { version = "0.4", features = [ "tokio" ] } sendfd = { version = "0.4", features = [ "tokio" ] }
serde = "1.0.218" serde = "1.0.218"
serde_derive = "1.0.218" serde_derive = "1.0.218"
serde_json = "1.0.139"
tokio = { version = "1.43.0", features = [ "fs", "net", "rt", "rt-multi-thread", "macros", "io-util", "process" ]} tokio = { version = "1.43.0", features = [ "fs", "net", "rt", "rt-multi-thread", "macros", "io-util", "process" ]}
toml = "0.8.20" toml = "0.8.20"
tracing = "0.1.41" tracing = "0.1.41"

View file

@ -36,9 +36,6 @@ allowed_globals = [
# #
# The first and second arguments to this program will be the interface # The first and second arguments to this program will be the interface
# and request name, respectively. # and request name, respectively.
#
# The number of arguments may change in the future, but the _last_ argument
# is always a JSON-serialized representation of the request's arguments.
ask_cmd = "contrib/ask-bemenu.sh" ask_cmd = "contrib/ask-bemenu.sh"
# A list of requests we'd like to filter # A list of requests we'd like to filter

View file

@ -98,30 +98,12 @@ impl WlMsg {
let parser_fn_name = format_ident!("{}", self.parser_fn_name()); let parser_fn_name = format_ident!("{}", self.parser_fn_name());
// Build all field names and their corresponding Rust type identifiers // Build all field names and their corresponding Rust type identifiers
let (field_names, (field_types, field_attrs)): (Vec<_>, (Vec<_>, Vec<_>)) = self let (field_names, field_types): (Vec<_>, Vec<_>) = self
.args .args
.iter() .iter()
.map(|(name, tt)| { .map(|(name, tt)| (format_ident!("{name}"), tt.to_rust_type()))
(
format_ident!("{name}"),
(
tt.to_rust_type(),
match tt {
// Can't serialize fds!
WlArgType::Fd => quote! { #[serde(skip)] },
_ => quote! {},
},
),
)
})
.unzip(); .unzip();
let num_consumed_fds = self
.args
.iter()
.filter(|(_, tt)| matches!(tt, WlArgType::Fd))
.count();
// Generate code to include in the parser for every field // Generate code to include in the parser for every field
let parser_code: Vec<_> = self let parser_code: Vec<_> = self
.args .args
@ -161,12 +143,10 @@ impl WlMsg {
quote! { quote! {
#[allow(unused)] #[allow(unused)]
#[derive(Serialize)]
pub struct #struct_name<'a> { pub struct #struct_name<'a> {
#[serde(skip)]
_phantom: std::marker::PhantomData<&'a ()>, _phantom: std::marker::PhantomData<&'a ()>,
obj_id: u32, obj_id: u32,
#( #field_attrs pub #field_names: #field_types, )* #( pub #field_names: #field_types, )*
} }
impl<'a> __private::WlParsedMessagePrivate for #struct_name<'a> {} impl<'a> __private::WlParsedMessagePrivate for #struct_name<'a> {}
@ -204,7 +184,6 @@ impl WlMsg {
fn try_from_msg_impl(msg: &crate::codec::WlRawMsg, _token: __private::WlParsedMessagePrivateToken) -> WaylandProtocolParsingOutcome<#struct_name> { fn try_from_msg_impl(msg: &crate::codec::WlRawMsg, _token: __private::WlParsedMessagePrivateToken) -> WaylandProtocolParsingOutcome<#struct_name> {
let payload = msg.payload(); let payload = msg.payload();
let mut pos = 0usize; let mut pos = 0usize;
let mut pos_fds = 0usize;
#( #parser_code )* #( #parser_code )*
WaylandProtocolParsingOutcome::Ok(#struct_name { WaylandProtocolParsingOutcome::Ok(#struct_name {
_phantom: std::marker::PhantomData, _phantom: std::marker::PhantomData,
@ -224,14 +203,6 @@ impl WlMsg {
fn known_objects_created(&self) -> Option<Vec<(u32, WlObjectType)>> { fn known_objects_created(&self) -> Option<Vec<(u32, WlObjectType)>> {
#known_objects_created #known_objects_created
} }
fn to_json(&self) -> String {
serde_json::to_string(self).unwrap()
}
fn num_consumed_fds(&self) -> usize {
#num_consumed_fds
}
} }
unsafe impl<'a> AnyWlParsedMessage<'a> for #struct_name<'a> {} unsafe impl<'a> AnyWlParsedMessage<'a> for #struct_name<'a> {}
@ -405,12 +376,11 @@ impl WlArgType {
}; };
}, },
WlArgType::Fd => quote! { WlArgType::Fd => quote! {
if msg.fds.len() < pos_fds + 1 { if msg.fds.len() == 0 {
return WaylandProtocolParsingOutcome::MalformedMessage; return WaylandProtocolParsingOutcome::MalformedMessage;
} }
let #var_name: std::os::fd::BorrowedFd<'_> = std::os::fd::AsFd::as_fd(&msg.fds[pos_fds]); let #var_name: std::os::fd::BorrowedFd<'_> = std::os::fd::AsFd::as_fd(&msg.fds[0]);
pos_fds += 1;
}, },
} }
} }

View file

@ -1,4 +1,4 @@
use std::{collections::VecDeque, os::fd::OwnedFd}; use std::os::fd::OwnedFd;
use byteorder::{ByteOrder, NativeEndian}; use byteorder::{ByteOrder, NativeEndian};
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
@ -13,18 +13,11 @@ pub struct WlRawMsg {
pub opcode: u16, pub opcode: u16,
// len bytes -- containing the header // len bytes -- containing the header
msg_buf: Bytes, msg_buf: Bytes,
/// All fds we have seen up until decoding this message frame pub fds: Box<[OwnedFd]>,
/// fds aren't guaranteed to be separated between messages; therefore, there
/// is no way for us to tell that all fds here belong to the current message
/// without actually loading the Wayland XML protocols.
///
/// Instead, downstream parsers should return any unused fds back to the decoder
/// with [WlDecoder::return_unused_fds].
pub fds: Vec<OwnedFd>,
} }
impl WlRawMsg { impl WlRawMsg {
pub fn try_decode(buf: &mut BytesMut, fds: &mut VecDeque<OwnedFd>) -> Option<WlRawMsg> { pub fn try_decode(buf: &mut BytesMut, fds: &mut Vec<OwnedFd>) -> Option<WlRawMsg> {
let buf_len = buf.len(); let buf_len = buf.len();
// Not even a complete message header // Not even a complete message header
if buf_len < 8 { if buf_len < 8 {
@ -43,16 +36,14 @@ impl WlRawMsg {
let msg_buf = buf.split_to(msg_len as usize); let msg_buf = buf.split_to(msg_len as usize);
let mut new_fds = Vec::with_capacity(fds.len()); let mut new_fds = Vec::with_capacity(fds.len());
while let Some(fd) = fds.pop_front() { new_fds.append(fds);
new_fds.push(fd);
}
Some(WlRawMsg { Some(WlRawMsg {
obj_id, obj_id,
len: msg_len as u16, len: msg_len as u16,
opcode: opcode as u16, opcode: opcode as u16,
msg_buf: msg_buf.freeze(), msg_buf: msg_buf.freeze(),
fds: new_fds, fds: new_fds.into_boxed_slice(),
}) })
} }
@ -61,7 +52,7 @@ impl WlRawMsg {
} }
pub fn into_parts(self) -> (Bytes, Box<[OwnedFd]>) { pub fn into_parts(self) -> (Bytes, Box<[OwnedFd]>) {
(self.msg_buf, self.fds.into_boxed_slice()) (self.msg_buf, self.fds)
} }
} }
@ -73,25 +64,14 @@ pub enum DecoderOutcome {
pub struct WlDecoder { pub struct WlDecoder {
buf: BytesMut, buf: BytesMut,
fds: VecDeque<OwnedFd>, fds: Vec<OwnedFd>,
} }
impl WlDecoder { impl WlDecoder {
pub fn new() -> WlDecoder { pub fn new() -> WlDecoder {
WlDecoder { WlDecoder {
buf: BytesMut::new(), buf: BytesMut::new(),
fds: VecDeque::new(), fds: Vec::new(),
}
}
pub fn return_unused_fds(&mut self, msg: &mut WlRawMsg, num_consumed: usize) {
let mut unused = msg.fds.split_off(num_consumed);
// Add all unused vectors, in order, to the _front_ of our queue
// This means that we take one item from the _back_ of the unused
// chunk at a time and insert that to the _front_, to preserve order.
while let Some(fd) = unused.pop() {
self.fds.push_front(fd);
} }
} }
@ -106,9 +86,9 @@ impl WlDecoder {
} }
} }
pub fn decode_after_read(&mut self, buf: &[u8], fds: Vec<OwnedFd>) -> DecoderOutcome { pub fn decode_after_read(&mut self, buf: &[u8], fds: &mut Vec<OwnedFd>) -> DecoderOutcome {
self.buf.extend_from_slice(&buf); self.buf.extend_from_slice(&buf);
self.fds.extend(fds.into_iter()); self.fds.append(fds);
match WlRawMsg::try_decode(&mut self.buf, &mut self.fds) { match WlRawMsg::try_decode(&mut self.buf, &mut self.fds) {
Some(res) => DecoderOutcome::Decoded(res), Some(res) => DecoderOutcome::Decoded(res),

View file

@ -25,10 +25,6 @@ impl<'a> WlMsgReader<'a> {
} }
} }
pub fn return_unused_fds(&mut self, msg: &mut WlRawMsg, num_consumed: usize) {
self.decoder.return_unused_fds(msg, num_consumed);
}
pub async fn read(&mut self) -> io::Result<DecoderOutcome> { pub async fn read(&mut self) -> io::Result<DecoderOutcome> {
if let Some(DecoderOutcome::Decoded(msg)) = self.decoder.decode_buf() { if let Some(DecoderOutcome::Decoded(msg)) = self.decoder.decode_buf() {
return Ok(DecoderOutcome::Decoded(msg)); return Ok(DecoderOutcome::Decoded(msg));
@ -54,7 +50,7 @@ impl<'a> WlMsgReader<'a> {
return Ok(self return Ok(self
.decoder .decoder
.decode_after_read(&tmp_buf[0..read_bytes], fd_vec)); .decode_after_read(&tmp_buf[0..read_bytes], &mut fd_vec));
} }
} }
} }

View file

@ -90,13 +90,10 @@ pub async fn handle_conn(
tokio::select! { tokio::select! {
s2c_msg = upstream_read.read() => { s2c_msg = upstream_read.read() => {
match s2c_msg? { match s2c_msg? {
codec::DecoderOutcome::Decoded(mut wl_raw_msg) => { codec::DecoderOutcome::Decoded(wl_raw_msg) => {
debug!(obj_id = wl_raw_msg.obj_id, opcode = wl_raw_msg.opcode, num_fds = wl_raw_msg.fds.len(), "s2c event"); debug!(obj_id = wl_raw_msg.obj_id, opcode = wl_raw_msg.opcode, num_fds = wl_raw_msg.fds.len(), "s2c event");
let (num_consumed_fds, verdict) = state.on_s2c_event(&wl_raw_msg).await; if state.on_s2c_event(&wl_raw_msg).await {
upstream_read.return_unused_fds(&mut wl_raw_msg, num_consumed_fds);
if verdict {
downstream_write.queue_write(wl_raw_msg); downstream_write.queue_write(wl_raw_msg);
} }
}, },
@ -106,13 +103,10 @@ pub async fn handle_conn(
}, },
c2s_msg = downstream_read.read() => { c2s_msg = downstream_read.read() => {
match c2s_msg? { match c2s_msg? {
codec::DecoderOutcome::Decoded(mut wl_raw_msg) => { codec::DecoderOutcome::Decoded(wl_raw_msg) => {
debug!(obj_id = wl_raw_msg.obj_id, opcode = wl_raw_msg.opcode, num_fds = wl_raw_msg.fds.len(), "c2s request"); debug!(obj_id = wl_raw_msg.obj_id, opcode = wl_raw_msg.opcode, num_fds = wl_raw_msg.fds.len(), "c2s request");
let (num_consumed_fds, verdict) = state.on_c2s_request(&wl_raw_msg).await; if state.on_c2s_request(&wl_raw_msg).await {
downstream_read.return_unused_fds(&mut wl_raw_msg, num_consumed_fds);
if verdict {
upstream_write.queue_write(wl_raw_msg); upstream_write.queue_write(wl_raw_msg);
} }
}, },

View file

@ -3,7 +3,6 @@
use std::{collections::HashMap, sync::LazyLock}; use std::{collections::HashMap, sync::LazyLock};
use byteorder::ByteOrder; use byteorder::ByteOrder;
use serde_derive::Serialize;
use crate::{ use crate::{
codec::WlRawMsg, codec::WlRawMsg,
@ -115,13 +114,6 @@ pub trait WlParsedMessage<'a>: __private::WlParsedMessagePrivate {
/// widely-used message with that capability is [WlRegistryBindRequest], /// widely-used message with that capability is [WlRegistryBindRequest],
/// which is already handled separately on its own. /// which is already handled separately on its own.
fn known_objects_created(&self) -> Option<Vec<(u32, WlObjectType)>>; fn known_objects_created(&self) -> Option<Vec<(u32, WlObjectType)>>;
/// Serialize this message into a JSON string, for use with ask scripts
fn to_json(&self) -> String;
/// How many fds have been consumed in parsing this message?
/// This is used to return any unused fds to the decoder.
fn num_consumed_fds(&self) -> usize;
} }
/// A version of [WlParsedMessage] that supports downcasting. By implementing this /// A version of [WlParsedMessage] that supports downcasting. By implementing this

View file

@ -60,9 +60,8 @@ impl WlMitmState {
} }
} }
/// Returns the number of fds consumed while parsing the message as a concrete Wayland type, and a verdict
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
pub async fn on_c2s_request(&mut self, raw_msg: &WlRawMsg) -> (usize, bool) { pub async fn on_c2s_request(&mut self, raw_msg: &WlRawMsg) -> bool {
let msg = match crate::proto::decode_request(&self.objects, raw_msg) { let msg = match crate::proto::decode_request(&self.objects, raw_msg) {
WaylandProtocolParsingOutcome::Ok(msg) => msg, WaylandProtocolParsingOutcome::Ok(msg) => msg,
WaylandProtocolParsingOutcome::MalformedMessage => { WaylandProtocolParsingOutcome::MalformedMessage => {
@ -74,11 +73,9 @@ impl WlMitmState {
num_fds = raw_msg.fds.len(), num_fds = raw_msg.fds.len(),
"Malformed request" "Malformed request"
); );
return (0, false); return false;
} }
_ => { _ => {
// TODO: due to fds, we can't expect to be able to pass through unknown messages.
// Load all sensible Wayland extensions and remove this condition.
// Pass through all unknown messages -- they could be from a Wayland protocol we haven't // Pass through all unknown messages -- they could be from a Wayland protocol we haven't
// been built against! // been built against!
// Note that this won't pass through messages for globals we haven't allowed: // Note that this won't pass through messages for globals we haven't allowed:
@ -87,7 +84,7 @@ impl WlMitmState {
// even for globals from protocols we don't know. // even for globals from protocols we don't know.
// It does mean we can't filter against methods that create more objects _from_ that // It does mean we can't filter against methods that create more objects _from_ that
// global, though. // global, though.
return (0, true); return true;
} }
}; };
@ -101,7 +98,7 @@ impl WlMitmState {
// to bind to it; if it does, it's likely a malicious client! // to bind to it; if it does, it's likely a malicious client!
// So, we simply remove these messages from the stream, which will cause the Wayland server to error out. // So, we simply remove these messages from the stream, which will cause the Wayland server to error out.
let Some(interface) = self.objects.lookup_global(msg.name) else { let Some(interface) = self.objects.lookup_global(msg.name) else {
return (0, false); return false;
}; };
if interface != msg.id_interface_name { if interface != msg.id_interface_name {
@ -109,7 +106,7 @@ impl WlMitmState {
"Client binding to interface {}, but the interface name {} should correspond to {}", "Client binding to interface {}, but the interface name {} should correspond to {}",
msg.id_interface_name, msg.name, interface msg.id_interface_name, msg.name, interface
); );
return (0, false); return false;
} }
info!( info!(
@ -144,14 +141,12 @@ impl WlMitmState {
msg.self_object_type().interface(), msg.self_object_type().interface(),
msg.self_msg_name() msg.self_msg_name()
); );
if let Ok(status) = tokio::process::Command::new(ask_cmd)
let mut cmd = tokio::process::Command::new(ask_cmd); .arg(msg.self_object_type().interface())
cmd.arg(msg.self_object_type().interface()); .arg(msg.self_msg_name())
cmd.arg(msg.self_msg_name()); .status()
// Note: the _last_ argument is always the JSON representation! .await
cmd.arg(msg.to_json()); {
if let Ok(status) = cmd.status().await {
if !status.success() { if !status.success() {
warn!( warn!(
"Blocked {}::{} because of return status {}", "Blocked {}::{} because of return status {}",
@ -161,7 +156,7 @@ impl WlMitmState {
); );
} }
return (msg.num_consumed_fds(), status.success()); return status.success();
} }
} }
@ -170,7 +165,7 @@ impl WlMitmState {
msg.self_object_type().interface(), msg.self_object_type().interface(),
msg.self_msg_name() msg.self_msg_name()
); );
return (msg.num_consumed_fds(), false); return false;
} }
WlFilterRequestAction::Block => { WlFilterRequestAction::Block => {
warn!( warn!(
@ -179,17 +174,17 @@ impl WlMitmState {
msg.self_msg_name() msg.self_msg_name()
); );
// TODO: don't just return false, build an error event // TODO: don't just return false, build an error event
return (msg.num_consumed_fds(), false); return false;
} }
} }
} }
} }
(msg.num_consumed_fds(), true) true
} }
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
pub async fn on_s2c_event(&mut self, raw_msg: &WlRawMsg) -> (usize, bool) { pub async fn on_s2c_event(&mut self, raw_msg: &WlRawMsg) -> bool {
let msg = match crate::proto::decode_event(&self.objects, raw_msg) { let msg = match crate::proto::decode_event(&self.objects, raw_msg) {
WaylandProtocolParsingOutcome::Ok(msg) => msg, WaylandProtocolParsingOutcome::Ok(msg) => msg,
WaylandProtocolParsingOutcome::MalformedMessage => { WaylandProtocolParsingOutcome::MalformedMessage => {
@ -199,10 +194,10 @@ impl WlMitmState {
num_fds = raw_msg.fds.len(), num_fds = raw_msg.fds.len(),
"Malformed event" "Malformed event"
); );
return (0, false); return false;
} }
_ => { _ => {
return (0, true); return true;
} }
}; };
@ -225,7 +220,7 @@ impl WlMitmState {
interface = msg.interface, interface = msg.interface,
"Removing interface from published globals" "Removing interface from published globals"
); );
return (0, false); return false;
} }
// Else, record the global object. These are the only ones we're ever going to allow through. // Else, record the global object. These are the only ones we're ever going to allow through.
@ -239,6 +234,6 @@ impl WlMitmState {
self.objects.remove_object(msg.id); self.objects.remove_object(msg.id);
} }
(msg.num_consumed_fds(), true) true
} }
} }