Compare commits
6 commits
5b9ce683ec
...
6698faef35
Author | SHA1 | Date | |
---|---|---|---|
6698faef35 | |||
22d17147fb | |||
5a466b5af8 | |||
5cd2e0ba2d | |||
cb85a248a4 | |||
11dc8b8119 |
6 changed files with 477 additions and 148 deletions
|
@ -1,19 +1,22 @@
|
||||||
use proc_macro2::Span;
|
use std::path::PathBuf;
|
||||||
|
|
||||||
use quick_xml::events::Event;
|
use quick_xml::events::Event;
|
||||||
use quote::{format_ident, quote};
|
use quote::{format_ident, quote};
|
||||||
use syn::{Ident, LitStr, parse_macro_input};
|
use syn::{Ident, LitStr, parse_macro_input};
|
||||||
use types::WlArgType;
|
use types::{WlArgType, WlInterface, WlMsg, WlMsgType};
|
||||||
|
|
||||||
mod types;
|
mod types;
|
||||||
|
|
||||||
#[proc_macro]
|
#[proc_macro]
|
||||||
pub fn wayland_proto_gen(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
|
pub fn wayland_proto_gen(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
|
||||||
let input: LitStr = parse_macro_input!(item);
|
let input: LitStr = parse_macro_input!(item);
|
||||||
let xml_str = std::fs::read_to_string(input.value()).expect("Unable to read from file");
|
let p = PathBuf::from(input.value());
|
||||||
|
let file_name = p.file_stem().expect("No file name provided");
|
||||||
|
let xml_str = std::fs::read_to_string(&p).expect("Unable to read from file");
|
||||||
let mut reader = quick_xml::Reader::from_str(&xml_str);
|
let mut reader = quick_xml::Reader::from_str(&xml_str);
|
||||||
reader.config_mut().trim_text(true);
|
reader.config_mut().trim_text(true);
|
||||||
|
|
||||||
let mut ret = proc_macro2::TokenStream::new();
|
let mut interfaces: Vec<WlInterface> = vec![];
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
match reader.read_event().expect("Unable to parse XML file") {
|
match reader.read_event().expect("Unable to parse XML file") {
|
||||||
|
@ -25,11 +28,7 @@ pub fn wayland_proto_gen(item: proc_macro::TokenStream) -> proc_macro::TokenStre
|
||||||
match name {
|
match name {
|
||||||
"interface" => {
|
"interface" => {
|
||||||
// An <interface> section
|
// An <interface> section
|
||||||
let str = handle_interface(&mut reader, e);
|
interfaces.push(handle_interface(&mut reader, e));
|
||||||
ret = quote! {
|
|
||||||
#ret
|
|
||||||
#str
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
|
@ -38,13 +37,57 @@ pub fn wayland_proto_gen(item: proc_macro::TokenStream) -> proc_macro::TokenStre
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ret.into()
|
let mut code: Vec<proc_macro2::TokenStream> = vec![];
|
||||||
|
let mut event_parsers: Vec<Ident> = vec![];
|
||||||
|
let mut request_parsers: Vec<Ident> = vec![];
|
||||||
|
let (mut known_interface_names, mut known_interface_consts): (Vec<String>, Vec<Ident>) =
|
||||||
|
(vec![], vec![]);
|
||||||
|
|
||||||
|
for i in interfaces.iter() {
|
||||||
|
known_interface_names.push(i.name_snake.clone());
|
||||||
|
known_interface_consts.push(format_ident!("{}", i.type_const_name()));
|
||||||
|
|
||||||
|
code.push(i.generate());
|
||||||
|
|
||||||
|
for m in i.msgs.iter() {
|
||||||
|
let parser_name = format_ident!("{}", m.parser_fn_name());
|
||||||
|
|
||||||
|
match m.msg_type {
|
||||||
|
WlMsgType::Event => {
|
||||||
|
event_parsers.push(parser_name);
|
||||||
|
}
|
||||||
|
WlMsgType::Request => {
|
||||||
|
request_parsers.push(parser_name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// A function to add all event/request parsers to WL_EVENT_PARSERS and WL_REQUEST_PARSERS
|
||||||
|
let add_parsers_fn = format_ident!("wl_init_parsers_{}", file_name.to_str().unwrap());
|
||||||
|
|
||||||
|
// A function to add all known interfaces to the WL_KNOWN_OBJECT_TYPES map from name -> Rust type
|
||||||
|
let add_object_types_fn = format_ident!("wl_init_known_types_{}", file_name.to_str().unwrap());
|
||||||
|
|
||||||
|
quote! {
|
||||||
|
#( #code )*
|
||||||
|
|
||||||
|
fn #add_parsers_fn() {
|
||||||
|
#( WL_EVENT_PARSERS.write().unwrap().push(&#event_parsers); )*
|
||||||
|
#( WL_REQUEST_PARSERS.write().unwrap().push(&#request_parsers); )*
|
||||||
|
}
|
||||||
|
|
||||||
|
fn #add_object_types_fn() {
|
||||||
|
#( WL_KNOWN_OBJECT_TYPES.write().unwrap().insert(#known_interface_names, #known_interface_consts); )*
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.into()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn handle_interface(
|
fn handle_interface(
|
||||||
reader: &mut quick_xml::Reader<&[u8]>,
|
reader: &mut quick_xml::Reader<&[u8]>,
|
||||||
start: quick_xml::events::BytesStart<'_>,
|
start: quick_xml::events::BytesStart<'_>,
|
||||||
) -> proc_macro2::TokenStream {
|
) -> WlInterface {
|
||||||
let name_attr = start
|
let name_attr = start
|
||||||
.attributes()
|
.attributes()
|
||||||
.map(|a| a.expect("attr parsing error"))
|
.map(|a| a.expect("attr parsing error"))
|
||||||
|
@ -55,30 +98,8 @@ fn handle_interface(
|
||||||
.expect("No name attr found for interface");
|
.expect("No name attr found for interface");
|
||||||
|
|
||||||
let interface_name_snake = std::str::from_utf8(&name_attr.value).expect("utf8 encoding error");
|
let interface_name_snake = std::str::from_utf8(&name_attr.value).expect("utf8 encoding error");
|
||||||
let interface_name_camel = to_camel_case(interface_name_snake);
|
|
||||||
|
|
||||||
// Generate the implementation of the Wayland object type ID, consisting of a private struct
|
let mut msgs: Vec<WlMsg> = vec![];
|
||||||
// to act as a trait object, a public const that wraps the struct in `WlObjectType`, and a impl
|
|
||||||
// of `WlObjectTypeId`.
|
|
||||||
// Example:
|
|
||||||
// struct WlDisplayTypeId;
|
|
||||||
// pub const WL_DISPLAY: WlObjectType = WlObjectType::new(&WlDisplayTypeId);
|
|
||||||
// impl WlObjectTypeId for WlDisplayTypeId { ... }
|
|
||||||
let interface_type_id_name = format_ident!("{}TypeId", interface_name_camel);
|
|
||||||
let interface_name_literal = LitStr::new(interface_name_snake, Span::call_site());
|
|
||||||
let interface_name_snake_upper =
|
|
||||||
Ident::new(&interface_name_snake.to_uppercase(), Span::call_site());
|
|
||||||
let mut ret: proc_macro2::TokenStream = quote! {
|
|
||||||
struct #interface_type_id_name;
|
|
||||||
|
|
||||||
pub const #interface_name_snake_upper: WlObjectType = WlObjectType::new(&#interface_type_id_name);
|
|
||||||
|
|
||||||
impl WlObjectTypeId for #interface_type_id_name {
|
|
||||||
fn interface(&self) -> &'static str {
|
|
||||||
#interface_name_literal
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Opcodes are tracked separately, in order, for each type (event or request)
|
// Opcodes are tracked separately, in order, for each type (event or request)
|
||||||
let mut event_opcode = 0;
|
let mut event_opcode = 0;
|
||||||
|
@ -90,53 +111,46 @@ fn handle_interface(
|
||||||
Event::Start(e) => {
|
Event::Start(e) => {
|
||||||
let start_tag =
|
let start_tag =
|
||||||
str::from_utf8(e.local_name().into_inner()).expect("Unable to parse start tag");
|
str::from_utf8(e.local_name().into_inner()).expect("Unable to parse start tag");
|
||||||
let append = if start_tag == "event" {
|
if start_tag == "event" {
|
||||||
// An event! Increment our opcode tracker for it!
|
// An event! Increment our opcode tracker for it!
|
||||||
event_opcode += 1;
|
event_opcode += 1;
|
||||||
handle_request_or_event(
|
msgs.push(handle_request_or_event(
|
||||||
reader,
|
reader,
|
||||||
&interface_name_camel,
|
|
||||||
&interface_name_snake_upper,
|
|
||||||
event_opcode - 1,
|
event_opcode - 1,
|
||||||
|
WlMsgType::Event,
|
||||||
|
interface_name_snake,
|
||||||
e,
|
e,
|
||||||
)
|
));
|
||||||
} else if start_tag == "request" {
|
} else if start_tag == "request" {
|
||||||
// A request! Increment our opcode tracker for it!
|
// A request! Increment our opcode tracker for it!
|
||||||
request_opcode += 1;
|
request_opcode += 1;
|
||||||
handle_request_or_event(
|
msgs.push(handle_request_or_event(
|
||||||
reader,
|
reader,
|
||||||
&interface_name_camel,
|
|
||||||
&interface_name_snake_upper,
|
|
||||||
request_opcode - 1,
|
request_opcode - 1,
|
||||||
|
WlMsgType::Request,
|
||||||
|
interface_name_snake,
|
||||||
e,
|
e,
|
||||||
)
|
));
|
||||||
} else {
|
|
||||||
proc_macro2::TokenStream::new()
|
|
||||||
};
|
};
|
||||||
|
|
||||||
ret = quote! {
|
|
||||||
#ret
|
|
||||||
#append
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
Event::End(e) if e.local_name() == start.local_name() => break,
|
Event::End(e) if e.local_name() == start.local_name() => break,
|
||||||
_ => continue,
|
_ => continue,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ret
|
WlInterface {
|
||||||
|
name_snake: interface_name_snake.to_string(),
|
||||||
|
msgs,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn handle_request_or_event(
|
fn handle_request_or_event(
|
||||||
reader: &mut quick_xml::Reader<&[u8]>,
|
reader: &mut quick_xml::Reader<&[u8]>,
|
||||||
interface_name_camel: &str,
|
|
||||||
interface_name_snake_upper: &Ident,
|
|
||||||
opcode: u16,
|
opcode: u16,
|
||||||
|
msg_type: WlMsgType,
|
||||||
|
interface_name_snake: &str,
|
||||||
start: quick_xml::events::BytesStart<'_>,
|
start: quick_xml::events::BytesStart<'_>,
|
||||||
) -> proc_macro2::TokenStream {
|
) -> WlMsg {
|
||||||
let start_tag =
|
|
||||||
str::from_utf8(start.local_name().into_inner()).expect("Unable to parse start tag");
|
|
||||||
let start_tag_camel = to_camel_case(start_tag);
|
|
||||||
let name_attr = start
|
let name_attr = start
|
||||||
.attributes()
|
.attributes()
|
||||||
.map(|a| a.expect("attr parsing error"))
|
.map(|a| a.expect("attr parsing error"))
|
||||||
|
@ -145,8 +159,6 @@ fn handle_request_or_event(
|
||||||
== "name"
|
== "name"
|
||||||
})
|
})
|
||||||
.expect("No name attr found for request/event");
|
.expect("No name attr found for request/event");
|
||||||
let name_camel = to_camel_case(str::from_utf8(&name_attr.value).expect("utf8 encoding error"));
|
|
||||||
|
|
||||||
// Load arguments and their types from XML
|
// Load arguments and their types from XML
|
||||||
let mut args: Vec<(String, WlArgType)> = Vec::new();
|
let mut args: Vec<(String, WlArgType)> = Vec::new();
|
||||||
|
|
||||||
|
@ -187,66 +199,18 @@ fn handle_request_or_event(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let (field_names, field_types): (Vec<_>, Vec<_>) = args
|
WlMsg {
|
||||||
.iter()
|
interface_name_snake: interface_name_snake.to_string(),
|
||||||
.map(|(name, tt)| (format_ident!("{name}"), tt.to_rust_type()))
|
name_snake: str::from_utf8(&name_attr.value)
|
||||||
.unzip();
|
.expect("utf8 encoding error")
|
||||||
|
.to_string(),
|
||||||
let struct_name = format_ident!("{interface_name_camel}{name_camel}{start_tag_camel}");
|
msg_type,
|
||||||
|
opcode,
|
||||||
// Struct definition, such as:
|
args,
|
||||||
//
|
|
||||||
// pub struct WlDisplayGetRegistryRequest<'a> {
|
|
||||||
// _phantom: PhantomData<&'a ()>,
|
|
||||||
// registry: u32
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// The 'a lifetime is added across the board for consistency.
|
|
||||||
let struct_def = quote! {
|
|
||||||
pub struct #struct_name<'a> {
|
|
||||||
_phantom: std::marker::PhantomData<&'a ()>,
|
|
||||||
#( pub #field_names: #field_types, )*
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Generate code to include in the parser for every field
|
|
||||||
let parser_code: Vec<_> = args
|
|
||||||
.into_iter()
|
|
||||||
.map(|(arg_name, arg_type)| {
|
|
||||||
let arg_name_ident = format_ident!("{arg_name}");
|
|
||||||
arg_type.generate_parser_code(arg_name_ident)
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
let struct_impl = quote! {
|
|
||||||
impl<'a> WlParsedMessage<'a> for #struct_name<'a> {
|
|
||||||
fn opcode() -> u16 {
|
|
||||||
#opcode
|
|
||||||
}
|
|
||||||
|
|
||||||
fn object_type() -> WlObjectType {
|
|
||||||
#interface_name_snake_upper
|
|
||||||
}
|
|
||||||
|
|
||||||
fn try_from_msg_impl(msg: &crate::codec::WlRawMsg) -> WaylandProtocolParsingOutcome<#struct_name> {
|
|
||||||
let payload = msg.payload();
|
|
||||||
let mut pos = 0usize;
|
|
||||||
#( #parser_code )*
|
|
||||||
WaylandProtocolParsingOutcome::Ok(#struct_name {
|
|
||||||
_phantom: std::marker::PhantomData,
|
|
||||||
#( #field_names, )*
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
quote! {
|
|
||||||
#struct_def
|
|
||||||
#struct_impl
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn to_camel_case(s: &str) -> String {
|
pub(crate) fn to_camel_case(s: &str) -> String {
|
||||||
s.split("_")
|
s.split("_")
|
||||||
.map(|item| {
|
.map(|item| {
|
||||||
item.char_indices()
|
item.char_indices()
|
||||||
|
|
|
@ -1,5 +1,177 @@
|
||||||
use quote::quote;
|
use proc_macro2::Span;
|
||||||
use syn::Ident;
|
use quote::{format_ident, quote};
|
||||||
|
use syn::{Ident, LitStr};
|
||||||
|
|
||||||
|
pub(crate) struct WlInterface {
|
||||||
|
pub name_snake: String,
|
||||||
|
pub msgs: Vec<WlMsg>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WlInterface {
|
||||||
|
/// Name of the interface type's const representation, e.g. WL_WAYLAND
|
||||||
|
/// This can be used as a discriminant for interface types in Rust
|
||||||
|
pub fn type_const_name(&self) -> String {
|
||||||
|
self.name_snake.to_uppercase()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn generate(&self) -> proc_macro2::TokenStream {
|
||||||
|
// Generate struct and parser impls for all messages belonging to this interface
|
||||||
|
let msg_impl = self.msgs.iter().map(|msg| msg.generate_struct_and_impl());
|
||||||
|
|
||||||
|
// Also generate a struct representing the type of this interface
|
||||||
|
// This is used to keep track of all objects in [objects]
|
||||||
|
// Example:
|
||||||
|
// struct WlDisplayTypeId;
|
||||||
|
// pub const WL_DISPLAY: WlObjectType = WlObjectType::new(&WlDisplayTypeId);
|
||||||
|
// impl WlObjectTypeId for WlDisplayTypeId { ... }
|
||||||
|
let interface_type_id_name =
|
||||||
|
format_ident!("{}TypeId", crate::to_camel_case(&self.name_snake));
|
||||||
|
let interface_name_literal = LitStr::new(&self.name_snake, Span::call_site());
|
||||||
|
let type_const_name = format_ident!("{}", self.type_const_name());
|
||||||
|
|
||||||
|
quote! {
|
||||||
|
struct #interface_type_id_name;
|
||||||
|
|
||||||
|
pub const #type_const_name: WlObjectType = WlObjectType::new(&#interface_type_id_name);
|
||||||
|
|
||||||
|
impl WlObjectTypeId for #interface_type_id_name {
|
||||||
|
fn interface(&self) -> &'static str {
|
||||||
|
#interface_name_literal
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#( #msg_impl )*
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) enum WlMsgType {
|
||||||
|
Request,
|
||||||
|
Event,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WlMsgType {
|
||||||
|
fn as_str(&self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
WlMsgType::Request => "Request",
|
||||||
|
WlMsgType::Event => "Event",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) struct WlMsg {
|
||||||
|
pub interface_name_snake: String,
|
||||||
|
pub name_snake: String,
|
||||||
|
pub msg_type: WlMsgType,
|
||||||
|
pub opcode: u16,
|
||||||
|
pub args: Vec<(String, WlArgType)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WlMsg {
|
||||||
|
/// Get the name of the structure generated for this message
|
||||||
|
/// e.g. WlRegistryBindRequest
|
||||||
|
pub fn struct_name(&self) -> String {
|
||||||
|
format!(
|
||||||
|
"{}{}{}",
|
||||||
|
crate::to_camel_case(&self.interface_name_snake),
|
||||||
|
crate::to_camel_case(&self.name_snake),
|
||||||
|
self.msg_type.as_str()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn parser_fn_name(&self) -> String {
|
||||||
|
format!("{}ParserFn", self.struct_name())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates a struct corresponding to the message type and a impl for [WlParsedMessage]
|
||||||
|
/// that includes a parser
|
||||||
|
pub fn generate_struct_and_impl(&self) -> proc_macro2::TokenStream {
|
||||||
|
let opcode = self.opcode;
|
||||||
|
let interface_name_snake_upper =
|
||||||
|
format_ident!("{}", self.interface_name_snake.to_uppercase());
|
||||||
|
let msg_type = format_ident!("{}", self.msg_type.as_str());
|
||||||
|
|
||||||
|
let struct_name = format_ident!("{}", self.struct_name());
|
||||||
|
|
||||||
|
let parser_fn_name = format_ident!("{}", self.parser_fn_name());
|
||||||
|
|
||||||
|
// Build all field names and their corresponding Rust type identifiers
|
||||||
|
let (field_names, field_types): (Vec<_>, Vec<_>) = self
|
||||||
|
.args
|
||||||
|
.iter()
|
||||||
|
.map(|(name, tt)| (format_ident!("{name}"), tt.to_rust_type()))
|
||||||
|
.unzip();
|
||||||
|
|
||||||
|
// Generate code to include in the parser for every field
|
||||||
|
let parser_code: Vec<_> = self
|
||||||
|
.args
|
||||||
|
.iter()
|
||||||
|
.map(|(arg_name, arg_type)| {
|
||||||
|
let arg_name_ident = format_ident!("{arg_name}");
|
||||||
|
arg_type.generate_parser_code(arg_name_ident)
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
quote! {
|
||||||
|
pub struct #struct_name<'a> {
|
||||||
|
_phantom: std::marker::PhantomData<&'a ()>,
|
||||||
|
#( pub #field_names: #field_types, )*
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> __private::WlParsedMessagePrivate for #struct_name<'a> {}
|
||||||
|
|
||||||
|
impl<'a> WlParsedMessage<'a> for #struct_name<'a> {
|
||||||
|
fn opcode() -> u16 {
|
||||||
|
#opcode
|
||||||
|
}
|
||||||
|
|
||||||
|
fn self_opcode(&self) -> u16 {
|
||||||
|
#opcode
|
||||||
|
}
|
||||||
|
|
||||||
|
fn object_type() -> WlObjectType {
|
||||||
|
#interface_name_snake_upper
|
||||||
|
}
|
||||||
|
|
||||||
|
fn self_object_type(&self) -> WlObjectType {
|
||||||
|
#interface_name_snake_upper
|
||||||
|
}
|
||||||
|
|
||||||
|
fn msg_type() -> WlMsgType {
|
||||||
|
WlMsgType::#msg_type
|
||||||
|
}
|
||||||
|
|
||||||
|
fn self_msg_type(&self) -> WlMsgType {
|
||||||
|
WlMsgType::#msg_type
|
||||||
|
}
|
||||||
|
|
||||||
|
fn try_from_msg_impl(msg: &crate::codec::WlRawMsg) -> WaylandProtocolParsingOutcome<#struct_name> {
|
||||||
|
let payload = msg.payload();
|
||||||
|
let mut pos = 0usize;
|
||||||
|
#( #parser_code )*
|
||||||
|
WaylandProtocolParsingOutcome::Ok(#struct_name {
|
||||||
|
_phantom: std::marker::PhantomData,
|
||||||
|
#( #field_names, )*
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe impl<'a> AnyWlParsedMessage<'a> for #struct_name<'a> {}
|
||||||
|
|
||||||
|
pub struct #parser_fn_name;
|
||||||
|
|
||||||
|
impl WlMsgParserFn for #parser_fn_name {
|
||||||
|
fn try_from_msg<'obj, 'msg>(
|
||||||
|
&self,
|
||||||
|
objects: &'obj WlObjects,
|
||||||
|
msg: &'msg WlRawMsg,
|
||||||
|
) -> WaylandProtocolParsingOutcome<Box<dyn AnyWlParsedMessage<'msg> + 'msg>> {
|
||||||
|
#struct_name::try_from_msg(objects, msg).map(|r| Box::new(r) as Box<_>)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) enum WlArgType {
|
pub(crate) enum WlArgType {
|
||||||
Int,
|
Int,
|
||||||
|
|
|
@ -17,6 +17,8 @@ use tracing::{Instrument, Level, debug, error, info, span};
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() {
|
async fn main() {
|
||||||
tracing_subscriber::fmt::init();
|
tracing_subscriber::fmt::init();
|
||||||
|
proto::wl_init_parsers();
|
||||||
|
proto::wl_init_known_types();
|
||||||
|
|
||||||
let mut conf_file = "config.toml";
|
let mut conf_file = "config.toml";
|
||||||
|
|
||||||
|
|
|
@ -64,6 +64,10 @@ impl WlObjects {
|
||||||
self.objects.get(&id).cloned()
|
self.objects.get(&id).cloned()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn remove_object(&mut self, id: u32) {
|
||||||
|
self.objects.remove(&id);
|
||||||
|
}
|
||||||
|
|
||||||
pub fn record_global(&mut self, name: u32, interface: &str) {
|
pub fn record_global(&mut self, name: u32, interface: &str) {
|
||||||
self.global_names.insert(name, interface.to_string());
|
self.global_names.insert(name, interface.to_string());
|
||||||
}
|
}
|
||||||
|
|
200
src/proto.rs
200
src/proto.rs
|
@ -1,6 +1,10 @@
|
||||||
//! Protocol definitions necessary for this MITM proxy
|
//! Protocol definitions necessary for this MITM proxy
|
||||||
|
|
||||||
// ---------- wl_display ---------
|
use std::{
|
||||||
|
collections::HashMap,
|
||||||
|
hash::{BuildHasherDefault, DefaultHasher},
|
||||||
|
sync::RwLock,
|
||||||
|
};
|
||||||
|
|
||||||
use byteorder::ByteOrder;
|
use byteorder::ByteOrder;
|
||||||
use protogen::wayland_proto_gen;
|
use protogen::wayland_proto_gen;
|
||||||
|
@ -10,26 +14,33 @@ use crate::{
|
||||||
objects::{WlObjectType, WlObjectTypeId, WlObjects},
|
objects::{WlObjectType, WlObjectTypeId, WlObjects},
|
||||||
};
|
};
|
||||||
|
|
||||||
macro_rules! reject_malformed {
|
macro_rules! bubble_malformed {
|
||||||
($e:expr) => {
|
($e:expr) => {{
|
||||||
if let crate::proto::WaylandProtocolParsingOutcome::MalformedMessage = $e {
|
let e = $e;
|
||||||
return false;
|
if let crate::proto::WaylandProtocolParsingOutcome::MalformedMessage = e {
|
||||||
} else if let crate::proto::WaylandProtocolParsingOutcome::Ok(e) = $e {
|
return WaylandProtocolParsingOutcome::MalformedMessage;
|
||||||
Some(e)
|
|
||||||
} else {
|
} else {
|
||||||
None
|
e
|
||||||
|
}
|
||||||
|
}};
|
||||||
|
}
|
||||||
|
|
||||||
|
macro_rules! match_decoded {
|
||||||
|
(match $decoded:ident {$($t:ty => $act:block$(,)?)+}) => {
|
||||||
|
if let crate::proto::WaylandProtocolParsingOutcome::Ok($decoded) = $decoded {
|
||||||
|
$(
|
||||||
|
if let Some($decoded) = $decoded.downcast_ref::<$t>() {
|
||||||
|
$act
|
||||||
|
}
|
||||||
|
)+
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
macro_rules! decode_and_match_msg {
|
#[derive(PartialEq, Eq)]
|
||||||
($objects:expr, match $msg:ident {$($t:ty => $act:block$(,)?)+}) => {
|
pub enum WlMsgType {
|
||||||
$(
|
Request,
|
||||||
if let Some($msg) = reject_malformed!(<$t as crate::proto::WlParsedMessage>::try_from_msg(&$objects, $msg)) {
|
Event,
|
||||||
$act
|
|
||||||
}
|
|
||||||
)+
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub enum WaylandProtocolParsingOutcome<T> {
|
pub enum WaylandProtocolParsingOutcome<T> {
|
||||||
|
@ -37,15 +48,50 @@ pub enum WaylandProtocolParsingOutcome<T> {
|
||||||
MalformedMessage,
|
MalformedMessage,
|
||||||
IncorrectObject,
|
IncorrectObject,
|
||||||
IncorrectOpcode,
|
IncorrectOpcode,
|
||||||
|
Unknown,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait WlParsedMessage<'a>: Sized {
|
impl<T> WaylandProtocolParsingOutcome<T> {
|
||||||
fn opcode() -> u16;
|
pub fn map<U>(self, f: impl Fn(T) -> U) -> WaylandProtocolParsingOutcome<U> {
|
||||||
fn object_type() -> WlObjectType;
|
match self {
|
||||||
|
WaylandProtocolParsingOutcome::Ok(t) => WaylandProtocolParsingOutcome::Ok(f(t)),
|
||||||
|
WaylandProtocolParsingOutcome::MalformedMessage => {
|
||||||
|
WaylandProtocolParsingOutcome::MalformedMessage
|
||||||
|
}
|
||||||
|
WaylandProtocolParsingOutcome::IncorrectObject => {
|
||||||
|
WaylandProtocolParsingOutcome::IncorrectObject
|
||||||
|
}
|
||||||
|
WaylandProtocolParsingOutcome::IncorrectOpcode => {
|
||||||
|
WaylandProtocolParsingOutcome::IncorrectOpcode
|
||||||
|
}
|
||||||
|
WaylandProtocolParsingOutcome::Unknown => WaylandProtocolParsingOutcome::Unknown,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Internal module used to seal the [WlParsedMessage] trait
|
||||||
|
mod __private {
|
||||||
|
pub(super) trait WlParsedMessagePrivate {}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(private_bounds)]
|
||||||
|
pub trait WlParsedMessage<'a>: __private::WlParsedMessagePrivate {
|
||||||
|
fn opcode() -> u16
|
||||||
|
where
|
||||||
|
Self: Sized;
|
||||||
|
fn object_type() -> WlObjectType
|
||||||
|
where
|
||||||
|
Self: Sized;
|
||||||
|
fn msg_type() -> WlMsgType
|
||||||
|
where
|
||||||
|
Self: Sized;
|
||||||
fn try_from_msg<'obj>(
|
fn try_from_msg<'obj>(
|
||||||
objects: &'obj WlObjects,
|
objects: &'obj WlObjects,
|
||||||
msg: &'a WlRawMsg,
|
msg: &'a WlRawMsg,
|
||||||
) -> WaylandProtocolParsingOutcome<Self> {
|
) -> WaylandProtocolParsingOutcome<Self>
|
||||||
|
where
|
||||||
|
Self: Sized,
|
||||||
|
{
|
||||||
// Verify object type and opcode
|
// Verify object type and opcode
|
||||||
if objects.lookup_object(msg.obj_id) != Some(Self::object_type()) {
|
if objects.lookup_object(msg.obj_id) != Some(Self::object_type()) {
|
||||||
return WaylandProtocolParsingOutcome::IncorrectObject;
|
return WaylandProtocolParsingOutcome::IncorrectObject;
|
||||||
|
@ -58,10 +104,122 @@ pub trait WlParsedMessage<'a>: Sized {
|
||||||
Self::try_from_msg_impl(msg)
|
Self::try_from_msg_impl(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn try_from_msg_impl(msg: &'a WlRawMsg) -> WaylandProtocolParsingOutcome<Self>;
|
fn try_from_msg_impl(msg: &'a WlRawMsg) -> WaylandProtocolParsingOutcome<Self>
|
||||||
|
where
|
||||||
|
Self: Sized;
|
||||||
|
|
||||||
|
// dyn-available methods
|
||||||
|
fn self_opcode(&self) -> u16;
|
||||||
|
fn self_object_type(&self) -> WlObjectType;
|
||||||
|
fn self_msg_type(&self) -> WlMsgType;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A version of [WlParsedMessage] that supports downcasting. By implementing this
|
||||||
|
/// trait, you promise that the (object_type, msg_type, opcode) triple is unique, i.e.
|
||||||
|
/// it does not overlap with any other implementation of this trait.
|
||||||
|
pub unsafe trait AnyWlParsedMessage<'a>: WlParsedMessage<'a> {}
|
||||||
|
|
||||||
|
impl<'a> dyn AnyWlParsedMessage<'a> + 'a {
|
||||||
|
pub fn downcast_ref<T: AnyWlParsedMessage<'a>>(&self) -> Option<&T> {
|
||||||
|
if self.self_opcode() != T::opcode() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.self_object_type() != T::object_type() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.self_msg_type() != T::msg_type() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
// SAFETY: We have verified the opcode, type, and msg type all match up
|
||||||
|
// As per safety guarantee of [AnyWlParsedMessage], we've now narrowed
|
||||||
|
// [self] down to one concrete type.
|
||||||
|
Some(unsafe { &*(self as *const dyn AnyWlParsedMessage as *const T) })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A dyn-compatible wrapper over a specific [WlParsedMessage] type's static methods.
|
||||||
|
/// The only exposed method, [try_from_msg], attempts to parse the message
|
||||||
|
/// to the given type. This is used as members of [WL_EVENT_PARSERS]
|
||||||
|
/// and [WL_REQUEST_PARSERS] to facilitate automatic parsing of all
|
||||||
|
/// known message types.
|
||||||
|
pub trait WlMsgParserFn: Send + Sync {
|
||||||
|
fn try_from_msg<'obj, 'msg>(
|
||||||
|
&self,
|
||||||
|
objects: &'obj WlObjects,
|
||||||
|
msg: &'msg WlRawMsg,
|
||||||
|
) -> WaylandProtocolParsingOutcome<Box<dyn AnyWlParsedMessage<'msg> + 'msg>>;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A map from known interface names to their object types in Rust representation
|
||||||
|
static WL_KNOWN_OBJECT_TYPES: RwLock<
|
||||||
|
HashMap<&'static str, WlObjectType, BuildHasherDefault<DefaultHasher>>,
|
||||||
|
> = RwLock::new(HashMap::with_hasher(BuildHasherDefault::new()));
|
||||||
|
/// Parsers for all known events
|
||||||
|
static WL_EVENT_PARSERS: RwLock<Vec<&'static dyn WlMsgParserFn>> = RwLock::new(Vec::new());
|
||||||
|
/// Parsers for all known requests
|
||||||
|
static WL_REQUEST_PARSERS: RwLock<Vec<&'static dyn WlMsgParserFn>> = RwLock::new(Vec::new());
|
||||||
|
|
||||||
|
/// Decode a Wayland event from a [WlRawMsg], returning the type-erased result, or
|
||||||
|
/// [WaylandProtocolParsingOutcome::Unknown] for unknown messages, [WaylandProtocolParsingOutcome::MalformedMessage]
|
||||||
|
/// for malformed messages.
|
||||||
|
///
|
||||||
|
/// To downcast the parse result to a concrete message type, use [<dyn AnyWlParsedMessage>::downcast_ref]
|
||||||
|
pub fn decode_event<'obj, 'msg>(
|
||||||
|
objects: &'obj WlObjects,
|
||||||
|
msg: &'msg WlRawMsg,
|
||||||
|
) -> WaylandProtocolParsingOutcome<Box<dyn AnyWlParsedMessage<'msg> + 'msg>> {
|
||||||
|
for p in WL_EVENT_PARSERS.read().unwrap().iter() {
|
||||||
|
if let WaylandProtocolParsingOutcome::Ok(e) =
|
||||||
|
bubble_malformed!(p.try_from_msg(objects, msg))
|
||||||
|
{
|
||||||
|
return WaylandProtocolParsingOutcome::Ok(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
WaylandProtocolParsingOutcome::Unknown
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Decode a Wayland request from a [WlRawMsg], returning the type-erased result, or
|
||||||
|
/// [WaylandProtocolParsingOutcome::Unknown] for unknown messages, [WaylandProtocolParsingOutcome::MalformedMessage]
|
||||||
|
/// for malformed messages.
|
||||||
|
///
|
||||||
|
/// To downcast the parse result to a concrete message type, use [<dyn AnyWlParsedMessage>::downcast_ref]
|
||||||
|
pub fn decode_request<'obj, 'msg>(
|
||||||
|
objects: &'obj WlObjects,
|
||||||
|
msg: &'msg WlRawMsg,
|
||||||
|
) -> WaylandProtocolParsingOutcome<Box<dyn AnyWlParsedMessage<'msg> + 'msg>> {
|
||||||
|
for p in WL_REQUEST_PARSERS.read().unwrap().iter() {
|
||||||
|
if let WaylandProtocolParsingOutcome::Ok(e) =
|
||||||
|
bubble_malformed!(p.try_from_msg(objects, msg))
|
||||||
|
{
|
||||||
|
return WaylandProtocolParsingOutcome::Ok(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
WaylandProtocolParsingOutcome::Unknown
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Look up a known object type from its name to its Rust [WlObjectType] representation
|
||||||
|
pub fn lookup_known_object_type(name: &str) -> Option<WlObjectType> {
|
||||||
|
WL_KNOWN_OBJECT_TYPES
|
||||||
|
.read()
|
||||||
|
.ok()
|
||||||
|
.and_then(|t| t.get(name).copied())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The default object ID of wl_display
|
/// The default object ID of wl_display
|
||||||
pub const WL_DISPLAY_OBJECT_ID: u32 = 1;
|
pub const WL_DISPLAY_OBJECT_ID: u32 = 1;
|
||||||
|
|
||||||
wayland_proto_gen!("proto/wayland.xml");
|
wayland_proto_gen!("proto/wayland.xml");
|
||||||
|
|
||||||
|
/// Install all available Wayland protocol parsers for use by [decode_event] and [decode_request].
|
||||||
|
pub fn wl_init_parsers() {
|
||||||
|
wl_init_parsers_wayland();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn wl_init_known_types() {
|
||||||
|
wl_init_known_types_wayland();
|
||||||
|
}
|
||||||
|
|
47
src/state.rs
47
src/state.rs
|
@ -7,7 +7,8 @@ use crate::{
|
||||||
config::Config,
|
config::Config,
|
||||||
objects::WlObjects,
|
objects::WlObjects,
|
||||||
proto::{
|
proto::{
|
||||||
WL_REGISTRY, WlDisplayGetRegistryRequest, WlRegistryBindRequest, WlRegistryGlobalEvent,
|
WL_REGISTRY, WlDisplayDeleteIdEvent, WlDisplayGetRegistryRequest, WlRegistryBindRequest,
|
||||||
|
WlRegistryGlobalEvent,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -25,9 +26,18 @@ impl WlMitmState {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument(skip_all)]
|
#[tracing::instrument(skip_all)]
|
||||||
pub fn on_c2s_request(&mut self, msg: &WlRawMsg) -> bool {
|
pub fn on_c2s_request(&mut self, raw_msg: &WlRawMsg) -> bool {
|
||||||
decode_and_match_msg!(
|
let msg = crate::proto::decode_request(&self.objects, raw_msg);
|
||||||
self.objects,
|
if let crate::proto::WaylandProtocolParsingOutcome::MalformedMessage = msg {
|
||||||
|
debug!(
|
||||||
|
obj_id = raw_msg.obj_id,
|
||||||
|
opcode = raw_msg.opcode,
|
||||||
|
"Malformed request"
|
||||||
|
);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
match_decoded! {
|
||||||
match msg {
|
match msg {
|
||||||
WlDisplayGetRegistryRequest => {
|
WlDisplayGetRegistryRequest => {
|
||||||
self.objects.record_object(WL_REGISTRY, msg.registry);
|
self.objects.record_object(WL_REGISTRY, msg.registry);
|
||||||
|
@ -36,22 +46,36 @@ impl WlMitmState {
|
||||||
let Some(interface) = self.objects.lookup_global(msg.name) else {
|
let Some(interface) = self.objects.lookup_global(msg.name) else {
|
||||||
return false;
|
return false;
|
||||||
};
|
};
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
interface = interface,
|
interface = interface,
|
||||||
obj_id = msg.id,
|
obj_id = msg.id,
|
||||||
"Client binding interface"
|
"Client binding interface"
|
||||||
);
|
);
|
||||||
|
|
||||||
|
if let Some(t) = crate::proto::lookup_known_object_type(interface) {
|
||||||
|
//self.objects.record_object(t, msg.id);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
);
|
}
|
||||||
|
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument(skip_all)]
|
#[tracing::instrument(skip_all)]
|
||||||
pub fn on_s2c_event(&mut self, msg: &WlRawMsg) -> bool {
|
pub fn on_s2c_event(&mut self, raw_msg: &WlRawMsg) -> bool {
|
||||||
decode_and_match_msg!(
|
let msg = crate::proto::decode_event(&self.objects, raw_msg);
|
||||||
self.objects,
|
if let crate::proto::WaylandProtocolParsingOutcome::MalformedMessage = msg {
|
||||||
|
debug!(
|
||||||
|
obj_id = raw_msg.obj_id,
|
||||||
|
opcode = raw_msg.opcode,
|
||||||
|
"Malformed event"
|
||||||
|
);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
match_decoded! {
|
||||||
match msg {
|
match msg {
|
||||||
WlRegistryGlobalEvent => {
|
WlRegistryGlobalEvent => {
|
||||||
debug!(
|
debug!(
|
||||||
|
@ -71,8 +95,13 @@ impl WlMitmState {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
WlDisplayDeleteIdEvent => {
|
||||||
|
// When an object is acknowledged to be deleted, remove it from our
|
||||||
|
// internal cache of all registered objects
|
||||||
|
//self.objects.remove_object(msg.id);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
);
|
}
|
||||||
|
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue