Compare commits

...

4 commits

4 changed files with 43 additions and 36 deletions

View file

@ -24,12 +24,12 @@ pub fn generate_from_dir(p: impl AsRef<Path>) -> String {
quote! { quote! {
#( #gen_code )* #( #gen_code )*
pub fn wl_init_parsers() { fn wl_init_parsers(event_parsers: &mut Vec<&'static dyn WlMsgParserFn>, request_parsers: &mut Vec<&'static dyn WlMsgParserFn>) {
#( #add_parsers_fn(); )* #( #add_parsers_fn(event_parsers, request_parsers); )*
} }
pub fn wl_init_known_types() { fn wl_init_known_types(object_types: &mut HashMap<&'static str, WlObjectType>) {
#( #add_object_types_fn(); )* #( #add_object_types_fn(object_types); )*
} }
} }
.to_string() .to_string()
@ -99,13 +99,13 @@ fn generate_from_xml_file(p: impl AsRef<Path>) -> (proc_macro2::TokenStream, (Id
let ret_code = quote! { let ret_code = quote! {
#( #code )* #( #code )*
fn #add_parsers_fn() { fn #add_parsers_fn(event_parsers: &mut Vec<&'static dyn WlMsgParserFn>, request_parsers: &mut Vec<&'static dyn WlMsgParserFn>) {
#( WL_EVENT_PARSERS.write().unwrap().push(&#event_parsers); )* #( event_parsers.push(&#event_parsers); )*
#( WL_REQUEST_PARSERS.write().unwrap().push(&#request_parsers); )* #( request_parsers.push(&#request_parsers); )*
} }
fn #add_object_types_fn() { fn #add_object_types_fn(object_types: &mut HashMap<&'static str, WlObjectType>) {
#( WL_KNOWN_OBJECT_TYPES.write().unwrap().insert(#known_interface_names, #known_interface_consts); )* #( object_types.insert(#known_interface_names, #known_interface_consts); )*
} }
}; };

View file

@ -146,8 +146,8 @@ impl WlMsg {
WlMsgType::#msg_type WlMsgType::#msg_type
} }
#[allow(unused)] #[allow(unused, private_interfaces)]
fn try_from_msg_impl(msg: &crate::codec::WlRawMsg) -> WaylandProtocolParsingOutcome<#struct_name> { fn try_from_msg_impl(msg: &crate::codec::WlRawMsg, _token: __private::WlParsedMessagePrivateToken) -> WaylandProtocolParsingOutcome<#struct_name> {
let payload = msg.payload(); let payload = msg.payload();
let mut pos = 0usize; let mut pos = 0usize;
#( #parser_code )* #( #parser_code )*

View file

@ -17,8 +17,8 @@ use tracing::{Instrument, Level, debug, error, info, span};
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
tracing_subscriber::fmt::init(); tracing_subscriber::fmt::init();
proto::wl_init_parsers(); //proto::wl_init_parsers();
proto::wl_init_known_types(); //proto::wl_init_known_types();
let mut conf_file = "config.toml"; let mut conf_file = "config.toml";

View file

@ -1,10 +1,6 @@
//! Protocol definitions necessary for this MITM proxy //! Protocol definitions necessary for this MITM proxy
use std::{ use std::{collections::HashMap, sync::LazyLock};
collections::HashMap,
hash::{BuildHasherDefault, DefaultHasher},
sync::RwLock,
};
use byteorder::ByteOrder; use byteorder::ByteOrder;
@ -59,9 +55,10 @@ impl<T> WaylandProtocolParsingOutcome<T> {
/// Internal module used to seal the [WlParsedMessage] trait /// Internal module used to seal the [WlParsedMessage] trait
mod __private { mod __private {
pub(super) trait WlParsedMessagePrivate {} pub(super) trait WlParsedMessagePrivate {}
pub(super) struct WlParsedMessagePrivateToken;
} }
#[allow(private_bounds)] #[allow(private_bounds, private_interfaces)]
pub trait WlParsedMessage<'a>: __private::WlParsedMessagePrivate { pub trait WlParsedMessage<'a>: __private::WlParsedMessagePrivate {
fn opcode() -> u16 fn opcode() -> u16
where where
@ -88,10 +85,13 @@ pub trait WlParsedMessage<'a>: __private::WlParsedMessagePrivate {
return WaylandProtocolParsingOutcome::IncorrectOpcode; return WaylandProtocolParsingOutcome::IncorrectOpcode;
} }
Self::try_from_msg_impl(msg) Self::try_from_msg_impl(msg, __private::WlParsedMessagePrivateToken)
} }
fn try_from_msg_impl(msg: &'a WlRawMsg) -> WaylandProtocolParsingOutcome<Self> fn try_from_msg_impl(
msg: &'a WlRawMsg,
_token: __private::WlParsedMessagePrivateToken,
) -> WaylandProtocolParsingOutcome<Self>
where where
Self: Sized + 'a; Self: Sized + 'a;
@ -106,8 +106,10 @@ pub trait WlParsedMessage<'a>: __private::WlParsedMessagePrivate {
/// it does not overlap with any other implementation of this trait. /// it does not overlap with any other implementation of this trait.
/// ///
/// In addition, any implementor also asserts that the type implementing this trait /// In addition, any implementor also asserts that the type implementing this trait
/// does not contain any lifetime other than 'a. This is required for the soundness of /// does not contain any lifetime other than 'a, and that the implenetor struct is
/// the downcast_ref implementation. /// _covariant_ with respect to lifetime 'a.
///
/// This is required for the soundness of the downcast_ref implementation.
pub unsafe trait AnyWlParsedMessage<'a>: WlParsedMessage<'a> {} pub unsafe trait AnyWlParsedMessage<'a>: WlParsedMessage<'a> {}
impl<'out, 'data: 'out> dyn AnyWlParsedMessage<'data> + 'data { impl<'out, 'data: 'out> dyn AnyWlParsedMessage<'data> + 'data {
@ -158,13 +160,21 @@ pub trait WlMsgParserFn: Send + Sync {
} }
/// A map from known interface names to their object types in Rust representation /// A map from known interface names to their object types in Rust representation
static WL_KNOWN_OBJECT_TYPES: RwLock< static WL_KNOWN_OBJECT_TYPES: LazyLock<HashMap<&'static str, WlObjectType>> = LazyLock::new(|| {
HashMap<&'static str, WlObjectType, BuildHasherDefault<DefaultHasher>>, let mut ret = HashMap::new();
> = RwLock::new(HashMap::with_hasher(BuildHasherDefault::new())); wl_init_known_types(&mut ret);
/// Parsers for all known events ret
static WL_EVENT_PARSERS: RwLock<Vec<&'static dyn WlMsgParserFn>> = RwLock::new(Vec::new()); });
/// Parsers for all known requests /// Parsers for all known events / requests
static WL_REQUEST_PARSERS: RwLock<Vec<&'static dyn WlMsgParserFn>> = RwLock::new(Vec::new()); static WL_EVENT_REQUEST_PARSERS: LazyLock<(
Vec<&'static dyn WlMsgParserFn>,
Vec<&'static dyn WlMsgParserFn>,
)> = LazyLock::new(|| {
let mut event_parsers = vec![];
let mut request_parsers = vec![];
wl_init_parsers(&mut event_parsers, &mut request_parsers);
(event_parsers, request_parsers)
});
/// Decode a Wayland event from a [WlRawMsg], returning the type-erased result, or /// Decode a Wayland event from a [WlRawMsg], returning the type-erased result, or
/// [WaylandProtocolParsingOutcome::Unknown] for unknown messages, [WaylandProtocolParsingOutcome::MalformedMessage] /// [WaylandProtocolParsingOutcome::Unknown] for unknown messages, [WaylandProtocolParsingOutcome::MalformedMessage]
@ -175,7 +185,7 @@ pub fn decode_event<'obj, 'msg>(
objects: &'obj WlObjects, objects: &'obj WlObjects,
msg: &'msg WlRawMsg, msg: &'msg WlRawMsg,
) -> WaylandProtocolParsingOutcome<Box<dyn AnyWlParsedMessage<'msg> + 'msg>> { ) -> WaylandProtocolParsingOutcome<Box<dyn AnyWlParsedMessage<'msg> + 'msg>> {
for p in WL_EVENT_PARSERS.read().unwrap().iter() { for p in WL_EVENT_REQUEST_PARSERS.0.iter() {
if let WaylandProtocolParsingOutcome::Ok(e) = if let WaylandProtocolParsingOutcome::Ok(e) =
bubble_malformed!(p.try_from_msg(objects, msg)) bubble_malformed!(p.try_from_msg(objects, msg))
{ {
@ -195,7 +205,7 @@ pub fn decode_request<'obj, 'msg>(
objects: &'obj WlObjects, objects: &'obj WlObjects,
msg: &'msg WlRawMsg, msg: &'msg WlRawMsg,
) -> WaylandProtocolParsingOutcome<Box<dyn AnyWlParsedMessage<'msg> + 'msg>> { ) -> WaylandProtocolParsingOutcome<Box<dyn AnyWlParsedMessage<'msg> + 'msg>> {
for p in WL_REQUEST_PARSERS.read().unwrap().iter() { for p in WL_EVENT_REQUEST_PARSERS.1.iter() {
if let WaylandProtocolParsingOutcome::Ok(e) = if let WaylandProtocolParsingOutcome::Ok(e) =
bubble_malformed!(p.try_from_msg(objects, msg)) bubble_malformed!(p.try_from_msg(objects, msg))
{ {
@ -208,10 +218,7 @@ pub fn decode_request<'obj, 'msg>(
/// Look up a known object type from its name to its Rust [WlObjectType] representation /// Look up a known object type from its name to its Rust [WlObjectType] representation
pub fn lookup_known_object_type(name: &str) -> Option<WlObjectType> { pub fn lookup_known_object_type(name: &str) -> Option<WlObjectType> {
WL_KNOWN_OBJECT_TYPES WL_KNOWN_OBJECT_TYPES.get(name).copied()
.read()
.ok()
.and_then(|t| t.get(name).copied())
} }
/// The default object ID of wl_display /// The default object ID of wl_display