Compare commits
2 commits
a15bb3fa6e
...
e0cc002148
Author | SHA1 | Date | |
---|---|---|---|
e0cc002148 | |||
d0afdbdda2 |
9 changed files with 144 additions and 41 deletions
26
Cargo.lock
generated
26
Cargo.lock
generated
|
@ -95,6 +95,7 @@ dependencies = [
|
||||||
"az",
|
"az",
|
||||||
"bytemuck",
|
"bytemuck",
|
||||||
"half",
|
"half",
|
||||||
|
"serde",
|
||||||
"typenum",
|
"typenum",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -130,6 +131,12 @@ dependencies = [
|
||||||
"hashbrown",
|
"hashbrown",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "itoa"
|
||||||
|
version = "1.0.14"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lazy_static"
|
name = "lazy_static"
|
||||||
version = "1.5.0"
|
version = "1.5.0"
|
||||||
|
@ -266,6 +273,12 @@ version = "0.1.24"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f"
|
checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ryu"
|
||||||
|
version = "1.0.19"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "sendfd"
|
name = "sendfd"
|
||||||
version = "0.4.3"
|
version = "0.4.3"
|
||||||
|
@ -296,6 +309,18 @@ dependencies = [
|
||||||
"syn",
|
"syn",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "serde_json"
|
||||||
|
version = "1.0.139"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "44f86c3acccc9c65b153fe1b85a3be07fe5515274ec9f0653b4a0875731c72a6"
|
||||||
|
dependencies = [
|
||||||
|
"itoa",
|
||||||
|
"memchr",
|
||||||
|
"ryu",
|
||||||
|
"serde",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde_spanned"
|
name = "serde_spanned"
|
||||||
version = "0.6.8"
|
version = "0.6.8"
|
||||||
|
@ -619,6 +644,7 @@ dependencies = [
|
||||||
"sendfd",
|
"sendfd",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_derive",
|
"serde_derive",
|
||||||
|
"serde_json",
|
||||||
"tokio",
|
"tokio",
|
||||||
"toml",
|
"toml",
|
||||||
"tracing",
|
"tracing",
|
||||||
|
|
|
@ -11,11 +11,12 @@ members = ["protogen"]
|
||||||
[dependencies]
|
[dependencies]
|
||||||
byteorder = "1.5.0"
|
byteorder = "1.5.0"
|
||||||
bytes = "1.10.0"
|
bytes = "1.10.0"
|
||||||
fixed = "1.29.0"
|
fixed = { version = "1.29.0", features = [ "serde" ] }
|
||||||
nix = "0.29.0"
|
nix = "0.29.0"
|
||||||
sendfd = { version = "0.4", features = [ "tokio" ] }
|
sendfd = { version = "0.4", features = [ "tokio" ] }
|
||||||
serde = "1.0.218"
|
serde = "1.0.218"
|
||||||
serde_derive = "1.0.218"
|
serde_derive = "1.0.218"
|
||||||
|
serde_json = "1.0.139"
|
||||||
tokio = { version = "1.43.0", features = [ "fs", "net", "rt", "rt-multi-thread", "macros", "io-util", "process" ]}
|
tokio = { version = "1.43.0", features = [ "fs", "net", "rt", "rt-multi-thread", "macros", "io-util", "process" ]}
|
||||||
toml = "0.8.20"
|
toml = "0.8.20"
|
||||||
tracing = "0.1.41"
|
tracing = "0.1.41"
|
||||||
|
|
|
@ -36,6 +36,9 @@ allowed_globals = [
|
||||||
#
|
#
|
||||||
# The first and second arguments to this program will be the interface
|
# The first and second arguments to this program will be the interface
|
||||||
# and request name, respectively.
|
# and request name, respectively.
|
||||||
|
#
|
||||||
|
# The number of arguments may change in the future, but the _last_ argument
|
||||||
|
# is always a JSON-serialized representation of the request's arguments.
|
||||||
ask_cmd = "contrib/ask-bemenu.sh"
|
ask_cmd = "contrib/ask-bemenu.sh"
|
||||||
|
|
||||||
# A list of requests we'd like to filter
|
# A list of requests we'd like to filter
|
||||||
|
|
|
@ -98,12 +98,30 @@ impl WlMsg {
|
||||||
let parser_fn_name = format_ident!("{}", self.parser_fn_name());
|
let parser_fn_name = format_ident!("{}", self.parser_fn_name());
|
||||||
|
|
||||||
// Build all field names and their corresponding Rust type identifiers
|
// Build all field names and their corresponding Rust type identifiers
|
||||||
let (field_names, field_types): (Vec<_>, Vec<_>) = self
|
let (field_names, (field_types, field_attrs)): (Vec<_>, (Vec<_>, Vec<_>)) = self
|
||||||
.args
|
.args
|
||||||
.iter()
|
.iter()
|
||||||
.map(|(name, tt)| (format_ident!("{name}"), tt.to_rust_type()))
|
.map(|(name, tt)| {
|
||||||
|
(
|
||||||
|
format_ident!("{name}"),
|
||||||
|
(
|
||||||
|
tt.to_rust_type(),
|
||||||
|
match tt {
|
||||||
|
// Can't serialize fds!
|
||||||
|
WlArgType::Fd => quote! { #[serde(skip)] },
|
||||||
|
_ => quote! {},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
})
|
||||||
.unzip();
|
.unzip();
|
||||||
|
|
||||||
|
let num_consumed_fds = self
|
||||||
|
.args
|
||||||
|
.iter()
|
||||||
|
.filter(|(_, tt)| matches!(tt, WlArgType::Fd))
|
||||||
|
.count();
|
||||||
|
|
||||||
// Generate code to include in the parser for every field
|
// Generate code to include in the parser for every field
|
||||||
let parser_code: Vec<_> = self
|
let parser_code: Vec<_> = self
|
||||||
.args
|
.args
|
||||||
|
@ -143,10 +161,12 @@ impl WlMsg {
|
||||||
|
|
||||||
quote! {
|
quote! {
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
|
#[derive(Serialize)]
|
||||||
pub struct #struct_name<'a> {
|
pub struct #struct_name<'a> {
|
||||||
|
#[serde(skip)]
|
||||||
_phantom: std::marker::PhantomData<&'a ()>,
|
_phantom: std::marker::PhantomData<&'a ()>,
|
||||||
obj_id: u32,
|
obj_id: u32,
|
||||||
#( pub #field_names: #field_types, )*
|
#( #field_attrs pub #field_names: #field_types, )*
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> __private::WlParsedMessagePrivate for #struct_name<'a> {}
|
impl<'a> __private::WlParsedMessagePrivate for #struct_name<'a> {}
|
||||||
|
@ -184,6 +204,7 @@ impl WlMsg {
|
||||||
fn try_from_msg_impl(msg: &crate::codec::WlRawMsg, _token: __private::WlParsedMessagePrivateToken) -> WaylandProtocolParsingOutcome<#struct_name> {
|
fn try_from_msg_impl(msg: &crate::codec::WlRawMsg, _token: __private::WlParsedMessagePrivateToken) -> WaylandProtocolParsingOutcome<#struct_name> {
|
||||||
let payload = msg.payload();
|
let payload = msg.payload();
|
||||||
let mut pos = 0usize;
|
let mut pos = 0usize;
|
||||||
|
let mut pos_fds = 0usize;
|
||||||
#( #parser_code )*
|
#( #parser_code )*
|
||||||
WaylandProtocolParsingOutcome::Ok(#struct_name {
|
WaylandProtocolParsingOutcome::Ok(#struct_name {
|
||||||
_phantom: std::marker::PhantomData,
|
_phantom: std::marker::PhantomData,
|
||||||
|
@ -203,6 +224,14 @@ impl WlMsg {
|
||||||
fn known_objects_created(&self) -> Option<Vec<(u32, WlObjectType)>> {
|
fn known_objects_created(&self) -> Option<Vec<(u32, WlObjectType)>> {
|
||||||
#known_objects_created
|
#known_objects_created
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn to_json(&self) -> String {
|
||||||
|
serde_json::to_string(self).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn num_consumed_fds(&self) -> usize {
|
||||||
|
#num_consumed_fds
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<'a> AnyWlParsedMessage<'a> for #struct_name<'a> {}
|
unsafe impl<'a> AnyWlParsedMessage<'a> for #struct_name<'a> {}
|
||||||
|
@ -376,11 +405,12 @@ impl WlArgType {
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
WlArgType::Fd => quote! {
|
WlArgType::Fd => quote! {
|
||||||
if msg.fds.len() == 0 {
|
if msg.fds.len() < pos_fds + 1 {
|
||||||
return WaylandProtocolParsingOutcome::MalformedMessage;
|
return WaylandProtocolParsingOutcome::MalformedMessage;
|
||||||
}
|
}
|
||||||
|
|
||||||
let #var_name: std::os::fd::BorrowedFd<'_> = std::os::fd::AsFd::as_fd(&msg.fds[0]);
|
let #var_name: std::os::fd::BorrowedFd<'_> = std::os::fd::AsFd::as_fd(&msg.fds[pos_fds]);
|
||||||
|
pos_fds += 1;
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
40
src/codec.rs
40
src/codec.rs
|
@ -1,4 +1,4 @@
|
||||||
use std::os::fd::OwnedFd;
|
use std::{collections::VecDeque, os::fd::OwnedFd};
|
||||||
|
|
||||||
use byteorder::{ByteOrder, NativeEndian};
|
use byteorder::{ByteOrder, NativeEndian};
|
||||||
use bytes::{Bytes, BytesMut};
|
use bytes::{Bytes, BytesMut};
|
||||||
|
@ -13,11 +13,18 @@ pub struct WlRawMsg {
|
||||||
pub opcode: u16,
|
pub opcode: u16,
|
||||||
// len bytes -- containing the header
|
// len bytes -- containing the header
|
||||||
msg_buf: Bytes,
|
msg_buf: Bytes,
|
||||||
pub fds: Box<[OwnedFd]>,
|
/// All fds we have seen up until decoding this message frame
|
||||||
|
/// fds aren't guaranteed to be separated between messages; therefore, there
|
||||||
|
/// is no way for us to tell that all fds here belong to the current message
|
||||||
|
/// without actually loading the Wayland XML protocols.
|
||||||
|
///
|
||||||
|
/// Instead, downstream parsers should return any unused fds back to the decoder
|
||||||
|
/// with [WlDecoder::return_unused_fds].
|
||||||
|
pub fds: Vec<OwnedFd>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl WlRawMsg {
|
impl WlRawMsg {
|
||||||
pub fn try_decode(buf: &mut BytesMut, fds: &mut Vec<OwnedFd>) -> Option<WlRawMsg> {
|
pub fn try_decode(buf: &mut BytesMut, fds: &mut VecDeque<OwnedFd>) -> Option<WlRawMsg> {
|
||||||
let buf_len = buf.len();
|
let buf_len = buf.len();
|
||||||
// Not even a complete message header
|
// Not even a complete message header
|
||||||
if buf_len < 8 {
|
if buf_len < 8 {
|
||||||
|
@ -36,14 +43,16 @@ impl WlRawMsg {
|
||||||
let msg_buf = buf.split_to(msg_len as usize);
|
let msg_buf = buf.split_to(msg_len as usize);
|
||||||
|
|
||||||
let mut new_fds = Vec::with_capacity(fds.len());
|
let mut new_fds = Vec::with_capacity(fds.len());
|
||||||
new_fds.append(fds);
|
while let Some(fd) = fds.pop_front() {
|
||||||
|
new_fds.push(fd);
|
||||||
|
}
|
||||||
|
|
||||||
Some(WlRawMsg {
|
Some(WlRawMsg {
|
||||||
obj_id,
|
obj_id,
|
||||||
len: msg_len as u16,
|
len: msg_len as u16,
|
||||||
opcode: opcode as u16,
|
opcode: opcode as u16,
|
||||||
msg_buf: msg_buf.freeze(),
|
msg_buf: msg_buf.freeze(),
|
||||||
fds: new_fds.into_boxed_slice(),
|
fds: new_fds,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -52,7 +61,7 @@ impl WlRawMsg {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn into_parts(self) -> (Bytes, Box<[OwnedFd]>) {
|
pub fn into_parts(self) -> (Bytes, Box<[OwnedFd]>) {
|
||||||
(self.msg_buf, self.fds)
|
(self.msg_buf, self.fds.into_boxed_slice())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -64,14 +73,25 @@ pub enum DecoderOutcome {
|
||||||
|
|
||||||
pub struct WlDecoder {
|
pub struct WlDecoder {
|
||||||
buf: BytesMut,
|
buf: BytesMut,
|
||||||
fds: Vec<OwnedFd>,
|
fds: VecDeque<OwnedFd>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl WlDecoder {
|
impl WlDecoder {
|
||||||
pub fn new() -> WlDecoder {
|
pub fn new() -> WlDecoder {
|
||||||
WlDecoder {
|
WlDecoder {
|
||||||
buf: BytesMut::new(),
|
buf: BytesMut::new(),
|
||||||
fds: Vec::new(),
|
fds: VecDeque::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn return_unused_fds(&mut self, msg: &mut WlRawMsg, num_consumed: usize) {
|
||||||
|
let mut unused = msg.fds.split_off(num_consumed);
|
||||||
|
|
||||||
|
// Add all unused vectors, in order, to the _front_ of our queue
|
||||||
|
// This means that we take one item from the _back_ of the unused
|
||||||
|
// chunk at a time and insert that to the _front_, to preserve order.
|
||||||
|
while let Some(fd) = unused.pop() {
|
||||||
|
self.fds.push_front(fd);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -86,9 +106,9 @@ impl WlDecoder {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn decode_after_read(&mut self, buf: &[u8], fds: &mut Vec<OwnedFd>) -> DecoderOutcome {
|
pub fn decode_after_read(&mut self, buf: &[u8], fds: Vec<OwnedFd>) -> DecoderOutcome {
|
||||||
self.buf.extend_from_slice(&buf);
|
self.buf.extend_from_slice(&buf);
|
||||||
self.fds.append(fds);
|
self.fds.extend(fds.into_iter());
|
||||||
|
|
||||||
match WlRawMsg::try_decode(&mut self.buf, &mut self.fds) {
|
match WlRawMsg::try_decode(&mut self.buf, &mut self.fds) {
|
||||||
Some(res) => DecoderOutcome::Decoded(res),
|
Some(res) => DecoderOutcome::Decoded(res),
|
||||||
|
|
|
@ -25,6 +25,10 @@ impl<'a> WlMsgReader<'a> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn return_unused_fds(&mut self, msg: &mut WlRawMsg, num_consumed: usize) {
|
||||||
|
self.decoder.return_unused_fds(msg, num_consumed);
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn read(&mut self) -> io::Result<DecoderOutcome> {
|
pub async fn read(&mut self) -> io::Result<DecoderOutcome> {
|
||||||
if let Some(DecoderOutcome::Decoded(msg)) = self.decoder.decode_buf() {
|
if let Some(DecoderOutcome::Decoded(msg)) = self.decoder.decode_buf() {
|
||||||
return Ok(DecoderOutcome::Decoded(msg));
|
return Ok(DecoderOutcome::Decoded(msg));
|
||||||
|
@ -50,7 +54,7 @@ impl<'a> WlMsgReader<'a> {
|
||||||
|
|
||||||
return Ok(self
|
return Ok(self
|
||||||
.decoder
|
.decoder
|
||||||
.decode_after_read(&tmp_buf[0..read_bytes], &mut fd_vec));
|
.decode_after_read(&tmp_buf[0..read_bytes], fd_vec));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
14
src/main.rs
14
src/main.rs
|
@ -90,10 +90,13 @@ pub async fn handle_conn(
|
||||||
tokio::select! {
|
tokio::select! {
|
||||||
s2c_msg = upstream_read.read() => {
|
s2c_msg = upstream_read.read() => {
|
||||||
match s2c_msg? {
|
match s2c_msg? {
|
||||||
codec::DecoderOutcome::Decoded(wl_raw_msg) => {
|
codec::DecoderOutcome::Decoded(mut wl_raw_msg) => {
|
||||||
debug!(obj_id = wl_raw_msg.obj_id, opcode = wl_raw_msg.opcode, num_fds = wl_raw_msg.fds.len(), "s2c event");
|
debug!(obj_id = wl_raw_msg.obj_id, opcode = wl_raw_msg.opcode, num_fds = wl_raw_msg.fds.len(), "s2c event");
|
||||||
|
|
||||||
if state.on_s2c_event(&wl_raw_msg).await {
|
let (num_consumed_fds, verdict) = state.on_s2c_event(&wl_raw_msg).await;
|
||||||
|
upstream_read.return_unused_fds(&mut wl_raw_msg, num_consumed_fds);
|
||||||
|
|
||||||
|
if verdict {
|
||||||
downstream_write.queue_write(wl_raw_msg);
|
downstream_write.queue_write(wl_raw_msg);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -103,10 +106,13 @@ pub async fn handle_conn(
|
||||||
},
|
},
|
||||||
c2s_msg = downstream_read.read() => {
|
c2s_msg = downstream_read.read() => {
|
||||||
match c2s_msg? {
|
match c2s_msg? {
|
||||||
codec::DecoderOutcome::Decoded(wl_raw_msg) => {
|
codec::DecoderOutcome::Decoded(mut wl_raw_msg) => {
|
||||||
debug!(obj_id = wl_raw_msg.obj_id, opcode = wl_raw_msg.opcode, num_fds = wl_raw_msg.fds.len(), "c2s request");
|
debug!(obj_id = wl_raw_msg.obj_id, opcode = wl_raw_msg.opcode, num_fds = wl_raw_msg.fds.len(), "c2s request");
|
||||||
|
|
||||||
if state.on_c2s_request(&wl_raw_msg).await {
|
let (num_consumed_fds, verdict) = state.on_c2s_request(&wl_raw_msg).await;
|
||||||
|
downstream_read.return_unused_fds(&mut wl_raw_msg, num_consumed_fds);
|
||||||
|
|
||||||
|
if verdict {
|
||||||
upstream_write.queue_write(wl_raw_msg);
|
upstream_write.queue_write(wl_raw_msg);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
use std::{collections::HashMap, sync::LazyLock};
|
use std::{collections::HashMap, sync::LazyLock};
|
||||||
|
|
||||||
use byteorder::ByteOrder;
|
use byteorder::ByteOrder;
|
||||||
|
use serde_derive::Serialize;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
codec::WlRawMsg,
|
codec::WlRawMsg,
|
||||||
|
@ -114,6 +115,13 @@ pub trait WlParsedMessage<'a>: __private::WlParsedMessagePrivate {
|
||||||
/// widely-used message with that capability is [WlRegistryBindRequest],
|
/// widely-used message with that capability is [WlRegistryBindRequest],
|
||||||
/// which is already handled separately on its own.
|
/// which is already handled separately on its own.
|
||||||
fn known_objects_created(&self) -> Option<Vec<(u32, WlObjectType)>>;
|
fn known_objects_created(&self) -> Option<Vec<(u32, WlObjectType)>>;
|
||||||
|
|
||||||
|
/// Serialize this message into a JSON string, for use with ask scripts
|
||||||
|
fn to_json(&self) -> String;
|
||||||
|
|
||||||
|
/// How many fds have been consumed in parsing this message?
|
||||||
|
/// This is used to return any unused fds to the decoder.
|
||||||
|
fn num_consumed_fds(&self) -> usize;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A version of [WlParsedMessage] that supports downcasting. By implementing this
|
/// A version of [WlParsedMessage] that supports downcasting. By implementing this
|
||||||
|
|
45
src/state.rs
45
src/state.rs
|
@ -60,8 +60,9 @@ impl WlMitmState {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns the number of fds consumed while parsing the message as a concrete Wayland type, and a verdict
|
||||||
#[tracing::instrument(skip_all)]
|
#[tracing::instrument(skip_all)]
|
||||||
pub async fn on_c2s_request(&mut self, raw_msg: &WlRawMsg) -> bool {
|
pub async fn on_c2s_request(&mut self, raw_msg: &WlRawMsg) -> (usize, bool) {
|
||||||
let msg = match crate::proto::decode_request(&self.objects, raw_msg) {
|
let msg = match crate::proto::decode_request(&self.objects, raw_msg) {
|
||||||
WaylandProtocolParsingOutcome::Ok(msg) => msg,
|
WaylandProtocolParsingOutcome::Ok(msg) => msg,
|
||||||
WaylandProtocolParsingOutcome::MalformedMessage => {
|
WaylandProtocolParsingOutcome::MalformedMessage => {
|
||||||
|
@ -73,9 +74,11 @@ impl WlMitmState {
|
||||||
num_fds = raw_msg.fds.len(),
|
num_fds = raw_msg.fds.len(),
|
||||||
"Malformed request"
|
"Malformed request"
|
||||||
);
|
);
|
||||||
return false;
|
return (0, false);
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
|
// TODO: due to fds, we can't expect to be able to pass through unknown messages.
|
||||||
|
// Load all sensible Wayland extensions and remove this condition.
|
||||||
// Pass through all unknown messages -- they could be from a Wayland protocol we haven't
|
// Pass through all unknown messages -- they could be from a Wayland protocol we haven't
|
||||||
// been built against!
|
// been built against!
|
||||||
// Note that this won't pass through messages for globals we haven't allowed:
|
// Note that this won't pass through messages for globals we haven't allowed:
|
||||||
|
@ -84,7 +87,7 @@ impl WlMitmState {
|
||||||
// even for globals from protocols we don't know.
|
// even for globals from protocols we don't know.
|
||||||
// It does mean we can't filter against methods that create more objects _from_ that
|
// It does mean we can't filter against methods that create more objects _from_ that
|
||||||
// global, though.
|
// global, though.
|
||||||
return true;
|
return (0, true);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -98,7 +101,7 @@ impl WlMitmState {
|
||||||
// to bind to it; if it does, it's likely a malicious client!
|
// to bind to it; if it does, it's likely a malicious client!
|
||||||
// So, we simply remove these messages from the stream, which will cause the Wayland server to error out.
|
// So, we simply remove these messages from the stream, which will cause the Wayland server to error out.
|
||||||
let Some(interface) = self.objects.lookup_global(msg.name) else {
|
let Some(interface) = self.objects.lookup_global(msg.name) else {
|
||||||
return false;
|
return (0, false);
|
||||||
};
|
};
|
||||||
|
|
||||||
if interface != msg.id_interface_name {
|
if interface != msg.id_interface_name {
|
||||||
|
@ -106,7 +109,7 @@ impl WlMitmState {
|
||||||
"Client binding to interface {}, but the interface name {} should correspond to {}",
|
"Client binding to interface {}, but the interface name {} should correspond to {}",
|
||||||
msg.id_interface_name, msg.name, interface
|
msg.id_interface_name, msg.name, interface
|
||||||
);
|
);
|
||||||
return false;
|
return (0, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
|
@ -141,12 +144,14 @@ impl WlMitmState {
|
||||||
msg.self_object_type().interface(),
|
msg.self_object_type().interface(),
|
||||||
msg.self_msg_name()
|
msg.self_msg_name()
|
||||||
);
|
);
|
||||||
if let Ok(status) = tokio::process::Command::new(ask_cmd)
|
|
||||||
.arg(msg.self_object_type().interface())
|
let mut cmd = tokio::process::Command::new(ask_cmd);
|
||||||
.arg(msg.self_msg_name())
|
cmd.arg(msg.self_object_type().interface());
|
||||||
.status()
|
cmd.arg(msg.self_msg_name());
|
||||||
.await
|
// Note: the _last_ argument is always the JSON representation!
|
||||||
{
|
cmd.arg(msg.to_json());
|
||||||
|
|
||||||
|
if let Ok(status) = cmd.status().await {
|
||||||
if !status.success() {
|
if !status.success() {
|
||||||
warn!(
|
warn!(
|
||||||
"Blocked {}::{} because of return status {}",
|
"Blocked {}::{} because of return status {}",
|
||||||
|
@ -156,7 +161,7 @@ impl WlMitmState {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
return status.success();
|
return (msg.num_consumed_fds(), status.success());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -165,7 +170,7 @@ impl WlMitmState {
|
||||||
msg.self_object_type().interface(),
|
msg.self_object_type().interface(),
|
||||||
msg.self_msg_name()
|
msg.self_msg_name()
|
||||||
);
|
);
|
||||||
return false;
|
return (msg.num_consumed_fds(), false);
|
||||||
}
|
}
|
||||||
WlFilterRequestAction::Block => {
|
WlFilterRequestAction::Block => {
|
||||||
warn!(
|
warn!(
|
||||||
|
@ -174,17 +179,17 @@ impl WlMitmState {
|
||||||
msg.self_msg_name()
|
msg.self_msg_name()
|
||||||
);
|
);
|
||||||
// TODO: don't just return false, build an error event
|
// TODO: don't just return false, build an error event
|
||||||
return false;
|
return (msg.num_consumed_fds(), false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
true
|
(msg.num_consumed_fds(), true)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument(skip_all)]
|
#[tracing::instrument(skip_all)]
|
||||||
pub async fn on_s2c_event(&mut self, raw_msg: &WlRawMsg) -> bool {
|
pub async fn on_s2c_event(&mut self, raw_msg: &WlRawMsg) -> (usize, bool) {
|
||||||
let msg = match crate::proto::decode_event(&self.objects, raw_msg) {
|
let msg = match crate::proto::decode_event(&self.objects, raw_msg) {
|
||||||
WaylandProtocolParsingOutcome::Ok(msg) => msg,
|
WaylandProtocolParsingOutcome::Ok(msg) => msg,
|
||||||
WaylandProtocolParsingOutcome::MalformedMessage => {
|
WaylandProtocolParsingOutcome::MalformedMessage => {
|
||||||
|
@ -194,10 +199,10 @@ impl WlMitmState {
|
||||||
num_fds = raw_msg.fds.len(),
|
num_fds = raw_msg.fds.len(),
|
||||||
"Malformed event"
|
"Malformed event"
|
||||||
);
|
);
|
||||||
return false;
|
return (0, false);
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
return true;
|
return (0, true);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -220,7 +225,7 @@ impl WlMitmState {
|
||||||
interface = msg.interface,
|
interface = msg.interface,
|
||||||
"Removing interface from published globals"
|
"Removing interface from published globals"
|
||||||
);
|
);
|
||||||
return false;
|
return (0, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Else, record the global object. These are the only ones we're ever going to allow through.
|
// Else, record the global object. These are the only ones we're ever going to allow through.
|
||||||
|
@ -234,6 +239,6 @@ impl WlMitmState {
|
||||||
self.objects.remove_object(msg.id);
|
self.objects.remove_object(msg.id);
|
||||||
}
|
}
|
||||||
|
|
||||||
true
|
(msg.num_consumed_fds(), true)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue