Compare commits

...

2 commits

4 changed files with 279 additions and 115 deletions

View file

@ -2,6 +2,9 @@ use proc_macro2::Span;
use quick_xml::events::Event; use quick_xml::events::Event;
use quote::{format_ident, quote}; use quote::{format_ident, quote};
use syn::{Ident, LitStr, parse_macro_input}; use syn::{Ident, LitStr, parse_macro_input};
use types::WlArgType;
mod types;
#[proc_macro] #[proc_macro]
pub fn wayland_proto_gen(item: proc_macro::TokenStream) -> proc_macro::TokenStream { pub fn wayland_proto_gen(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
@ -51,6 +54,7 @@ fn handle_interface(
.expect("No name attr found for interface"); .expect("No name attr found for interface");
let interface_name_snake = std::str::from_utf8(&name_attr.value).expect("utf8 encoding error"); let interface_name_snake = std::str::from_utf8(&name_attr.value).expect("utf8 encoding error");
let interface_name_camel = to_camel_case(interface_name_snake);
// Generate the implementation of the Wayland object type ID, consisting of a private struct // Generate the implementation of the Wayland object type ID, consisting of a private struct
// to act as a trait object, a public const that wraps the struct in `WlObjectType`, and a impl // to act as a trait object, a public const that wraps the struct in `WlObjectType`, and a impl
@ -59,7 +63,7 @@ fn handle_interface(
// struct WlDisplayTypeId; // struct WlDisplayTypeId;
// pub const WL_DISPLAY: WlObjectType = WlObjectType::new(&WlDisplayTypeId); // pub const WL_DISPLAY: WlObjectType = WlObjectType::new(&WlDisplayTypeId);
// impl WlObjectTypeId for WlDisplayTypeId { ... } // impl WlObjectTypeId for WlDisplayTypeId { ... }
let interface_type_id_name = format_ident!("{}TypeId", to_camel_case(interface_name_snake)); let interface_type_id_name = format_ident!("{}TypeId", interface_name_camel);
let interface_name_literal = LitStr::new(interface_name_snake, Span::call_site()); let interface_name_literal = LitStr::new(interface_name_snake, Span::call_site());
let interface_name_snake_upper = let interface_name_snake_upper =
Ident::new(&interface_name_snake.to_uppercase(), Span::call_site()); Ident::new(&interface_name_snake.to_uppercase(), Span::call_site());
@ -75,9 +79,42 @@ fn handle_interface(
} }
}; };
let mut event_opcode = 0;
let mut request_opcode = 0;
loop { loop {
match reader.read_event().expect("Unable to parse XML file") { match reader.read_event().expect("Unable to parse XML file") {
Event::Eof => panic!("Unexpected EOF"), Event::Eof => panic!("Unexpected EOF"),
Event::Start(e) => {
let start_tag =
str::from_utf8(e.local_name().into_inner()).expect("Unable to parse start tag");
let append = if start_tag == "event" {
event_opcode += 1;
handle_request_or_event(
reader,
&interface_name_camel,
&interface_name_snake_upper,
event_opcode - 1,
e,
)
} else if start_tag == "request" {
request_opcode += 1;
handle_request_or_event(
reader,
&interface_name_camel,
&interface_name_snake_upper,
request_opcode - 1,
e,
)
} else {
proc_macro2::TokenStream::new()
};
ret = quote! {
#ret
#append
}
}
Event::End(e) if e.local_name() == start.local_name() => break, Event::End(e) if e.local_name() == start.local_name() => break,
_ => continue, _ => continue,
} }
@ -86,6 +123,115 @@ fn handle_interface(
ret ret
} }
fn handle_request_or_event(
reader: &mut quick_xml::Reader<&[u8]>,
interface_name_camel: &str,
interface_name_snake_upper: &Ident,
opcode: u16,
start: quick_xml::events::BytesStart<'_>,
) -> proc_macro2::TokenStream {
let start_tag =
str::from_utf8(start.local_name().into_inner()).expect("Unable to parse start tag");
let start_tag_camel = to_camel_case(start_tag);
let name_attr = start
.attributes()
.map(|a| a.expect("attr parsing error"))
.find(|a| {
std::str::from_utf8(a.key.local_name().into_inner()).expect("utf8 encoding error")
== "name"
})
.expect("No name attr found for request/event");
let name_camel = to_camel_case(str::from_utf8(&name_attr.value).expect("utf8 encoding error"));
let mut args: Vec<(String, WlArgType)> = Vec::new();
loop {
match reader.read_event().expect("Unable to parse XML file") {
Event::Eof => panic!("Unexpected EOF"),
Event::Empty(e)
if str::from_utf8(e.local_name().into_inner()).expect("utf8 encoding error")
== "arg" =>
{
let mut name: Option<String> = None;
let mut tt: Option<WlArgType> = None;
for attr in e.attributes() {
let attr = attr.expect("attr parsing error");
let attr_name = str::from_utf8(attr.key.local_name().into_inner())
.expect("utf8 encoding error");
if attr_name == "name" {
name = Some(
str::from_utf8(&attr.value)
.expect("utf8 encoding error")
.to_string(),
);
} else if attr_name == "type" {
tt = Some(WlArgType::parse(
str::from_utf8(&attr.value).expect("utf8 encoding error"),
));
}
}
args.push((
name.expect("args must have a name"),
tt.expect("args must have a type"),
));
}
Event::End(e) if e.local_name() == start.local_name() => break,
_ => continue,
}
}
let (field_names, field_types): (Vec<_>, Vec<_>) = args
.iter()
.map(|(name, tt)| (format_ident!("{name}"), tt.to_rust_type()))
.unzip();
let struct_name = format_ident!("{interface_name_camel}{name_camel}{start_tag_camel}");
let struct_def = quote! {
pub struct #struct_name<'a> {
_phantom: std::marker::PhantomData<&'a ()>,
#( pub #field_names: #field_types, )*
}
};
let parser_code: Vec<_> = args
.into_iter()
.map(|(arg_name, arg_type)| {
let arg_name_ident = format_ident!("{arg_name}");
arg_type.generate_parser_code(arg_name_ident)
})
.collect();
let struct_impl = quote! {
impl<'a> WlParsedMessage<'a> for #struct_name<'a> {
fn opcode() -> u16 {
#opcode
}
fn object_type() -> WlObjectType {
#interface_name_snake_upper
}
fn try_from_msg_impl(msg: &crate::codec::WlRawMsg) -> WaylandProtocolParsingOutcome<#struct_name> {
let payload = msg.payload();
let mut pos = 0usize;
#( #parser_code )*
WaylandProtocolParsingOutcome::Ok(#struct_name {
_phantom: std::marker::PhantomData,
#( #field_names, )*
})
}
}
};
quote! {
#struct_def
#struct_impl
}
}
fn to_camel_case(s: &str) -> String { fn to_camel_case(s: &str) -> String {
s.split("_") s.split("_")
.map(|item| { .map(|item| {

123
protogen/src/types.rs Normal file
View file

@ -0,0 +1,123 @@
use quote::quote;
use syn::Ident;
pub(crate) enum WlArgType {
Int,
Uint,
Fixed,
Object,
NewId,
String,
Array,
Fd,
Enum,
}
impl WlArgType {
pub fn parse(s: &str) -> WlArgType {
match s {
"int" => WlArgType::Int,
"uint" => WlArgType::Uint,
"fixed" => WlArgType::Fixed,
"object" => WlArgType::Object,
"new_id" => WlArgType::NewId,
"string" => WlArgType::String,
"array" => WlArgType::Array,
"fd" => WlArgType::Fd,
"enum" => WlArgType::Enum,
_ => panic!("Unknown arg type!"),
}
}
pub fn to_rust_type(&self) -> proc_macro2::TokenStream {
match self {
WlArgType::Int => quote! { i32 },
// TODO: "fixed" is decoded directly as a u32. fix it
WlArgType::Uint
| WlArgType::Fixed
| WlArgType::Object
| WlArgType::NewId
| WlArgType::Enum => quote! { u32 },
WlArgType::String => quote! { &'a str },
WlArgType::Array => quote! { &'a [u8] },
WlArgType::Fd => quote! { std::os::fd::BorrowedFd<'a> },
}
}
pub fn generate_parser_code(&self, var_name: Ident) -> proc_macro2::TokenStream {
match self {
WlArgType::Int => quote! {
if payload.len() < pos + 4 {
return WaylandProtocolParsingOutcome::MalformedMessage;
}
let #var_name: i32 = byteorder::NativeEndian::read_i32(&payload[pos..pos + 4]);
pos += 4;
},
WlArgType::Uint
| WlArgType::Fixed
| WlArgType::Object
| WlArgType::NewId
| WlArgType::Enum => quote! {
if payload.len() < pos + 4 {
return WaylandProtocolParsingOutcome::MalformedMessage;
}
let #var_name: u32 = byteorder::NativeEndian::read_u32(&payload[pos..pos + 4]);
pos += 4;
},
WlArgType::String => quote! {
let #var_name: &str = {
if payload.len() < pos + 4 {
return WaylandProtocolParsingOutcome::MalformedMessage;
}
let len = byteorder::NativeEndian::read_u32(&payload[pos..pos + 4]) as usize;
pos += 4;
if payload.len() < pos + len {
return WaylandProtocolParsingOutcome::MalformedMessage;
}
let Ok(#var_name) = std::str::from_utf8(&payload[pos..pos + len - 1]) else {
return WaylandProtocolParsingOutcome::MalformedMessage;
};
pos += len;
#var_name
};
},
WlArgType::Array => quote! {
let #var_name: &[u8] = {
if payload.len() < pos + 4 {
return WaylandProtocolParsingOutcome::MalformedMessage;
}
let len = byteorder::NativeEndian::read_u32(&payload[pos..pos + 4]) as usize;
pos += 4;
if payload.len() < pos + len {
return WaylandProtocolParsingOutcome::MalformedMessage;
}
let #var_name = &payload[pos..pos + len];
pos += len;
#var_name
};
},
WlArgType::Fd => quote! {
if msg.fds.len() == 0 {
return WaylandProtocolParsingOutcome::MalformedMessage;
}
let #var_name: std::os::fd::BorrowedFd<'_> = std::os::fd::AsFd::as_fd(&msg.fds[0]);
},
}
}
}

View file

@ -2,7 +2,7 @@
// ---------- wl_display --------- // ---------- wl_display ---------
use byteorder::{ByteOrder, NativeEndian}; use byteorder::ByteOrder;
use protogen::wayland_proto_gen; use protogen::wayland_proto_gen;
use crate::{ use crate::{
@ -61,114 +61,7 @@ pub trait WlParsedMessage<'a>: Sized {
fn try_from_msg_impl(msg: &'a WlRawMsg) -> WaylandProtocolParsingOutcome<Self>; fn try_from_msg_impl(msg: &'a WlRawMsg) -> WaylandProtocolParsingOutcome<Self>;
} }
wayland_proto_gen!("proto/wayland.xml");
/// The default object ID of wl_display /// The default object ID of wl_display
pub const WL_DISPLAY_OBJECT_ID: u32 = 1; pub const WL_DISPLAY_OBJECT_ID: u32 = 1;
/// Opcode for binding the wl_registry object
pub const WL_DISPLAY_GET_REGISTRY_OPCODE: u16 = 1;
pub struct WlDisplayGetRegistry { wayland_proto_gen!("proto/wayland.xml");
pub registry_new_id: u32,
}
impl WlParsedMessage<'_> for WlDisplayGetRegistry {
fn object_type() -> WlObjectType {
WL_DISPLAY
}
fn opcode() -> u16 {
WL_DISPLAY_GET_REGISTRY_OPCODE
}
fn try_from_msg_impl(msg: &WlRawMsg) -> WaylandProtocolParsingOutcome<WlDisplayGetRegistry> {
let payload = msg.payload();
if payload.len() != 4 {
return WaylandProtocolParsingOutcome::MalformedMessage;
}
WaylandProtocolParsingOutcome::Ok(WlDisplayGetRegistry {
registry_new_id: NativeEndian::read_u32(msg.payload()),
})
}
}
// ---------- wl_registry ---------
/// Opcode for server->client "global" events
pub const WL_REGISTRY_GLOBAL_OPCODE: u16 = 0;
/// Opcode for client->server "bind" requests
pub const WL_REGISTRY_BIND_OPCODE: u16 = 0;
pub struct WlRegistryGlobalEvent<'a> {
pub name: u32,
pub interface: &'a str,
pub version: u32,
}
impl<'a> WlParsedMessage<'a> for WlRegistryGlobalEvent<'a> {
fn opcode() -> u16 {
WL_REGISTRY_GLOBAL_OPCODE
}
fn object_type() -> WlObjectType {
WL_REGISTRY
}
fn try_from_msg_impl(
msg: &'a WlRawMsg,
) -> WaylandProtocolParsingOutcome<WlRegistryGlobalEvent<'a>> {
let payload = msg.payload();
if payload.len() < 8 {
return WaylandProtocolParsingOutcome::MalformedMessage;
}
let name = NativeEndian::read_u32(&payload[0..4]);
let interface_len = NativeEndian::read_u32(&payload[4..8]);
if interface_len + 4 >= payload.len() as u32 {
return WaylandProtocolParsingOutcome::MalformedMessage;
}
let version = NativeEndian::read_u32(&payload[payload.len() - 4..]);
// -1 because of 0-terminator
let Ok(interface) = std::str::from_utf8(&payload[8..8 + interface_len as usize - 1]) else {
return WaylandProtocolParsingOutcome::MalformedMessage;
};
WaylandProtocolParsingOutcome::Ok(WlRegistryGlobalEvent {
name,
interface,
version,
})
}
}
pub struct WlRegistryBind {
pub name: u32,
pub new_id: u32,
}
impl<'a> WlParsedMessage<'a> for WlRegistryBind {
fn opcode() -> u16 {
WL_REGISTRY_BIND_OPCODE
}
fn object_type() -> WlObjectType {
WL_REGISTRY
}
fn try_from_msg_impl(msg: &'a WlRawMsg) -> WaylandProtocolParsingOutcome<WlRegistryBind> {
let payload = msg.payload();
if payload.len() < 8 {
return WaylandProtocolParsingOutcome::MalformedMessage;
}
let name = NativeEndian::read_u32(&payload[..4]);
let new_id = NativeEndian::read_u32(&payload[4..8]);
WaylandProtocolParsingOutcome::Ok(WlRegistryBind { name, new_id })
}
}

View file

@ -6,7 +6,9 @@ use crate::{
codec::WlRawMsg, codec::WlRawMsg,
config::Config, config::Config,
objects::WlObjects, objects::WlObjects,
proto::{WL_REGISTRY, WlDisplayGetRegistry, WlRegistryBind, WlRegistryGlobalEvent}, proto::{
WL_REGISTRY, WlDisplayGetRegistryRequest, WlRegistryBindRequest, WlRegistryGlobalEvent,
},
}; };
pub struct WlMitmState { pub struct WlMitmState {
@ -27,16 +29,16 @@ impl WlMitmState {
decode_and_match_msg!( decode_and_match_msg!(
self.objects, self.objects,
match msg { match msg {
WlDisplayGetRegistry => { WlDisplayGetRegistryRequest => {
self.objects.record_object(WL_REGISTRY, msg.registry_new_id); self.objects.record_object(WL_REGISTRY, msg.registry);
} }
WlRegistryBind => { WlRegistryBindRequest => {
let Some(interface) = self.objects.lookup_global(msg.name) else { let Some(interface) = self.objects.lookup_global(msg.name) else {
return false; return false;
}; };
info!( info!(
interface = interface, interface = interface,
obj_id = msg.new_id, obj_id = msg.id,
"Client binding interface" "Client binding interface"
); );
} }