Compare commits

..

No commits in common. "3b22de398cc6ec86ed07273d98c122accb790954" and "1860437c72188ee6e053575b4b0b83cbe4cc5c50" have entirely different histories.

9 changed files with 87 additions and 542 deletions

246
Cargo.lock generated
View file

@ -1,6 +1,6 @@
# This file is automatically @generated by Cargo. # This file is automatically @generated by Cargo.
# It is not intended for manual editing. # It is not intended for manual editing.
version = 4 version = 3
[[package]] [[package]]
name = "addr2line" name = "addr2line"
@ -62,52 +62,18 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724"
[[package]]
name = "equivalent"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f"
[[package]] [[package]]
name = "gimli" name = "gimli"
version = "0.31.1" version = "0.31.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f"
[[package]]
name = "hashbrown"
version = "0.15.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289"
[[package]]
name = "indexmap"
version = "2.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8c9c992b02b5b4c94ea26e32fe5bccb7aa7d9f390ab5c1221ff895bc7ea8b652"
dependencies = [
"equivalent",
"hashbrown",
]
[[package]]
name = "lazy_static"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
[[package]] [[package]]
name = "libc" name = "libc"
version = "0.2.170" version = "0.2.170"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "875b3680cb2f8f71bdcf9a30f38d48282f5d3c95cbf9b3fa57269bb5d5c06828" checksum = "875b3680cb2f8f71bdcf9a30f38d48282f5d3c95cbf9b3fa57269bb5d5c06828"
[[package]]
name = "log"
version = "0.4.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "30bde2b3dc3671ae49d8e2e9f044c7c005836e7a023ee57cffa25ab82764bb9e"
[[package]] [[package]]
name = "memchr" name = "memchr"
version = "2.7.4" version = "2.7.4"
@ -146,16 +112,6 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "nu-ansi-term"
version = "0.46.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84"
dependencies = [
"overload",
"winapi",
]
[[package]] [[package]]
name = "object" name = "object"
version = "0.36.7" version = "0.36.7"
@ -165,18 +121,6 @@ dependencies = [
"memchr", "memchr",
] ]
[[package]]
name = "once_cell"
version = "1.20.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "945462a4b81e43c4e3ba96bd7b49d834c6f61198356aa858733bc4acf3cbe62e"
[[package]]
name = "overload"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39"
[[package]] [[package]]
name = "pin-project-lite" name = "pin-project-lite"
version = "0.2.16" version = "0.2.16"
@ -217,50 +161,6 @@ dependencies = [
"tokio", "tokio",
] ]
[[package]]
name = "serde"
version = "1.0.218"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e8dfc9d19bdbf6d17e22319da49161d5d0108e4188e8b680aef6299eed22df60"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.218"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f09503e191f4e797cb8aac08e9a4a4695c5edf6a2e70e376d961ddd5c969f82b"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "serde_spanned"
version = "0.6.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87607cb1398ed59d48732e575a4c28a7a8ebf2454b964fe3f224f2afc07909e1"
dependencies = [
"serde",
]
[[package]]
name = "sharded-slab"
version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6"
dependencies = [
"lazy_static",
]
[[package]]
name = "smallvec"
version = "1.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7fcf8323ef1faaee30a44a340193b1ac6814fd9b7b4e88e9d4519a3e4abe1cfd"
[[package]] [[package]]
name = "socket2" name = "socket2"
version = "0.5.8" version = "0.5.8"
@ -282,16 +182,6 @@ dependencies = [
"unicode-ident", "unicode-ident",
] ]
[[package]]
name = "thread_local"
version = "1.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b9ef9bad013ada3808854ceac7b46812a6465ba368859a37e2100283d2d719c"
dependencies = [
"cfg-if",
"once_cell",
]
[[package]] [[package]]
name = "tokio" name = "tokio"
version = "1.43.0" version = "1.43.0"
@ -299,7 +189,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3d61fa4ffa3de412bfea335c6ecff681de2b609ba3c77ef3e00e521813a9ed9e" checksum = "3d61fa4ffa3de412bfea335c6ecff681de2b609ba3c77ef3e00e521813a9ed9e"
dependencies = [ dependencies = [
"backtrace", "backtrace",
"bytes",
"libc", "libc",
"mio", "mio",
"pin-project-lite", "pin-project-lite",
@ -319,137 +208,18 @@ dependencies = [
"syn", "syn",
] ]
[[package]]
name = "toml"
version = "0.8.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cd87a5cdd6ffab733b2f74bc4fd7ee5fff6634124999ac278c35fc78c6120148"
dependencies = [
"serde",
"serde_spanned",
"toml_datetime",
"toml_edit",
]
[[package]]
name = "toml_datetime"
version = "0.6.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41"
dependencies = [
"serde",
]
[[package]]
name = "toml_edit"
version = "0.22.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "17b4795ff5edd201c7cd6dca065ae59972ce77d1b80fa0a84d94950ece7d1474"
dependencies = [
"indexmap",
"serde",
"serde_spanned",
"toml_datetime",
"winnow",
]
[[package]]
name = "tracing"
version = "0.1.41"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0"
dependencies = [
"pin-project-lite",
"tracing-attributes",
"tracing-core",
]
[[package]]
name = "tracing-attributes"
version = "0.1.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "tracing-core"
version = "0.1.33"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e672c95779cf947c5311f83787af4fa8fffd12fb27e4993211a84bdfd9610f9c"
dependencies = [
"once_cell",
"valuable",
]
[[package]]
name = "tracing-log"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3"
dependencies = [
"log",
"once_cell",
"tracing-core",
]
[[package]]
name = "tracing-subscriber"
version = "0.3.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008"
dependencies = [
"nu-ansi-term",
"sharded-slab",
"smallvec",
"thread_local",
"tracing-core",
"tracing-log",
]
[[package]] [[package]]
name = "unicode-ident" name = "unicode-ident"
version = "1.0.17" version = "1.0.17"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "00e2473a93778eb0bad35909dff6a10d28e63f792f16ed15e404fca9d5eeedbe" checksum = "00e2473a93778eb0bad35909dff6a10d28e63f792f16ed15e404fca9d5eeedbe"
[[package]]
name = "valuable"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65"
[[package]] [[package]]
name = "wasi" name = "wasi"
version = "0.11.0+wasi-snapshot-preview1" version = "0.11.0+wasi-snapshot-preview1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
[[package]]
name = "winapi"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
dependencies = [
"winapi-i686-pc-windows-gnu",
"winapi-x86_64-pc-windows-gnu",
]
[[package]]
name = "winapi-i686-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
[[package]]
name = "winapi-x86_64-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
[[package]] [[package]]
name = "windows-sys" name = "windows-sys"
version = "0.52.0" version = "0.52.0"
@ -523,15 +293,6 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
[[package]]
name = "winnow"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0e7f4ea97f6f78012141bcdb6a216b2609f0979ada50b20ca5b52dde2eac2bb1"
dependencies = [
"memchr",
]
[[package]] [[package]]
name = "wl-mitm" name = "wl-mitm"
version = "0.1.0" version = "0.1.0"
@ -540,10 +301,5 @@ dependencies = [
"bytes", "bytes",
"nix", "nix",
"sendfd", "sendfd",
"serde",
"serde_derive",
"tokio", "tokio",
"toml",
"tracing",
"tracing-subscriber",
] ]

View file

@ -1,7 +1,7 @@
[package] [package]
name = "wl-mitm" name = "wl-mitm"
version = "0.1.0" version = "0.1.0"
edition = "2024" edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
@ -10,9 +10,4 @@ byteorder = "1.5.0"
bytes = "1.10.0" bytes = "1.10.0"
nix = "0.29.0" nix = "0.29.0"
sendfd = { version = "0.4", features = [ "tokio" ] } sendfd = { version = "0.4", features = [ "tokio" ] }
serde = "1.0.218" tokio = { version = "1.43.0", features = [ "net", "rt", "rt-multi-thread", "macros" ]}
serde_derive = "1.0.218"
tokio = { version = "1.43.0", features = [ "fs", "net", "rt", "rt-multi-thread", "macros", "io-util" ]}
toml = "0.8.20"
tracing = "0.1.41"
tracing-subscriber = "0.3.19"

View file

@ -1,25 +0,0 @@
[socket]
# Which socket to listen on? If relative,
# defaults to being relative to $XDG_RUNTIME_DIR
listen = "wayland-2"
# Which Wayland socket to use as upstream?
# If missing, defaults to $WAYLAND_DISPLAY
# upstream = "wayland-1"
[filter]
# A list of Wayland global singleton objects that's allowed
# Each of them generally correspond to an implemented protocol
# For a list of these, see <https://wayland.app>
# Note that not every object is exposed and created as a global.
allowed_globals = [
# Base wl protocols
"wl_compositor",
"wl_shm",
"wl_data_device_manager",
"wl_seat",
# Window management
"xdg_wm_base",
"zxdg_decoration_manager_v1",
# Linux DMA-BUF
"zwp_linux_dmabuf_v1",
]

View file

@ -3,7 +3,6 @@ use std::os::fd::OwnedFd;
use byteorder::{ByteOrder, NativeEndian}; use byteorder::{ByteOrder, NativeEndian};
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
#[allow(unused)]
pub struct WlRawMsg { pub struct WlRawMsg {
// 4 bytes // 4 bytes
pub obj_id: u32, pub obj_id: u32,

View file

@ -1,56 +0,0 @@
use std::{
collections::HashSet,
path::{Path, PathBuf},
};
use serde_derive::Deserialize;
#[derive(Deserialize)]
pub struct Config {
pub socket: WlSockets,
pub filter: WlFilter,
}
fn default_upstream_socket() -> String {
std::env::var("WAYLAND_DISPLAY").unwrap_or_else(|_| "wayland-1".to_string())
}
#[derive(Deserialize)]
pub struct WlSockets {
listen: String,
#[serde(default = "default_upstream_socket")]
upstream: String,
}
impl WlSockets {
pub fn upstream_socket_path(&self) -> PathBuf {
let p = Path::new(&self.upstream);
if p.is_absolute() {
p.into()
} else {
Path::new(
&std::env::var("XDG_RUNTIME_DIR").unwrap_or_else(|_| "/run/user/1000".to_string()),
)
.join(p)
.into()
}
}
pub fn listen_socket_path(&self) -> PathBuf {
let p = Path::new(&self.listen);
if p.is_absolute() {
p.into()
} else {
Path::new(
&std::env::var("XDG_RUNTIME_DIR").unwrap_or_else(|_| "/run/user/1000".to_string()),
)
.join(p)
.into()
}
}
}
#[derive(Deserialize)]
pub struct WlFilter {
pub allowed_globals: HashSet<String>,
}

View file

@ -1,12 +1,9 @@
use std::{ use std::{
future::poll_fn,
io, io,
ops::Deref, ops::Deref,
os::fd::{FromRawFd, OwnedFd}, os::fd::{FromRawFd, OwnedFd},
task::{Context, Poll},
}; };
use bytes::Bytes;
use sendfd::{RecvWithFd, SendWithFd}; use sendfd::{RecvWithFd, SendWithFd};
use tokio::net::unix::{ReadHalf, WriteHalf}; use tokio::net::unix::{ReadHalf, WriteHalf};
@ -57,97 +54,35 @@ impl<'a> WlMsgReader<'a> {
pub struct WlMsgWriter<'a> { pub struct WlMsgWriter<'a> {
egress: WriteHalf<'a>, egress: WriteHalf<'a>,
write_queue: Vec<WlRawMsg>,
cur_write_buf: Option<Bytes>,
cur_write_buf_pos: usize,
cur_write_fds: Option<Box<[OwnedFd]>>,
} }
impl<'a> WlMsgWriter<'a> { impl<'a> WlMsgWriter<'a> {
pub fn new(egress: WriteHalf<'a>) -> Self { pub fn new(egress: WriteHalf<'a>) -> Self {
WlMsgWriter { WlMsgWriter { egress }
egress,
write_queue: Vec::new(),
cur_write_buf: None,
cur_write_buf_pos: 0,
cur_write_fds: None,
}
} }
/// Can we possibly write anything? pub async fn write(&mut self, msg: WlRawMsg) -> io::Result<()> {
fn can_write(&self) -> bool { let (buf, fds) = msg.into_parts();
self.cur_write_buf.is_some() || !self.write_queue.is_empty()
}
/// Try to write __something__ into the underlying stream. let mut written = 0;
/// This does not care about registering interests, so it may return ready with a WOULDBLOCK
fn try_poll_write(&mut self) -> Poll<io::Result<()>> {
// If we don't have a partially written buffer, try remove one from the write queue
if self.cur_write_buf.is_none() && !self.write_queue.is_empty() {
// Don't use pop(), wl messages need to be in order!!
let (buf, fds) = self.write_queue.remove(0).into_parts();
self.cur_write_buf = Some(buf); while written < buf.len() {
self.cur_write_buf_pos = 0; self.egress.writable().await?;
self.cur_write_fds = Some(fds);
}
if let Some(buf) = self.cur_write_buf.take() { let res = if written == 0 {
let send_res = if let Some(fds) = self.cur_write_fds.take() {
self.egress self.egress
.send_with_fd(&buf[self.cur_write_buf_pos..], unsafe { .send_with_fd(&buf, unsafe { std::mem::transmute(fds.deref()) })
std::mem::transmute(fds.deref())
})
} else { } else {
self.egress self.egress.send_with_fd(&buf[written..], &[])
.send_with_fd(&buf[self.cur_write_buf_pos..], &[])
}; };
if let Ok(written) = send_res { match res {
// Partial send :( Ok(new_written) => written += new_written,
// At least fds are always guaranteed to be sent in full Err(e) if e.kind() == io::ErrorKind::WouldBlock => continue,
if self.cur_write_buf_pos + written < buf.len() { Err(e) => return Err(e),
self.cur_write_buf = Some(buf);
self.cur_write_buf_pos += written;
}
}
// Caller is supposed to handle WOULDBLOCK
Poll::Ready(send_res.map(|_| ()))
} else {
Poll::Pending
}
}
fn poll_write(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
// If we can't write anything, return pending immediately
if !self.can_write() {
return Poll::Pending;
}
while self.egress.as_ref().poll_write_ready(cx).is_ready() {
match self.try_poll_write() {
Poll::Ready(Err(e)) if e.kind() == io::ErrorKind::WouldBlock => continue,
Poll::Ready(res) => return Poll::Ready(res),
Poll::Pending => return Poll::Pending,
} }
} }
Poll::Pending Ok(())
}
/// Queue a message up for writing, but doesn't do anything right away.
pub fn queue_write(&mut self, msg: WlRawMsg) {
self.write_queue.push(msg);
}
/// Try to make progress by flushing some of the queued up messages into the stream.
/// When this resolves, note that we might have only partially written. In that
/// case the buffer is saved internally in this structure.
///
/// The returned future will block forever (never resolve) if there is no
/// message to be written. This behavior makes it play nicely with select!{}
pub async fn dequeue_write(&mut self) -> io::Result<()> {
poll_fn(|cx| self.poll_write(cx)).await
} }
} }

View file

@ -1,69 +1,46 @@
mod codec; mod codec;
mod io_util; mod io_util;
mod objects; mod objects;
#[macro_use]
mod proto; mod proto;
mod config;
mod state; mod state;
use std::{io, path::Path, sync::Arc}; use std::{io, path::Path};
use config::Config;
use io_util::{WlMsgReader, WlMsgWriter}; use io_util::{WlMsgReader, WlMsgWriter};
use state::WlMitmState; use state::WlMitmState;
use tokio::net::{UnixListener, UnixStream}; use tokio::net::{UnixListener, UnixStream};
use tracing::{Instrument, Level, debug, error, info, span};
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
tracing_subscriber::fmt::init();
let mut conf_file = "config.toml";
let args: Vec<_> = std::env::args().collect(); let args: Vec<_> = std::env::args().collect();
if args.len() >= 2 { if args.len() < 3 {
conf_file = &args[2]; println!("Usage: {} <wl_display> <wl_display_proxied>", args[0]);
}
let conf_str = tokio::fs::read_to_string(conf_file)
.await
.expect("Can't read config file");
let config: Arc<Config> =
Arc::new(toml::from_str(&conf_str).expect("Can't decode config file"));
let src = config.socket.upstream_socket_path();
let proxied = config.socket.listen_socket_path();
if src == proxied {
error!("downstream and upstream sockets should not be the same");
return; return;
} }
if proxied.exists() { let xdg_rt = std::env::var("XDG_RUNTIME_DIR").expect("XDG_RUNTIME_DIR not set");
tokio::fs::remove_file(&proxied)
.await let src = format!("{}/{}", xdg_rt, args[1]);
.expect("Cannot unlink existing socket"); let proxied = format!("{}/{}", xdg_rt, args[2]);
if src == proxied {
println!("downstream and upstream sockets should not be the same");
return;
} }
let listener = UnixListener::bind(&proxied).expect("Failed to bind to target socket"); if Path::exists(proxied.as_ref()) {
std::fs::remove_file(&proxied).expect("Cannot unlink existing socket");
}
info!(path = ?proxied, "Listening on socket"); let listener = UnixListener::bind(proxied).expect("Failed to bind to target socket");
let mut conn_id = 0;
while let Ok((conn, addr)) = listener.accept().await { while let Ok((conn, addr)) = listener.accept().await {
info!(conn_id = conn_id, "Accepted new client {:?}", addr); println!("Accepted new client {:?}", addr);
let span = span!(Level::INFO, "conn", conn_id = conn_id); tokio::spawn(handle_conn(src.clone(), conn));
tokio::spawn(handle_conn(config.clone(), src.clone(), conn).instrument(span));
conn_id += 1;
} }
} }
#[tracing::instrument(skip_all)] pub async fn handle_conn(src_path: String, mut downstream_conn: UnixStream) -> io::Result<()> {
pub async fn handle_conn(
config: Arc<Config>,
src_path: impl AsRef<Path>,
mut downstream_conn: UnixStream,
) -> io::Result<()> {
let mut upstream_conn = UnixStream::connect(src_path).await?; let mut upstream_conn = UnixStream::connect(src_path).await?;
let (upstream_read, upstream_write) = upstream_conn.split(); let (upstream_read, upstream_write) = upstream_conn.split();
@ -75,17 +52,17 @@ pub async fn handle_conn(
let mut upstream_write = WlMsgWriter::new(upstream_write); let mut upstream_write = WlMsgWriter::new(upstream_write);
let mut downstream_write = WlMsgWriter::new(downstream_write); let mut downstream_write = WlMsgWriter::new(downstream_write);
let mut state = WlMitmState::new(config); let mut state = WlMitmState::new();
loop { loop {
tokio::select! { tokio::select! {
s2c_msg = upstream_read.read() => { s2c_msg = upstream_read.read() => {
match s2c_msg? { match s2c_msg? {
codec::DecoderOutcome::Decoded(wl_raw_msg) => { codec::DecoderOutcome::Decoded(wl_raw_msg) => {
debug!(obj_id = wl_raw_msg.obj_id, opcode = wl_raw_msg.opcode, "s2c event"); println!("s2c, obj_id = {}, opcode = {}", wl_raw_msg.obj_id, wl_raw_msg.opcode);
if state.on_s2c_event(&wl_raw_msg) { if state.on_s2c_msg(&wl_raw_msg) {
downstream_write.queue_write(wl_raw_msg); downstream_write.write(wl_raw_msg).await?;
} }
}, },
codec::DecoderOutcome::Incomplete => continue, codec::DecoderOutcome::Incomplete => continue,
@ -95,19 +72,16 @@ pub async fn handle_conn(
c2s_msg = downstream_read.read() => { c2s_msg = downstream_read.read() => {
match c2s_msg? { match c2s_msg? {
codec::DecoderOutcome::Decoded(wl_raw_msg) => { codec::DecoderOutcome::Decoded(wl_raw_msg) => {
debug!(obj_id = wl_raw_msg.obj_id, opcode = wl_raw_msg.opcode, "c2s request"); println!("c2s, obj_id = {}, opcode = {}", wl_raw_msg.obj_id, wl_raw_msg.opcode);
if state.on_c2s_request(&wl_raw_msg) { if state.on_c2s_msg(&wl_raw_msg) {
upstream_write.queue_write(wl_raw_msg); upstream_write.write(wl_raw_msg).await?;
} }
}, },
codec::DecoderOutcome::Incomplete => continue, codec::DecoderOutcome::Incomplete => continue,
codec::DecoderOutcome::Eof => break Ok(()), codec::DecoderOutcome::Eof => break Ok(()),
} }
} }
// Try to write of we have any queued up. These don't do anything if no message is queued.
res = upstream_write.dequeue_write() => res?,
res = downstream_write.dequeue_write() => res?,
} }
} }
} }

View file

@ -9,28 +9,6 @@ use crate::{
objects::{WlObjectType, WlObjects}, objects::{WlObjectType, WlObjects},
}; };
macro_rules! reject_malformed {
($e:expr) => {
if let crate::proto::WaylandProtocolParsingOutcome::MalformedMessage = $e {
return false;
} else if let crate::proto::WaylandProtocolParsingOutcome::Ok(e) = $e {
Some(e)
} else {
None
}
};
}
macro_rules! decode_and_match_msg {
($objects:expr, match $msg:ident {$($t:ty => $act:block$(,)?)+}) => {
$(
if let Some($msg) = reject_malformed!(<$t>::try_from_msg(&$objects, $msg)) {
$act
}
)+
};
}
pub enum WaylandProtocolParsingOutcome<T> { pub enum WaylandProtocolParsingOutcome<T> {
Ok(T), Ok(T),
MalformedMessage, MalformedMessage,
@ -122,8 +100,7 @@ impl<'a> WlRegistryGlobalEvent<'a> {
} }
let version = NativeEndian::read_u32(&payload[payload.len() - 4..]); let version = NativeEndian::read_u32(&payload[payload.len() - 4..]);
// -1 because of 0-terminator let Ok(interface) = std::str::from_utf8(&payload[8..8 + interface_len as usize]) else {
let Ok(interface) = std::str::from_utf8(&payload[8..8 + interface_len as usize - 1]) else {
return WaylandProtocolParsingOutcome::MalformedMessage; return WaylandProtocolParsingOutcome::MalformedMessage;
}; };

View file

@ -1,77 +1,67 @@
use std::sync::Arc;
use tracing::{debug, info};
use crate::{ use crate::{
codec::WlRawMsg, codec::WlRawMsg,
config::Config,
objects::{WlObjectType, WlObjects}, objects::{WlObjectType, WlObjects},
proto::{WlDisplayGetRegistry, WlRegistryBind, WlRegistryGlobalEvent}, proto::{
WaylandProtocolParsingOutcome, WlDisplayGetRegistry, WlRegistryBind, WlRegistryGlobalEvent,
},
}; };
macro_rules! reject_malformed {
($e:expr) => {
if let WaylandProtocolParsingOutcome::MalformedMessage = $e {
return false;
} else if let WaylandProtocolParsingOutcome::Ok(e) = $e {
Some(e)
} else {
None
}
};
}
pub struct WlMitmState { pub struct WlMitmState {
config: Arc<Config>,
objects: WlObjects, objects: WlObjects,
} }
impl WlMitmState { impl WlMitmState {
pub fn new(config: Arc<Config>) -> WlMitmState { pub fn new() -> WlMitmState {
WlMitmState { WlMitmState {
config,
objects: WlObjects::new(), objects: WlObjects::new(),
} }
} }
#[tracing::instrument(skip_all)] pub fn on_c2s_msg(&mut self, msg: &WlRawMsg) -> bool {
pub fn on_c2s_request(&mut self, msg: &WlRawMsg) -> bool { if let Some(get_registry_msg) =
decode_and_match_msg!( reject_malformed!(WlDisplayGetRegistry::try_from_msg(&self.objects, msg))
self.objects, {
match msg { self.objects
WlDisplayGetRegistry => { .record_object(WlObjectType::WlRegistry, get_registry_msg.registry_new_id);
self.objects } else if let Some(bind_msg) =
.record_object(WlObjectType::WlRegistry, msg.registry_new_id); reject_malformed!(WlRegistryBind::try_from_msg(&self.objects, msg))
} {
WlRegistryBind => { let Some(interface) = self.objects.lookup_global(bind_msg.name) else {
let Some(interface) = self.objects.lookup_global(msg.name) else { return false;
return false; };
}; println!(
info!( "Client binding interface {}, object id = {}",
interface = interface, interface, bind_msg.new_id
obj_id = msg.new_id, );
"Client binding interface" }
);
}
}
);
true true
} }
#[tracing::instrument(skip_all)] pub fn on_s2c_msg(&mut self, msg: &WlRawMsg) -> bool {
pub fn on_s2c_event(&mut self, msg: &WlRawMsg) -> bool { if let Some(global_msg) =
decode_and_match_msg!( reject_malformed!(WlRegistryGlobalEvent::try_from_msg(&self.objects, msg))
self.objects, {
match msg { println!(
WlRegistryGlobalEvent => { "got global: {}, name {}, version {}",
debug!( global_msg.interface, global_msg.name, global_msg.version
interface = msg.interface, );
name = msg.name,
version = msg.version,
"got global"
);
self.objects.record_global(msg.name, msg.interface); self.objects
.record_global(global_msg.name, global_msg.interface);
if !self.config.filter.allowed_globals.contains(msg.interface) { }
info!(
interface = msg.interface,
"Removing interface from published globals"
);
return false;
}
}
}
);
true true
} }