Compare commits
No commits in common. "6698faef350b9f38ee1cad2c5a2f04c2eba28b16" and "5b9ce683ecfc45aa07404773acc10e01aa63f2a0" have entirely different histories.
6698faef35
...
5b9ce683ec
6 changed files with 148 additions and 477 deletions
|
@ -1,22 +1,19 @@
|
|||
use std::path::PathBuf;
|
||||
|
||||
use proc_macro2::Span;
|
||||
use quick_xml::events::Event;
|
||||
use quote::{format_ident, quote};
|
||||
use syn::{Ident, LitStr, parse_macro_input};
|
||||
use types::{WlArgType, WlInterface, WlMsg, WlMsgType};
|
||||
use types::WlArgType;
|
||||
|
||||
mod types;
|
||||
|
||||
#[proc_macro]
|
||||
pub fn wayland_proto_gen(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
|
||||
let input: LitStr = parse_macro_input!(item);
|
||||
let p = PathBuf::from(input.value());
|
||||
let file_name = p.file_stem().expect("No file name provided");
|
||||
let xml_str = std::fs::read_to_string(&p).expect("Unable to read from file");
|
||||
let xml_str = std::fs::read_to_string(input.value()).expect("Unable to read from file");
|
||||
let mut reader = quick_xml::Reader::from_str(&xml_str);
|
||||
reader.config_mut().trim_text(true);
|
||||
|
||||
let mut interfaces: Vec<WlInterface> = vec![];
|
||||
let mut ret = proc_macro2::TokenStream::new();
|
||||
|
||||
loop {
|
||||
match reader.read_event().expect("Unable to parse XML file") {
|
||||
|
@ -28,7 +25,11 @@ pub fn wayland_proto_gen(item: proc_macro::TokenStream) -> proc_macro::TokenStre
|
|||
match name {
|
||||
"interface" => {
|
||||
// An <interface> section
|
||||
interfaces.push(handle_interface(&mut reader, e));
|
||||
let str = handle_interface(&mut reader, e);
|
||||
ret = quote! {
|
||||
#ret
|
||||
#str
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
@ -37,57 +38,13 @@ pub fn wayland_proto_gen(item: proc_macro::TokenStream) -> proc_macro::TokenStre
|
|||
}
|
||||
}
|
||||
|
||||
let mut code: Vec<proc_macro2::TokenStream> = vec![];
|
||||
let mut event_parsers: Vec<Ident> = vec![];
|
||||
let mut request_parsers: Vec<Ident> = vec![];
|
||||
let (mut known_interface_names, mut known_interface_consts): (Vec<String>, Vec<Ident>) =
|
||||
(vec![], vec![]);
|
||||
|
||||
for i in interfaces.iter() {
|
||||
known_interface_names.push(i.name_snake.clone());
|
||||
known_interface_consts.push(format_ident!("{}", i.type_const_name()));
|
||||
|
||||
code.push(i.generate());
|
||||
|
||||
for m in i.msgs.iter() {
|
||||
let parser_name = format_ident!("{}", m.parser_fn_name());
|
||||
|
||||
match m.msg_type {
|
||||
WlMsgType::Event => {
|
||||
event_parsers.push(parser_name);
|
||||
}
|
||||
WlMsgType::Request => {
|
||||
request_parsers.push(parser_name);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// A function to add all event/request parsers to WL_EVENT_PARSERS and WL_REQUEST_PARSERS
|
||||
let add_parsers_fn = format_ident!("wl_init_parsers_{}", file_name.to_str().unwrap());
|
||||
|
||||
// A function to add all known interfaces to the WL_KNOWN_OBJECT_TYPES map from name -> Rust type
|
||||
let add_object_types_fn = format_ident!("wl_init_known_types_{}", file_name.to_str().unwrap());
|
||||
|
||||
quote! {
|
||||
#( #code )*
|
||||
|
||||
fn #add_parsers_fn() {
|
||||
#( WL_EVENT_PARSERS.write().unwrap().push(&#event_parsers); )*
|
||||
#( WL_REQUEST_PARSERS.write().unwrap().push(&#request_parsers); )*
|
||||
}
|
||||
|
||||
fn #add_object_types_fn() {
|
||||
#( WL_KNOWN_OBJECT_TYPES.write().unwrap().insert(#known_interface_names, #known_interface_consts); )*
|
||||
}
|
||||
}
|
||||
.into()
|
||||
ret.into()
|
||||
}
|
||||
|
||||
fn handle_interface(
|
||||
reader: &mut quick_xml::Reader<&[u8]>,
|
||||
start: quick_xml::events::BytesStart<'_>,
|
||||
) -> WlInterface {
|
||||
) -> proc_macro2::TokenStream {
|
||||
let name_attr = start
|
||||
.attributes()
|
||||
.map(|a| a.expect("attr parsing error"))
|
||||
|
@ -98,8 +55,30 @@ fn handle_interface(
|
|||
.expect("No name attr found for interface");
|
||||
|
||||
let interface_name_snake = std::str::from_utf8(&name_attr.value).expect("utf8 encoding error");
|
||||
let interface_name_camel = to_camel_case(interface_name_snake);
|
||||
|
||||
let mut msgs: Vec<WlMsg> = vec![];
|
||||
// Generate the implementation of the Wayland object type ID, consisting of a private struct
|
||||
// to act as a trait object, a public const that wraps the struct in `WlObjectType`, and a impl
|
||||
// of `WlObjectTypeId`.
|
||||
// Example:
|
||||
// struct WlDisplayTypeId;
|
||||
// pub const WL_DISPLAY: WlObjectType = WlObjectType::new(&WlDisplayTypeId);
|
||||
// impl WlObjectTypeId for WlDisplayTypeId { ... }
|
||||
let interface_type_id_name = format_ident!("{}TypeId", interface_name_camel);
|
||||
let interface_name_literal = LitStr::new(interface_name_snake, Span::call_site());
|
||||
let interface_name_snake_upper =
|
||||
Ident::new(&interface_name_snake.to_uppercase(), Span::call_site());
|
||||
let mut ret: proc_macro2::TokenStream = quote! {
|
||||
struct #interface_type_id_name;
|
||||
|
||||
pub const #interface_name_snake_upper: WlObjectType = WlObjectType::new(&#interface_type_id_name);
|
||||
|
||||
impl WlObjectTypeId for #interface_type_id_name {
|
||||
fn interface(&self) -> &'static str {
|
||||
#interface_name_literal
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Opcodes are tracked separately, in order, for each type (event or request)
|
||||
let mut event_opcode = 0;
|
||||
|
@ -111,46 +90,53 @@ fn handle_interface(
|
|||
Event::Start(e) => {
|
||||
let start_tag =
|
||||
str::from_utf8(e.local_name().into_inner()).expect("Unable to parse start tag");
|
||||
if start_tag == "event" {
|
||||
let append = if start_tag == "event" {
|
||||
// An event! Increment our opcode tracker for it!
|
||||
event_opcode += 1;
|
||||
msgs.push(handle_request_or_event(
|
||||
handle_request_or_event(
|
||||
reader,
|
||||
&interface_name_camel,
|
||||
&interface_name_snake_upper,
|
||||
event_opcode - 1,
|
||||
WlMsgType::Event,
|
||||
interface_name_snake,
|
||||
e,
|
||||
));
|
||||
)
|
||||
} else if start_tag == "request" {
|
||||
// A request! Increment our opcode tracker for it!
|
||||
request_opcode += 1;
|
||||
msgs.push(handle_request_or_event(
|
||||
handle_request_or_event(
|
||||
reader,
|
||||
&interface_name_camel,
|
||||
&interface_name_snake_upper,
|
||||
request_opcode - 1,
|
||||
WlMsgType::Request,
|
||||
interface_name_snake,
|
||||
e,
|
||||
));
|
||||
)
|
||||
} else {
|
||||
proc_macro2::TokenStream::new()
|
||||
};
|
||||
|
||||
ret = quote! {
|
||||
#ret
|
||||
#append
|
||||
}
|
||||
}
|
||||
Event::End(e) if e.local_name() == start.local_name() => break,
|
||||
_ => continue,
|
||||
}
|
||||
}
|
||||
|
||||
WlInterface {
|
||||
name_snake: interface_name_snake.to_string(),
|
||||
msgs,
|
||||
}
|
||||
ret
|
||||
}
|
||||
|
||||
fn handle_request_or_event(
|
||||
reader: &mut quick_xml::Reader<&[u8]>,
|
||||
interface_name_camel: &str,
|
||||
interface_name_snake_upper: &Ident,
|
||||
opcode: u16,
|
||||
msg_type: WlMsgType,
|
||||
interface_name_snake: &str,
|
||||
start: quick_xml::events::BytesStart<'_>,
|
||||
) -> WlMsg {
|
||||
) -> proc_macro2::TokenStream {
|
||||
let start_tag =
|
||||
str::from_utf8(start.local_name().into_inner()).expect("Unable to parse start tag");
|
||||
let start_tag_camel = to_camel_case(start_tag);
|
||||
let name_attr = start
|
||||
.attributes()
|
||||
.map(|a| a.expect("attr parsing error"))
|
||||
|
@ -159,6 +145,8 @@ fn handle_request_or_event(
|
|||
== "name"
|
||||
})
|
||||
.expect("No name attr found for request/event");
|
||||
let name_camel = to_camel_case(str::from_utf8(&name_attr.value).expect("utf8 encoding error"));
|
||||
|
||||
// Load arguments and their types from XML
|
||||
let mut args: Vec<(String, WlArgType)> = Vec::new();
|
||||
|
||||
|
@ -199,18 +187,66 @@ fn handle_request_or_event(
|
|||
}
|
||||
}
|
||||
|
||||
WlMsg {
|
||||
interface_name_snake: interface_name_snake.to_string(),
|
||||
name_snake: str::from_utf8(&name_attr.value)
|
||||
.expect("utf8 encoding error")
|
||||
.to_string(),
|
||||
msg_type,
|
||||
opcode,
|
||||
args,
|
||||
let (field_names, field_types): (Vec<_>, Vec<_>) = args
|
||||
.iter()
|
||||
.map(|(name, tt)| (format_ident!("{name}"), tt.to_rust_type()))
|
||||
.unzip();
|
||||
|
||||
let struct_name = format_ident!("{interface_name_camel}{name_camel}{start_tag_camel}");
|
||||
|
||||
// Struct definition, such as:
|
||||
//
|
||||
// pub struct WlDisplayGetRegistryRequest<'a> {
|
||||
// _phantom: PhantomData<&'a ()>,
|
||||
// registry: u32
|
||||
// }
|
||||
//
|
||||
// The 'a lifetime is added across the board for consistency.
|
||||
let struct_def = quote! {
|
||||
pub struct #struct_name<'a> {
|
||||
_phantom: std::marker::PhantomData<&'a ()>,
|
||||
#( pub #field_names: #field_types, )*
|
||||
}
|
||||
};
|
||||
|
||||
// Generate code to include in the parser for every field
|
||||
let parser_code: Vec<_> = args
|
||||
.into_iter()
|
||||
.map(|(arg_name, arg_type)| {
|
||||
let arg_name_ident = format_ident!("{arg_name}");
|
||||
arg_type.generate_parser_code(arg_name_ident)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let struct_impl = quote! {
|
||||
impl<'a> WlParsedMessage<'a> for #struct_name<'a> {
|
||||
fn opcode() -> u16 {
|
||||
#opcode
|
||||
}
|
||||
|
||||
fn object_type() -> WlObjectType {
|
||||
#interface_name_snake_upper
|
||||
}
|
||||
|
||||
fn try_from_msg_impl(msg: &crate::codec::WlRawMsg) -> WaylandProtocolParsingOutcome<#struct_name> {
|
||||
let payload = msg.payload();
|
||||
let mut pos = 0usize;
|
||||
#( #parser_code )*
|
||||
WaylandProtocolParsingOutcome::Ok(#struct_name {
|
||||
_phantom: std::marker::PhantomData,
|
||||
#( #field_names, )*
|
||||
})
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
quote! {
|
||||
#struct_def
|
||||
#struct_impl
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn to_camel_case(s: &str) -> String {
|
||||
fn to_camel_case(s: &str) -> String {
|
||||
s.split("_")
|
||||
.map(|item| {
|
||||
item.char_indices()
|
||||
|
|
|
@ -1,177 +1,5 @@
|
|||
use proc_macro2::Span;
|
||||
use quote::{format_ident, quote};
|
||||
use syn::{Ident, LitStr};
|
||||
|
||||
pub(crate) struct WlInterface {
|
||||
pub name_snake: String,
|
||||
pub msgs: Vec<WlMsg>,
|
||||
}
|
||||
|
||||
impl WlInterface {
|
||||
/// Name of the interface type's const representation, e.g. WL_WAYLAND
|
||||
/// This can be used as a discriminant for interface types in Rust
|
||||
pub fn type_const_name(&self) -> String {
|
||||
self.name_snake.to_uppercase()
|
||||
}
|
||||
|
||||
pub fn generate(&self) -> proc_macro2::TokenStream {
|
||||
// Generate struct and parser impls for all messages belonging to this interface
|
||||
let msg_impl = self.msgs.iter().map(|msg| msg.generate_struct_and_impl());
|
||||
|
||||
// Also generate a struct representing the type of this interface
|
||||
// This is used to keep track of all objects in [objects]
|
||||
// Example:
|
||||
// struct WlDisplayTypeId;
|
||||
// pub const WL_DISPLAY: WlObjectType = WlObjectType::new(&WlDisplayTypeId);
|
||||
// impl WlObjectTypeId for WlDisplayTypeId { ... }
|
||||
let interface_type_id_name =
|
||||
format_ident!("{}TypeId", crate::to_camel_case(&self.name_snake));
|
||||
let interface_name_literal = LitStr::new(&self.name_snake, Span::call_site());
|
||||
let type_const_name = format_ident!("{}", self.type_const_name());
|
||||
|
||||
quote! {
|
||||
struct #interface_type_id_name;
|
||||
|
||||
pub const #type_const_name: WlObjectType = WlObjectType::new(&#interface_type_id_name);
|
||||
|
||||
impl WlObjectTypeId for #interface_type_id_name {
|
||||
fn interface(&self) -> &'static str {
|
||||
#interface_name_literal
|
||||
}
|
||||
}
|
||||
|
||||
#( #msg_impl )*
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) enum WlMsgType {
|
||||
Request,
|
||||
Event,
|
||||
}
|
||||
|
||||
impl WlMsgType {
|
||||
fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
WlMsgType::Request => "Request",
|
||||
WlMsgType::Event => "Event",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct WlMsg {
|
||||
pub interface_name_snake: String,
|
||||
pub name_snake: String,
|
||||
pub msg_type: WlMsgType,
|
||||
pub opcode: u16,
|
||||
pub args: Vec<(String, WlArgType)>,
|
||||
}
|
||||
|
||||
impl WlMsg {
|
||||
/// Get the name of the structure generated for this message
|
||||
/// e.g. WlRegistryBindRequest
|
||||
pub fn struct_name(&self) -> String {
|
||||
format!(
|
||||
"{}{}{}",
|
||||
crate::to_camel_case(&self.interface_name_snake),
|
||||
crate::to_camel_case(&self.name_snake),
|
||||
self.msg_type.as_str()
|
||||
)
|
||||
}
|
||||
|
||||
pub fn parser_fn_name(&self) -> String {
|
||||
format!("{}ParserFn", self.struct_name())
|
||||
}
|
||||
|
||||
/// Generates a struct corresponding to the message type and a impl for [WlParsedMessage]
|
||||
/// that includes a parser
|
||||
pub fn generate_struct_and_impl(&self) -> proc_macro2::TokenStream {
|
||||
let opcode = self.opcode;
|
||||
let interface_name_snake_upper =
|
||||
format_ident!("{}", self.interface_name_snake.to_uppercase());
|
||||
let msg_type = format_ident!("{}", self.msg_type.as_str());
|
||||
|
||||
let struct_name = format_ident!("{}", self.struct_name());
|
||||
|
||||
let parser_fn_name = format_ident!("{}", self.parser_fn_name());
|
||||
|
||||
// Build all field names and their corresponding Rust type identifiers
|
||||
let (field_names, field_types): (Vec<_>, Vec<_>) = self
|
||||
.args
|
||||
.iter()
|
||||
.map(|(name, tt)| (format_ident!("{name}"), tt.to_rust_type()))
|
||||
.unzip();
|
||||
|
||||
// Generate code to include in the parser for every field
|
||||
let parser_code: Vec<_> = self
|
||||
.args
|
||||
.iter()
|
||||
.map(|(arg_name, arg_type)| {
|
||||
let arg_name_ident = format_ident!("{arg_name}");
|
||||
arg_type.generate_parser_code(arg_name_ident)
|
||||
})
|
||||
.collect();
|
||||
|
||||
quote! {
|
||||
pub struct #struct_name<'a> {
|
||||
_phantom: std::marker::PhantomData<&'a ()>,
|
||||
#( pub #field_names: #field_types, )*
|
||||
}
|
||||
|
||||
impl<'a> __private::WlParsedMessagePrivate for #struct_name<'a> {}
|
||||
|
||||
impl<'a> WlParsedMessage<'a> for #struct_name<'a> {
|
||||
fn opcode() -> u16 {
|
||||
#opcode
|
||||
}
|
||||
|
||||
fn self_opcode(&self) -> u16 {
|
||||
#opcode
|
||||
}
|
||||
|
||||
fn object_type() -> WlObjectType {
|
||||
#interface_name_snake_upper
|
||||
}
|
||||
|
||||
fn self_object_type(&self) -> WlObjectType {
|
||||
#interface_name_snake_upper
|
||||
}
|
||||
|
||||
fn msg_type() -> WlMsgType {
|
||||
WlMsgType::#msg_type
|
||||
}
|
||||
|
||||
fn self_msg_type(&self) -> WlMsgType {
|
||||
WlMsgType::#msg_type
|
||||
}
|
||||
|
||||
fn try_from_msg_impl(msg: &crate::codec::WlRawMsg) -> WaylandProtocolParsingOutcome<#struct_name> {
|
||||
let payload = msg.payload();
|
||||
let mut pos = 0usize;
|
||||
#( #parser_code )*
|
||||
WaylandProtocolParsingOutcome::Ok(#struct_name {
|
||||
_phantom: std::marker::PhantomData,
|
||||
#( #field_names, )*
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<'a> AnyWlParsedMessage<'a> for #struct_name<'a> {}
|
||||
|
||||
pub struct #parser_fn_name;
|
||||
|
||||
impl WlMsgParserFn for #parser_fn_name {
|
||||
fn try_from_msg<'obj, 'msg>(
|
||||
&self,
|
||||
objects: &'obj WlObjects,
|
||||
msg: &'msg WlRawMsg,
|
||||
) -> WaylandProtocolParsingOutcome<Box<dyn AnyWlParsedMessage<'msg> + 'msg>> {
|
||||
#struct_name::try_from_msg(objects, msg).map(|r| Box::new(r) as Box<_>)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
use quote::quote;
|
||||
use syn::Ident;
|
||||
|
||||
pub(crate) enum WlArgType {
|
||||
Int,
|
||||
|
|
|
@ -17,8 +17,6 @@ use tracing::{Instrument, Level, debug, error, info, span};
|
|||
#[tokio::main]
|
||||
async fn main() {
|
||||
tracing_subscriber::fmt::init();
|
||||
proto::wl_init_parsers();
|
||||
proto::wl_init_known_types();
|
||||
|
||||
let mut conf_file = "config.toml";
|
||||
|
||||
|
|
|
@ -64,10 +64,6 @@ impl WlObjects {
|
|||
self.objects.get(&id).cloned()
|
||||
}
|
||||
|
||||
pub fn remove_object(&mut self, id: u32) {
|
||||
self.objects.remove(&id);
|
||||
}
|
||||
|
||||
pub fn record_global(&mut self, name: u32, interface: &str) {
|
||||
self.global_names.insert(name, interface.to_string());
|
||||
}
|
||||
|
|
200
src/proto.rs
200
src/proto.rs
|
@ -1,10 +1,6 @@
|
|||
//! Protocol definitions necessary for this MITM proxy
|
||||
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
hash::{BuildHasherDefault, DefaultHasher},
|
||||
sync::RwLock,
|
||||
};
|
||||
// ---------- wl_display ---------
|
||||
|
||||
use byteorder::ByteOrder;
|
||||
use protogen::wayland_proto_gen;
|
||||
|
@ -14,33 +10,26 @@ use crate::{
|
|||
objects::{WlObjectType, WlObjectTypeId, WlObjects},
|
||||
};
|
||||
|
||||
macro_rules! bubble_malformed {
|
||||
($e:expr) => {{
|
||||
let e = $e;
|
||||
if let crate::proto::WaylandProtocolParsingOutcome::MalformedMessage = e {
|
||||
return WaylandProtocolParsingOutcome::MalformedMessage;
|
||||
macro_rules! reject_malformed {
|
||||
($e:expr) => {
|
||||
if let crate::proto::WaylandProtocolParsingOutcome::MalformedMessage = $e {
|
||||
return false;
|
||||
} else if let crate::proto::WaylandProtocolParsingOutcome::Ok(e) = $e {
|
||||
Some(e)
|
||||
} else {
|
||||
e
|
||||
}
|
||||
}};
|
||||
}
|
||||
|
||||
macro_rules! match_decoded {
|
||||
(match $decoded:ident {$($t:ty => $act:block$(,)?)+}) => {
|
||||
if let crate::proto::WaylandProtocolParsingOutcome::Ok($decoded) = $decoded {
|
||||
$(
|
||||
if let Some($decoded) = $decoded.downcast_ref::<$t>() {
|
||||
$act
|
||||
}
|
||||
)+
|
||||
None
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq)]
|
||||
pub enum WlMsgType {
|
||||
Request,
|
||||
Event,
|
||||
macro_rules! decode_and_match_msg {
|
||||
($objects:expr, match $msg:ident {$($t:ty => $act:block$(,)?)+}) => {
|
||||
$(
|
||||
if let Some($msg) = reject_malformed!(<$t as crate::proto::WlParsedMessage>::try_from_msg(&$objects, $msg)) {
|
||||
$act
|
||||
}
|
||||
)+
|
||||
};
|
||||
}
|
||||
|
||||
pub enum WaylandProtocolParsingOutcome<T> {
|
||||
|
@ -48,50 +37,15 @@ pub enum WaylandProtocolParsingOutcome<T> {
|
|||
MalformedMessage,
|
||||
IncorrectObject,
|
||||
IncorrectOpcode,
|
||||
Unknown,
|
||||
}
|
||||
|
||||
impl<T> WaylandProtocolParsingOutcome<T> {
|
||||
pub fn map<U>(self, f: impl Fn(T) -> U) -> WaylandProtocolParsingOutcome<U> {
|
||||
match self {
|
||||
WaylandProtocolParsingOutcome::Ok(t) => WaylandProtocolParsingOutcome::Ok(f(t)),
|
||||
WaylandProtocolParsingOutcome::MalformedMessage => {
|
||||
WaylandProtocolParsingOutcome::MalformedMessage
|
||||
}
|
||||
WaylandProtocolParsingOutcome::IncorrectObject => {
|
||||
WaylandProtocolParsingOutcome::IncorrectObject
|
||||
}
|
||||
WaylandProtocolParsingOutcome::IncorrectOpcode => {
|
||||
WaylandProtocolParsingOutcome::IncorrectOpcode
|
||||
}
|
||||
WaylandProtocolParsingOutcome::Unknown => WaylandProtocolParsingOutcome::Unknown,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Internal module used to seal the [WlParsedMessage] trait
|
||||
mod __private {
|
||||
pub(super) trait WlParsedMessagePrivate {}
|
||||
}
|
||||
|
||||
#[allow(private_bounds)]
|
||||
pub trait WlParsedMessage<'a>: __private::WlParsedMessagePrivate {
|
||||
fn opcode() -> u16
|
||||
where
|
||||
Self: Sized;
|
||||
fn object_type() -> WlObjectType
|
||||
where
|
||||
Self: Sized;
|
||||
fn msg_type() -> WlMsgType
|
||||
where
|
||||
Self: Sized;
|
||||
pub trait WlParsedMessage<'a>: Sized {
|
||||
fn opcode() -> u16;
|
||||
fn object_type() -> WlObjectType;
|
||||
fn try_from_msg<'obj>(
|
||||
objects: &'obj WlObjects,
|
||||
msg: &'a WlRawMsg,
|
||||
) -> WaylandProtocolParsingOutcome<Self>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
) -> WaylandProtocolParsingOutcome<Self> {
|
||||
// Verify object type and opcode
|
||||
if objects.lookup_object(msg.obj_id) != Some(Self::object_type()) {
|
||||
return WaylandProtocolParsingOutcome::IncorrectObject;
|
||||
|
@ -104,122 +58,10 @@ pub trait WlParsedMessage<'a>: __private::WlParsedMessagePrivate {
|
|||
Self::try_from_msg_impl(msg)
|
||||
}
|
||||
|
||||
fn try_from_msg_impl(msg: &'a WlRawMsg) -> WaylandProtocolParsingOutcome<Self>
|
||||
where
|
||||
Self: Sized;
|
||||
|
||||
// dyn-available methods
|
||||
fn self_opcode(&self) -> u16;
|
||||
fn self_object_type(&self) -> WlObjectType;
|
||||
fn self_msg_type(&self) -> WlMsgType;
|
||||
}
|
||||
|
||||
/// A version of [WlParsedMessage] that supports downcasting. By implementing this
|
||||
/// trait, you promise that the (object_type, msg_type, opcode) triple is unique, i.e.
|
||||
/// it does not overlap with any other implementation of this trait.
|
||||
pub unsafe trait AnyWlParsedMessage<'a>: WlParsedMessage<'a> {}
|
||||
|
||||
impl<'a> dyn AnyWlParsedMessage<'a> + 'a {
|
||||
pub fn downcast_ref<T: AnyWlParsedMessage<'a>>(&self) -> Option<&T> {
|
||||
if self.self_opcode() != T::opcode() {
|
||||
return None;
|
||||
}
|
||||
|
||||
if self.self_object_type() != T::object_type() {
|
||||
return None;
|
||||
}
|
||||
|
||||
if self.self_msg_type() != T::msg_type() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// SAFETY: We have verified the opcode, type, and msg type all match up
|
||||
// As per safety guarantee of [AnyWlParsedMessage], we've now narrowed
|
||||
// [self] down to one concrete type.
|
||||
Some(unsafe { &*(self as *const dyn AnyWlParsedMessage as *const T) })
|
||||
}
|
||||
}
|
||||
|
||||
/// A dyn-compatible wrapper over a specific [WlParsedMessage] type's static methods.
|
||||
/// The only exposed method, [try_from_msg], attempts to parse the message
|
||||
/// to the given type. This is used as members of [WL_EVENT_PARSERS]
|
||||
/// and [WL_REQUEST_PARSERS] to facilitate automatic parsing of all
|
||||
/// known message types.
|
||||
pub trait WlMsgParserFn: Send + Sync {
|
||||
fn try_from_msg<'obj, 'msg>(
|
||||
&self,
|
||||
objects: &'obj WlObjects,
|
||||
msg: &'msg WlRawMsg,
|
||||
) -> WaylandProtocolParsingOutcome<Box<dyn AnyWlParsedMessage<'msg> + 'msg>>;
|
||||
}
|
||||
|
||||
/// A map from known interface names to their object types in Rust representation
|
||||
static WL_KNOWN_OBJECT_TYPES: RwLock<
|
||||
HashMap<&'static str, WlObjectType, BuildHasherDefault<DefaultHasher>>,
|
||||
> = RwLock::new(HashMap::with_hasher(BuildHasherDefault::new()));
|
||||
/// Parsers for all known events
|
||||
static WL_EVENT_PARSERS: RwLock<Vec<&'static dyn WlMsgParserFn>> = RwLock::new(Vec::new());
|
||||
/// Parsers for all known requests
|
||||
static WL_REQUEST_PARSERS: RwLock<Vec<&'static dyn WlMsgParserFn>> = RwLock::new(Vec::new());
|
||||
|
||||
/// Decode a Wayland event from a [WlRawMsg], returning the type-erased result, or
|
||||
/// [WaylandProtocolParsingOutcome::Unknown] for unknown messages, [WaylandProtocolParsingOutcome::MalformedMessage]
|
||||
/// for malformed messages.
|
||||
///
|
||||
/// To downcast the parse result to a concrete message type, use [<dyn AnyWlParsedMessage>::downcast_ref]
|
||||
pub fn decode_event<'obj, 'msg>(
|
||||
objects: &'obj WlObjects,
|
||||
msg: &'msg WlRawMsg,
|
||||
) -> WaylandProtocolParsingOutcome<Box<dyn AnyWlParsedMessage<'msg> + 'msg>> {
|
||||
for p in WL_EVENT_PARSERS.read().unwrap().iter() {
|
||||
if let WaylandProtocolParsingOutcome::Ok(e) =
|
||||
bubble_malformed!(p.try_from_msg(objects, msg))
|
||||
{
|
||||
return WaylandProtocolParsingOutcome::Ok(e);
|
||||
}
|
||||
}
|
||||
|
||||
WaylandProtocolParsingOutcome::Unknown
|
||||
}
|
||||
|
||||
/// Decode a Wayland request from a [WlRawMsg], returning the type-erased result, or
|
||||
/// [WaylandProtocolParsingOutcome::Unknown] for unknown messages, [WaylandProtocolParsingOutcome::MalformedMessage]
|
||||
/// for malformed messages.
|
||||
///
|
||||
/// To downcast the parse result to a concrete message type, use [<dyn AnyWlParsedMessage>::downcast_ref]
|
||||
pub fn decode_request<'obj, 'msg>(
|
||||
objects: &'obj WlObjects,
|
||||
msg: &'msg WlRawMsg,
|
||||
) -> WaylandProtocolParsingOutcome<Box<dyn AnyWlParsedMessage<'msg> + 'msg>> {
|
||||
for p in WL_REQUEST_PARSERS.read().unwrap().iter() {
|
||||
if let WaylandProtocolParsingOutcome::Ok(e) =
|
||||
bubble_malformed!(p.try_from_msg(objects, msg))
|
||||
{
|
||||
return WaylandProtocolParsingOutcome::Ok(e);
|
||||
}
|
||||
}
|
||||
|
||||
WaylandProtocolParsingOutcome::Unknown
|
||||
}
|
||||
|
||||
/// Look up a known object type from its name to its Rust [WlObjectType] representation
|
||||
pub fn lookup_known_object_type(name: &str) -> Option<WlObjectType> {
|
||||
WL_KNOWN_OBJECT_TYPES
|
||||
.read()
|
||||
.ok()
|
||||
.and_then(|t| t.get(name).copied())
|
||||
fn try_from_msg_impl(msg: &'a WlRawMsg) -> WaylandProtocolParsingOutcome<Self>;
|
||||
}
|
||||
|
||||
/// The default object ID of wl_display
|
||||
pub const WL_DISPLAY_OBJECT_ID: u32 = 1;
|
||||
|
||||
wayland_proto_gen!("proto/wayland.xml");
|
||||
|
||||
/// Install all available Wayland protocol parsers for use by [decode_event] and [decode_request].
|
||||
pub fn wl_init_parsers() {
|
||||
wl_init_parsers_wayland();
|
||||
}
|
||||
|
||||
pub fn wl_init_known_types() {
|
||||
wl_init_known_types_wayland();
|
||||
}
|
||||
|
|
47
src/state.rs
47
src/state.rs
|
@ -7,8 +7,7 @@ use crate::{
|
|||
config::Config,
|
||||
objects::WlObjects,
|
||||
proto::{
|
||||
WL_REGISTRY, WlDisplayDeleteIdEvent, WlDisplayGetRegistryRequest, WlRegistryBindRequest,
|
||||
WlRegistryGlobalEvent,
|
||||
WL_REGISTRY, WlDisplayGetRegistryRequest, WlRegistryBindRequest, WlRegistryGlobalEvent,
|
||||
},
|
||||
};
|
||||
|
||||
|
@ -26,18 +25,9 @@ impl WlMitmState {
|
|||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub fn on_c2s_request(&mut self, raw_msg: &WlRawMsg) -> bool {
|
||||
let msg = crate::proto::decode_request(&self.objects, raw_msg);
|
||||
if let crate::proto::WaylandProtocolParsingOutcome::MalformedMessage = msg {
|
||||
debug!(
|
||||
obj_id = raw_msg.obj_id,
|
||||
opcode = raw_msg.opcode,
|
||||
"Malformed request"
|
||||
);
|
||||
return false;
|
||||
}
|
||||
|
||||
match_decoded! {
|
||||
pub fn on_c2s_request(&mut self, msg: &WlRawMsg) -> bool {
|
||||
decode_and_match_msg!(
|
||||
self.objects,
|
||||
match msg {
|
||||
WlDisplayGetRegistryRequest => {
|
||||
self.objects.record_object(WL_REGISTRY, msg.registry);
|
||||
|
@ -46,36 +36,22 @@ impl WlMitmState {
|
|||
let Some(interface) = self.objects.lookup_global(msg.name) else {
|
||||
return false;
|
||||
};
|
||||
|
||||
info!(
|
||||
interface = interface,
|
||||
obj_id = msg.id,
|
||||
"Client binding interface"
|
||||
);
|
||||
|
||||
if let Some(t) = crate::proto::lookup_known_object_type(interface) {
|
||||
//self.objects.record_object(t, msg.id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub fn on_s2c_event(&mut self, raw_msg: &WlRawMsg) -> bool {
|
||||
let msg = crate::proto::decode_event(&self.objects, raw_msg);
|
||||
if let crate::proto::WaylandProtocolParsingOutcome::MalformedMessage = msg {
|
||||
debug!(
|
||||
obj_id = raw_msg.obj_id,
|
||||
opcode = raw_msg.opcode,
|
||||
"Malformed event"
|
||||
);
|
||||
return false;
|
||||
}
|
||||
|
||||
match_decoded! {
|
||||
pub fn on_s2c_event(&mut self, msg: &WlRawMsg) -> bool {
|
||||
decode_and_match_msg!(
|
||||
self.objects,
|
||||
match msg {
|
||||
WlRegistryGlobalEvent => {
|
||||
debug!(
|
||||
|
@ -95,13 +71,8 @@ impl WlMitmState {
|
|||
return false;
|
||||
}
|
||||
}
|
||||
WlDisplayDeleteIdEvent => {
|
||||
// When an object is acknowledged to be deleted, remove it from our
|
||||
// internal cache of all registered objects
|
||||
//self.objects.remove_object(msg.id);
|
||||
}
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
true
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue