Compare commits

..

No commits in common. "c06110812d2849195fea5bac82562c3ae32de1cf" and "97f8a6ed23892b78ac26eecd81ec0cab2d8537b8" have entirely different histories.

4 changed files with 36 additions and 43 deletions

View file

@ -24,12 +24,12 @@ pub fn generate_from_dir(p: impl AsRef<Path>) -> String {
quote! { quote! {
#( #gen_code )* #( #gen_code )*
fn wl_init_parsers(event_parsers: &mut Vec<&'static dyn WlMsgParserFn>, request_parsers: &mut Vec<&'static dyn WlMsgParserFn>) { pub fn wl_init_parsers() {
#( #add_parsers_fn(event_parsers, request_parsers); )* #( #add_parsers_fn(); )*
} }
fn wl_init_known_types(object_types: &mut HashMap<&'static str, WlObjectType>) { pub fn wl_init_known_types() {
#( #add_object_types_fn(object_types); )* #( #add_object_types_fn(); )*
} }
} }
.to_string() .to_string()
@ -99,13 +99,13 @@ fn generate_from_xml_file(p: impl AsRef<Path>) -> (proc_macro2::TokenStream, (Id
let ret_code = quote! { let ret_code = quote! {
#( #code )* #( #code )*
fn #add_parsers_fn(event_parsers: &mut Vec<&'static dyn WlMsgParserFn>, request_parsers: &mut Vec<&'static dyn WlMsgParserFn>) { fn #add_parsers_fn() {
#( event_parsers.push(&#event_parsers); )* #( WL_EVENT_PARSERS.write().unwrap().push(&#event_parsers); )*
#( request_parsers.push(&#request_parsers); )* #( WL_REQUEST_PARSERS.write().unwrap().push(&#request_parsers); )*
} }
fn #add_object_types_fn(object_types: &mut HashMap<&'static str, WlObjectType>) { fn #add_object_types_fn() {
#( object_types.insert(#known_interface_names, #known_interface_consts); )* #( WL_KNOWN_OBJECT_TYPES.write().unwrap().insert(#known_interface_names, #known_interface_consts); )*
} }
}; };

View file

@ -146,8 +146,8 @@ impl WlMsg {
WlMsgType::#msg_type WlMsgType::#msg_type
} }
#[allow(unused, private_interfaces)] #[allow(unused)]
fn try_from_msg_impl(msg: &crate::codec::WlRawMsg, _token: __private::WlParsedMessagePrivateToken) -> WaylandProtocolParsingOutcome<#struct_name> { fn try_from_msg_impl(msg: &crate::codec::WlRawMsg) -> WaylandProtocolParsingOutcome<#struct_name> {
let payload = msg.payload(); let payload = msg.payload();
let mut pos = 0usize; let mut pos = 0usize;
#( #parser_code )* #( #parser_code )*

View file

@ -17,8 +17,8 @@ use tracing::{Instrument, Level, debug, error, info, span};
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
tracing_subscriber::fmt::init(); tracing_subscriber::fmt::init();
//proto::wl_init_parsers(); proto::wl_init_parsers();
//proto::wl_init_known_types(); proto::wl_init_known_types();
let mut conf_file = "config.toml"; let mut conf_file = "config.toml";

View file

@ -1,6 +1,10 @@
//! Protocol definitions necessary for this MITM proxy //! Protocol definitions necessary for this MITM proxy
use std::{collections::HashMap, sync::LazyLock}; use std::{
collections::HashMap,
hash::{BuildHasherDefault, DefaultHasher},
sync::RwLock,
};
use byteorder::ByteOrder; use byteorder::ByteOrder;
@ -55,10 +59,9 @@ impl<T> WaylandProtocolParsingOutcome<T> {
/// Internal module used to seal the [WlParsedMessage] trait /// Internal module used to seal the [WlParsedMessage] trait
mod __private { mod __private {
pub(super) trait WlParsedMessagePrivate {} pub(super) trait WlParsedMessagePrivate {}
pub(super) struct WlParsedMessagePrivateToken;
} }
#[allow(private_bounds, private_interfaces)] #[allow(private_bounds)]
pub trait WlParsedMessage<'a>: __private::WlParsedMessagePrivate { pub trait WlParsedMessage<'a>: __private::WlParsedMessagePrivate {
fn opcode() -> u16 fn opcode() -> u16
where where
@ -85,13 +88,10 @@ pub trait WlParsedMessage<'a>: __private::WlParsedMessagePrivate {
return WaylandProtocolParsingOutcome::IncorrectOpcode; return WaylandProtocolParsingOutcome::IncorrectOpcode;
} }
Self::try_from_msg_impl(msg, __private::WlParsedMessagePrivateToken) Self::try_from_msg_impl(msg)
} }
fn try_from_msg_impl( fn try_from_msg_impl(msg: &'a WlRawMsg) -> WaylandProtocolParsingOutcome<Self>
msg: &'a WlRawMsg,
_token: __private::WlParsedMessagePrivateToken,
) -> WaylandProtocolParsingOutcome<Self>
where where
Self: Sized + 'a; Self: Sized + 'a;
@ -106,10 +106,8 @@ pub trait WlParsedMessage<'a>: __private::WlParsedMessagePrivate {
/// it does not overlap with any other implementation of this trait. /// it does not overlap with any other implementation of this trait.
/// ///
/// In addition, any implementor also asserts that the type implementing this trait /// In addition, any implementor also asserts that the type implementing this trait
/// does not contain any lifetime other than 'a, and that the implenetor struct is /// does not contain any lifetime other than 'a. This is required for the soundness of
/// _covariant_ with respect to lifetime 'a. /// the downcast_ref implementation.
///
/// This is required for the soundness of the downcast_ref implementation.
pub unsafe trait AnyWlParsedMessage<'a>: WlParsedMessage<'a> {} pub unsafe trait AnyWlParsedMessage<'a>: WlParsedMessage<'a> {}
impl<'out, 'data: 'out> dyn AnyWlParsedMessage<'data> + 'data { impl<'out, 'data: 'out> dyn AnyWlParsedMessage<'data> + 'data {
@ -160,21 +158,13 @@ pub trait WlMsgParserFn: Send + Sync {
} }
/// A map from known interface names to their object types in Rust representation /// A map from known interface names to their object types in Rust representation
static WL_KNOWN_OBJECT_TYPES: LazyLock<HashMap<&'static str, WlObjectType>> = LazyLock::new(|| { static WL_KNOWN_OBJECT_TYPES: RwLock<
let mut ret = HashMap::new(); HashMap<&'static str, WlObjectType, BuildHasherDefault<DefaultHasher>>,
wl_init_known_types(&mut ret); > = RwLock::new(HashMap::with_hasher(BuildHasherDefault::new()));
ret /// Parsers for all known events
}); static WL_EVENT_PARSERS: RwLock<Vec<&'static dyn WlMsgParserFn>> = RwLock::new(Vec::new());
/// Parsers for all known events / requests /// Parsers for all known requests
static WL_EVENT_REQUEST_PARSERS: LazyLock<( static WL_REQUEST_PARSERS: RwLock<Vec<&'static dyn WlMsgParserFn>> = RwLock::new(Vec::new());
Vec<&'static dyn WlMsgParserFn>,
Vec<&'static dyn WlMsgParserFn>,
)> = LazyLock::new(|| {
let mut event_parsers = vec![];
let mut request_parsers = vec![];
wl_init_parsers(&mut event_parsers, &mut request_parsers);
(event_parsers, request_parsers)
});
/// Decode a Wayland event from a [WlRawMsg], returning the type-erased result, or /// Decode a Wayland event from a [WlRawMsg], returning the type-erased result, or
/// [WaylandProtocolParsingOutcome::Unknown] for unknown messages, [WaylandProtocolParsingOutcome::MalformedMessage] /// [WaylandProtocolParsingOutcome::Unknown] for unknown messages, [WaylandProtocolParsingOutcome::MalformedMessage]
@ -185,7 +175,7 @@ pub fn decode_event<'obj, 'msg>(
objects: &'obj WlObjects, objects: &'obj WlObjects,
msg: &'msg WlRawMsg, msg: &'msg WlRawMsg,
) -> WaylandProtocolParsingOutcome<Box<dyn AnyWlParsedMessage<'msg> + 'msg>> { ) -> WaylandProtocolParsingOutcome<Box<dyn AnyWlParsedMessage<'msg> + 'msg>> {
for p in WL_EVENT_REQUEST_PARSERS.0.iter() { for p in WL_EVENT_PARSERS.read().unwrap().iter() {
if let WaylandProtocolParsingOutcome::Ok(e) = if let WaylandProtocolParsingOutcome::Ok(e) =
bubble_malformed!(p.try_from_msg(objects, msg)) bubble_malformed!(p.try_from_msg(objects, msg))
{ {
@ -205,7 +195,7 @@ pub fn decode_request<'obj, 'msg>(
objects: &'obj WlObjects, objects: &'obj WlObjects,
msg: &'msg WlRawMsg, msg: &'msg WlRawMsg,
) -> WaylandProtocolParsingOutcome<Box<dyn AnyWlParsedMessage<'msg> + 'msg>> { ) -> WaylandProtocolParsingOutcome<Box<dyn AnyWlParsedMessage<'msg> + 'msg>> {
for p in WL_EVENT_REQUEST_PARSERS.1.iter() { for p in WL_REQUEST_PARSERS.read().unwrap().iter() {
if let WaylandProtocolParsingOutcome::Ok(e) = if let WaylandProtocolParsingOutcome::Ok(e) =
bubble_malformed!(p.try_from_msg(objects, msg)) bubble_malformed!(p.try_from_msg(objects, msg))
{ {
@ -218,7 +208,10 @@ pub fn decode_request<'obj, 'msg>(
/// Look up a known object type from its name to its Rust [WlObjectType] representation /// Look up a known object type from its name to its Rust [WlObjectType] representation
pub fn lookup_known_object_type(name: &str) -> Option<WlObjectType> { pub fn lookup_known_object_type(name: &str) -> Option<WlObjectType> {
WL_KNOWN_OBJECT_TYPES.get(name).copied() WL_KNOWN_OBJECT_TYPES
.read()
.ok()
.and_then(|t| t.get(name).copied())
} }
/// The default object ID of wl_display /// The default object ID of wl_display