Compare commits
No commits in common. "e0cc002148ff8295b7e4edbf00f9ed3882e63f8c" and "a15bb3fa6e3cc26811ba0cbe8b7059fef4710bb0" have entirely different histories.
e0cc002148
...
a15bb3fa6e
9 changed files with 41 additions and 144 deletions
26
Cargo.lock
generated
26
Cargo.lock
generated
|
@ -95,7 +95,6 @@ dependencies = [
|
|||
"az",
|
||||
"bytemuck",
|
||||
"half",
|
||||
"serde",
|
||||
"typenum",
|
||||
]
|
||||
|
||||
|
@ -131,12 +130,6 @@ dependencies = [
|
|||
"hashbrown",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itoa"
|
||||
version = "1.0.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674"
|
||||
|
||||
[[package]]
|
||||
name = "lazy_static"
|
||||
version = "1.5.0"
|
||||
|
@ -273,12 +266,6 @@ version = "0.1.24"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f"
|
||||
|
||||
[[package]]
|
||||
name = "ryu"
|
||||
version = "1.0.19"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd"
|
||||
|
||||
[[package]]
|
||||
name = "sendfd"
|
||||
version = "0.4.3"
|
||||
|
@ -309,18 +296,6 @@ dependencies = [
|
|||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_json"
|
||||
version = "1.0.139"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "44f86c3acccc9c65b153fe1b85a3be07fe5515274ec9f0653b4a0875731c72a6"
|
||||
dependencies = [
|
||||
"itoa",
|
||||
"memchr",
|
||||
"ryu",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_spanned"
|
||||
version = "0.6.8"
|
||||
|
@ -644,7 +619,6 @@ dependencies = [
|
|||
"sendfd",
|
||||
"serde",
|
||||
"serde_derive",
|
||||
"serde_json",
|
||||
"tokio",
|
||||
"toml",
|
||||
"tracing",
|
||||
|
|
|
@ -11,12 +11,11 @@ members = ["protogen"]
|
|||
[dependencies]
|
||||
byteorder = "1.5.0"
|
||||
bytes = "1.10.0"
|
||||
fixed = { version = "1.29.0", features = [ "serde" ] }
|
||||
fixed = "1.29.0"
|
||||
nix = "0.29.0"
|
||||
sendfd = { version = "0.4", features = [ "tokio" ] }
|
||||
serde = "1.0.218"
|
||||
serde_derive = "1.0.218"
|
||||
serde_json = "1.0.139"
|
||||
tokio = { version = "1.43.0", features = [ "fs", "net", "rt", "rt-multi-thread", "macros", "io-util", "process" ]}
|
||||
toml = "0.8.20"
|
||||
tracing = "0.1.41"
|
||||
|
|
|
@ -36,9 +36,6 @@ allowed_globals = [
|
|||
#
|
||||
# The first and second arguments to this program will be the interface
|
||||
# and request name, respectively.
|
||||
#
|
||||
# The number of arguments may change in the future, but the _last_ argument
|
||||
# is always a JSON-serialized representation of the request's arguments.
|
||||
ask_cmd = "contrib/ask-bemenu.sh"
|
||||
|
||||
# A list of requests we'd like to filter
|
||||
|
|
|
@ -98,30 +98,12 @@ impl WlMsg {
|
|||
let parser_fn_name = format_ident!("{}", self.parser_fn_name());
|
||||
|
||||
// Build all field names and their corresponding Rust type identifiers
|
||||
let (field_names, (field_types, field_attrs)): (Vec<_>, (Vec<_>, Vec<_>)) = self
|
||||
let (field_names, field_types): (Vec<_>, Vec<_>) = self
|
||||
.args
|
||||
.iter()
|
||||
.map(|(name, tt)| {
|
||||
(
|
||||
format_ident!("{name}"),
|
||||
(
|
||||
tt.to_rust_type(),
|
||||
match tt {
|
||||
// Can't serialize fds!
|
||||
WlArgType::Fd => quote! { #[serde(skip)] },
|
||||
_ => quote! {},
|
||||
},
|
||||
),
|
||||
)
|
||||
})
|
||||
.map(|(name, tt)| (format_ident!("{name}"), tt.to_rust_type()))
|
||||
.unzip();
|
||||
|
||||
let num_consumed_fds = self
|
||||
.args
|
||||
.iter()
|
||||
.filter(|(_, tt)| matches!(tt, WlArgType::Fd))
|
||||
.count();
|
||||
|
||||
// Generate code to include in the parser for every field
|
||||
let parser_code: Vec<_> = self
|
||||
.args
|
||||
|
@ -161,12 +143,10 @@ impl WlMsg {
|
|||
|
||||
quote! {
|
||||
#[allow(unused)]
|
||||
#[derive(Serialize)]
|
||||
pub struct #struct_name<'a> {
|
||||
#[serde(skip)]
|
||||
_phantom: std::marker::PhantomData<&'a ()>,
|
||||
obj_id: u32,
|
||||
#( #field_attrs pub #field_names: #field_types, )*
|
||||
#( pub #field_names: #field_types, )*
|
||||
}
|
||||
|
||||
impl<'a> __private::WlParsedMessagePrivate for #struct_name<'a> {}
|
||||
|
@ -204,7 +184,6 @@ impl WlMsg {
|
|||
fn try_from_msg_impl(msg: &crate::codec::WlRawMsg, _token: __private::WlParsedMessagePrivateToken) -> WaylandProtocolParsingOutcome<#struct_name> {
|
||||
let payload = msg.payload();
|
||||
let mut pos = 0usize;
|
||||
let mut pos_fds = 0usize;
|
||||
#( #parser_code )*
|
||||
WaylandProtocolParsingOutcome::Ok(#struct_name {
|
||||
_phantom: std::marker::PhantomData,
|
||||
|
@ -224,14 +203,6 @@ impl WlMsg {
|
|||
fn known_objects_created(&self) -> Option<Vec<(u32, WlObjectType)>> {
|
||||
#known_objects_created
|
||||
}
|
||||
|
||||
fn to_json(&self) -> String {
|
||||
serde_json::to_string(self).unwrap()
|
||||
}
|
||||
|
||||
fn num_consumed_fds(&self) -> usize {
|
||||
#num_consumed_fds
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<'a> AnyWlParsedMessage<'a> for #struct_name<'a> {}
|
||||
|
@ -405,12 +376,11 @@ impl WlArgType {
|
|||
};
|
||||
},
|
||||
WlArgType::Fd => quote! {
|
||||
if msg.fds.len() < pos_fds + 1 {
|
||||
if msg.fds.len() == 0 {
|
||||
return WaylandProtocolParsingOutcome::MalformedMessage;
|
||||
}
|
||||
|
||||
let #var_name: std::os::fd::BorrowedFd<'_> = std::os::fd::AsFd::as_fd(&msg.fds[pos_fds]);
|
||||
pos_fds += 1;
|
||||
let #var_name: std::os::fd::BorrowedFd<'_> = std::os::fd::AsFd::as_fd(&msg.fds[0]);
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
40
src/codec.rs
40
src/codec.rs
|
@ -1,4 +1,4 @@
|
|||
use std::{collections::VecDeque, os::fd::OwnedFd};
|
||||
use std::os::fd::OwnedFd;
|
||||
|
||||
use byteorder::{ByteOrder, NativeEndian};
|
||||
use bytes::{Bytes, BytesMut};
|
||||
|
@ -13,18 +13,11 @@ pub struct WlRawMsg {
|
|||
pub opcode: u16,
|
||||
// len bytes -- containing the header
|
||||
msg_buf: Bytes,
|
||||
/// All fds we have seen up until decoding this message frame
|
||||
/// fds aren't guaranteed to be separated between messages; therefore, there
|
||||
/// is no way for us to tell that all fds here belong to the current message
|
||||
/// without actually loading the Wayland XML protocols.
|
||||
///
|
||||
/// Instead, downstream parsers should return any unused fds back to the decoder
|
||||
/// with [WlDecoder::return_unused_fds].
|
||||
pub fds: Vec<OwnedFd>,
|
||||
pub fds: Box<[OwnedFd]>,
|
||||
}
|
||||
|
||||
impl WlRawMsg {
|
||||
pub fn try_decode(buf: &mut BytesMut, fds: &mut VecDeque<OwnedFd>) -> Option<WlRawMsg> {
|
||||
pub fn try_decode(buf: &mut BytesMut, fds: &mut Vec<OwnedFd>) -> Option<WlRawMsg> {
|
||||
let buf_len = buf.len();
|
||||
// Not even a complete message header
|
||||
if buf_len < 8 {
|
||||
|
@ -43,16 +36,14 @@ impl WlRawMsg {
|
|||
let msg_buf = buf.split_to(msg_len as usize);
|
||||
|
||||
let mut new_fds = Vec::with_capacity(fds.len());
|
||||
while let Some(fd) = fds.pop_front() {
|
||||
new_fds.push(fd);
|
||||
}
|
||||
new_fds.append(fds);
|
||||
|
||||
Some(WlRawMsg {
|
||||
obj_id,
|
||||
len: msg_len as u16,
|
||||
opcode: opcode as u16,
|
||||
msg_buf: msg_buf.freeze(),
|
||||
fds: new_fds,
|
||||
fds: new_fds.into_boxed_slice(),
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -61,7 +52,7 @@ impl WlRawMsg {
|
|||
}
|
||||
|
||||
pub fn into_parts(self) -> (Bytes, Box<[OwnedFd]>) {
|
||||
(self.msg_buf, self.fds.into_boxed_slice())
|
||||
(self.msg_buf, self.fds)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -73,25 +64,14 @@ pub enum DecoderOutcome {
|
|||
|
||||
pub struct WlDecoder {
|
||||
buf: BytesMut,
|
||||
fds: VecDeque<OwnedFd>,
|
||||
fds: Vec<OwnedFd>,
|
||||
}
|
||||
|
||||
impl WlDecoder {
|
||||
pub fn new() -> WlDecoder {
|
||||
WlDecoder {
|
||||
buf: BytesMut::new(),
|
||||
fds: VecDeque::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn return_unused_fds(&mut self, msg: &mut WlRawMsg, num_consumed: usize) {
|
||||
let mut unused = msg.fds.split_off(num_consumed);
|
||||
|
||||
// Add all unused vectors, in order, to the _front_ of our queue
|
||||
// This means that we take one item from the _back_ of the unused
|
||||
// chunk at a time and insert that to the _front_, to preserve order.
|
||||
while let Some(fd) = unused.pop() {
|
||||
self.fds.push_front(fd);
|
||||
fds: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -106,9 +86,9 @@ impl WlDecoder {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn decode_after_read(&mut self, buf: &[u8], fds: Vec<OwnedFd>) -> DecoderOutcome {
|
||||
pub fn decode_after_read(&mut self, buf: &[u8], fds: &mut Vec<OwnedFd>) -> DecoderOutcome {
|
||||
self.buf.extend_from_slice(&buf);
|
||||
self.fds.extend(fds.into_iter());
|
||||
self.fds.append(fds);
|
||||
|
||||
match WlRawMsg::try_decode(&mut self.buf, &mut self.fds) {
|
||||
Some(res) => DecoderOutcome::Decoded(res),
|
||||
|
|
|
@ -25,10 +25,6 @@ impl<'a> WlMsgReader<'a> {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn return_unused_fds(&mut self, msg: &mut WlRawMsg, num_consumed: usize) {
|
||||
self.decoder.return_unused_fds(msg, num_consumed);
|
||||
}
|
||||
|
||||
pub async fn read(&mut self) -> io::Result<DecoderOutcome> {
|
||||
if let Some(DecoderOutcome::Decoded(msg)) = self.decoder.decode_buf() {
|
||||
return Ok(DecoderOutcome::Decoded(msg));
|
||||
|
@ -54,7 +50,7 @@ impl<'a> WlMsgReader<'a> {
|
|||
|
||||
return Ok(self
|
||||
.decoder
|
||||
.decode_after_read(&tmp_buf[0..read_bytes], fd_vec));
|
||||
.decode_after_read(&tmp_buf[0..read_bytes], &mut fd_vec));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
14
src/main.rs
14
src/main.rs
|
@ -90,13 +90,10 @@ pub async fn handle_conn(
|
|||
tokio::select! {
|
||||
s2c_msg = upstream_read.read() => {
|
||||
match s2c_msg? {
|
||||
codec::DecoderOutcome::Decoded(mut wl_raw_msg) => {
|
||||
codec::DecoderOutcome::Decoded(wl_raw_msg) => {
|
||||
debug!(obj_id = wl_raw_msg.obj_id, opcode = wl_raw_msg.opcode, num_fds = wl_raw_msg.fds.len(), "s2c event");
|
||||
|
||||
let (num_consumed_fds, verdict) = state.on_s2c_event(&wl_raw_msg).await;
|
||||
upstream_read.return_unused_fds(&mut wl_raw_msg, num_consumed_fds);
|
||||
|
||||
if verdict {
|
||||
if state.on_s2c_event(&wl_raw_msg).await {
|
||||
downstream_write.queue_write(wl_raw_msg);
|
||||
}
|
||||
},
|
||||
|
@ -106,13 +103,10 @@ pub async fn handle_conn(
|
|||
},
|
||||
c2s_msg = downstream_read.read() => {
|
||||
match c2s_msg? {
|
||||
codec::DecoderOutcome::Decoded(mut wl_raw_msg) => {
|
||||
codec::DecoderOutcome::Decoded(wl_raw_msg) => {
|
||||
debug!(obj_id = wl_raw_msg.obj_id, opcode = wl_raw_msg.opcode, num_fds = wl_raw_msg.fds.len(), "c2s request");
|
||||
|
||||
let (num_consumed_fds, verdict) = state.on_c2s_request(&wl_raw_msg).await;
|
||||
downstream_read.return_unused_fds(&mut wl_raw_msg, num_consumed_fds);
|
||||
|
||||
if verdict {
|
||||
if state.on_c2s_request(&wl_raw_msg).await {
|
||||
upstream_write.queue_write(wl_raw_msg);
|
||||
}
|
||||
},
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
use std::{collections::HashMap, sync::LazyLock};
|
||||
|
||||
use byteorder::ByteOrder;
|
||||
use serde_derive::Serialize;
|
||||
|
||||
use crate::{
|
||||
codec::WlRawMsg,
|
||||
|
@ -115,13 +114,6 @@ pub trait WlParsedMessage<'a>: __private::WlParsedMessagePrivate {
|
|||
/// widely-used message with that capability is [WlRegistryBindRequest],
|
||||
/// which is already handled separately on its own.
|
||||
fn known_objects_created(&self) -> Option<Vec<(u32, WlObjectType)>>;
|
||||
|
||||
/// Serialize this message into a JSON string, for use with ask scripts
|
||||
fn to_json(&self) -> String;
|
||||
|
||||
/// How many fds have been consumed in parsing this message?
|
||||
/// This is used to return any unused fds to the decoder.
|
||||
fn num_consumed_fds(&self) -> usize;
|
||||
}
|
||||
|
||||
/// A version of [WlParsedMessage] that supports downcasting. By implementing this
|
||||
|
|
45
src/state.rs
45
src/state.rs
|
@ -60,9 +60,8 @@ impl WlMitmState {
|
|||
}
|
||||
}
|
||||
|
||||
/// Returns the number of fds consumed while parsing the message as a concrete Wayland type, and a verdict
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn on_c2s_request(&mut self, raw_msg: &WlRawMsg) -> (usize, bool) {
|
||||
pub async fn on_c2s_request(&mut self, raw_msg: &WlRawMsg) -> bool {
|
||||
let msg = match crate::proto::decode_request(&self.objects, raw_msg) {
|
||||
WaylandProtocolParsingOutcome::Ok(msg) => msg,
|
||||
WaylandProtocolParsingOutcome::MalformedMessage => {
|
||||
|
@ -74,11 +73,9 @@ impl WlMitmState {
|
|||
num_fds = raw_msg.fds.len(),
|
||||
"Malformed request"
|
||||
);
|
||||
return (0, false);
|
||||
return false;
|
||||
}
|
||||
_ => {
|
||||
// TODO: due to fds, we can't expect to be able to pass through unknown messages.
|
||||
// Load all sensible Wayland extensions and remove this condition.
|
||||
// Pass through all unknown messages -- they could be from a Wayland protocol we haven't
|
||||
// been built against!
|
||||
// Note that this won't pass through messages for globals we haven't allowed:
|
||||
|
@ -87,7 +84,7 @@ impl WlMitmState {
|
|||
// even for globals from protocols we don't know.
|
||||
// It does mean we can't filter against methods that create more objects _from_ that
|
||||
// global, though.
|
||||
return (0, true);
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -101,7 +98,7 @@ impl WlMitmState {
|
|||
// to bind to it; if it does, it's likely a malicious client!
|
||||
// So, we simply remove these messages from the stream, which will cause the Wayland server to error out.
|
||||
let Some(interface) = self.objects.lookup_global(msg.name) else {
|
||||
return (0, false);
|
||||
return false;
|
||||
};
|
||||
|
||||
if interface != msg.id_interface_name {
|
||||
|
@ -109,7 +106,7 @@ impl WlMitmState {
|
|||
"Client binding to interface {}, but the interface name {} should correspond to {}",
|
||||
msg.id_interface_name, msg.name, interface
|
||||
);
|
||||
return (0, false);
|
||||
return false;
|
||||
}
|
||||
|
||||
info!(
|
||||
|
@ -144,14 +141,12 @@ impl WlMitmState {
|
|||
msg.self_object_type().interface(),
|
||||
msg.self_msg_name()
|
||||
);
|
||||
|
||||
let mut cmd = tokio::process::Command::new(ask_cmd);
|
||||
cmd.arg(msg.self_object_type().interface());
|
||||
cmd.arg(msg.self_msg_name());
|
||||
// Note: the _last_ argument is always the JSON representation!
|
||||
cmd.arg(msg.to_json());
|
||||
|
||||
if let Ok(status) = cmd.status().await {
|
||||
if let Ok(status) = tokio::process::Command::new(ask_cmd)
|
||||
.arg(msg.self_object_type().interface())
|
||||
.arg(msg.self_msg_name())
|
||||
.status()
|
||||
.await
|
||||
{
|
||||
if !status.success() {
|
||||
warn!(
|
||||
"Blocked {}::{} because of return status {}",
|
||||
|
@ -161,7 +156,7 @@ impl WlMitmState {
|
|||
);
|
||||
}
|
||||
|
||||
return (msg.num_consumed_fds(), status.success());
|
||||
return status.success();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -170,7 +165,7 @@ impl WlMitmState {
|
|||
msg.self_object_type().interface(),
|
||||
msg.self_msg_name()
|
||||
);
|
||||
return (msg.num_consumed_fds(), false);
|
||||
return false;
|
||||
}
|
||||
WlFilterRequestAction::Block => {
|
||||
warn!(
|
||||
|
@ -179,17 +174,17 @@ impl WlMitmState {
|
|||
msg.self_msg_name()
|
||||
);
|
||||
// TODO: don't just return false, build an error event
|
||||
return (msg.num_consumed_fds(), false);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
(msg.num_consumed_fds(), true)
|
||||
true
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn on_s2c_event(&mut self, raw_msg: &WlRawMsg) -> (usize, bool) {
|
||||
pub async fn on_s2c_event(&mut self, raw_msg: &WlRawMsg) -> bool {
|
||||
let msg = match crate::proto::decode_event(&self.objects, raw_msg) {
|
||||
WaylandProtocolParsingOutcome::Ok(msg) => msg,
|
||||
WaylandProtocolParsingOutcome::MalformedMessage => {
|
||||
|
@ -199,10 +194,10 @@ impl WlMitmState {
|
|||
num_fds = raw_msg.fds.len(),
|
||||
"Malformed event"
|
||||
);
|
||||
return (0, false);
|
||||
return false;
|
||||
}
|
||||
_ => {
|
||||
return (0, true);
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -225,7 +220,7 @@ impl WlMitmState {
|
|||
interface = msg.interface,
|
||||
"Removing interface from published globals"
|
||||
);
|
||||
return (0, false);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Else, record the global object. These are the only ones we're ever going to allow through.
|
||||
|
@ -239,6 +234,6 @@ impl WlMitmState {
|
|||
self.objects.remove_object(msg.id);
|
||||
}
|
||||
|
||||
(msg.num_consumed_fds(), true)
|
||||
true
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue