wl-mitm/protogen/src/types.rs

474 lines
16 KiB
Rust

use proc_macro2::Span;
use quote::{format_ident, quote};
use syn::{Ident, LitStr};
pub(crate) struct WlInterface {
pub name_snake: String,
pub msgs: Vec<WlMsg>,
}
impl WlInterface {
/// Name of the interface type's const representation, e.g. WL_WAYLAND
/// This can be used as a discriminant for interface types in Rust
pub fn type_const_name(&self) -> String {
self.name_snake.to_uppercase()
}
pub fn generate(&self) -> proc_macro2::TokenStream {
// Generate struct and parser impls for all messages belonging to this interface
let msg_impl = self.msgs.iter().map(|msg| msg.generate_struct_and_impl());
// Also generate a struct representing the type of this interface
// This is used to keep track of all objects in [objects]
// Example:
// struct WlDisplayTypeId;
// pub const WL_DISPLAY: WlObjectType = WlObjectType::new(&WlDisplayTypeId);
// impl WlObjectTypeId for WlDisplayTypeId { ... }
let interface_type_id_name =
format_ident!("{}TypeId", crate::to_camel_case(&self.name_snake));
let interface_name_literal = LitStr::new(&self.name_snake, Span::call_site());
let type_const_name = format_ident!("{}", self.type_const_name());
quote! {
struct #interface_type_id_name;
pub const #type_const_name: crate::objects::WlObjectType = crate::objects::WlObjectType::new(&#interface_type_id_name);
impl crate::objects::WlObjectTypeId for #interface_type_id_name {
fn interface(&self) -> &'static str {
#interface_name_literal
}
}
#( #msg_impl )*
}
}
}
pub(crate) enum WlMsgType {
Request,
Event,
}
impl WlMsgType {
fn as_str(&self) -> &'static str {
match self {
WlMsgType::Request => "Request",
WlMsgType::Event => "Event",
}
}
}
pub(crate) struct WlMsg {
pub interface_name_snake: String,
pub name_snake: String,
pub msg_type: WlMsgType,
pub opcode: u16,
pub is_destructor: bool,
pub args: Vec<(String, WlArgType)>,
}
impl WlMsg {
/// Get the name of the structure generated for this message
/// e.g. WlRegistryBindRequest
pub fn struct_name(&self) -> String {
format!(
"{}{}{}",
crate::to_camel_case(&self.interface_name_snake),
crate::to_camel_case(&self.name_snake),
self.msg_type.as_str()
)
}
pub fn parser_fn_name(&self) -> String {
format!("{}ParserFn", self.struct_name())
}
/// Generates a struct corresponding to the message type and a impl for [WlParsedMessage]
/// that includes a parser
pub fn generate_struct_and_impl(&self) -> proc_macro2::TokenStream {
let opcode = self.opcode;
let interface_name_snake_upper =
format_ident!("{}", self.interface_name_snake.to_uppercase());
let msg_name_snake = &self.name_snake;
let struct_name = format_ident!("{}", self.struct_name());
let parser_fn_name = format_ident!("{}", self.parser_fn_name());
// Build all field names and their corresponding Rust type identifiers
let (field_names, (field_types, field_attrs)): (Vec<_>, (Vec<_>, Vec<_>)) = self
.args
.iter()
.map(|(name, tt)| {
(
format_ident!("{name}"),
(
tt.to_rust_type(),
match tt {
// Can't serialize fds!
WlArgType::Fd => quote! { #[serde(skip)] },
_ => quote! {},
},
),
)
})
.unzip();
let num_consumed_fds = self
.args
.iter()
.filter(|(_, tt)| matches!(tt, WlArgType::Fd))
.count();
// Generate code to include in the parser / builder for every field
let (parser_code, builder_code): (Vec<_>, Vec<_>) = self
.args
.iter()
.map(|(arg_name, arg_type)| {
let arg_name_ident = format_ident!("{arg_name}");
(
arg_type.generate_parser_code(&arg_name_ident),
arg_type.generate_builder_code(&arg_name_ident),
)
})
.unzip();
// Collect new objects created in this msg with a known object type (interface)
let (new_id_name, new_id_type): (Vec<_>, Vec<_>) = self
.args
.iter()
.filter_map(|it| match it.1 {
WlArgType::NewId(Some(ref interface)) => Some((
format_ident!("{}", it.0),
format_ident!("{}", interface.to_uppercase()),
)),
_ => None,
})
.unzip();
let known_objects_created = if new_id_name.len() > 0 {
quote! {
Some(vec![
#( (self.#new_id_name, crate::proto::#new_id_type), )*
])
}
} else {
quote! {
None
}
};
let is_destructor = self.is_destructor;
quote! {
#[allow(unused, non_snake_case)]
#[derive(Serialize)]
pub struct #struct_name<'a> {
#[serde(skip)]
_phantom: std::marker::PhantomData<&'a ()>,
obj_id: u32,
#( #field_attrs pub #field_names: #field_types, )*
}
impl<'a> #struct_name<'a> {
#[allow(unused, non_snake_case)]
pub fn new(
obj_id: u32,
#( #field_names: #field_types, )*
) -> Self {
Self {
_phantom: std::marker::PhantomData,
obj_id,
#( #field_names, )*
}
}
}
impl<'a> crate::proto::__private::WlParsedMessagePrivate for #struct_name<'a> {}
impl<'a> crate::proto::WlParsedMessage<'a> for #struct_name<'a> {
fn opcode() -> u16 {
#opcode
}
fn object_type() -> crate::objects::WlObjectType {
crate::proto::#interface_name_snake_upper
}
fn msg_name() -> &'static str {
#msg_name_snake
}
fn is_destructor() -> bool {
#is_destructor
}
fn num_consumed_fds() -> usize {
#num_consumed_fds
}
#[allow(unused, private_interfaces, non_snake_case)]
fn try_from_msg_impl(
msg: &crate::codec::WlRawMsg, _token: crate::proto::__private::WlParsedMessagePrivateToken
) -> crate::proto::WaylandProtocolParsingOutcome<#struct_name> {
let payload = msg.payload();
let mut pos = 0usize;
let mut pos_fds = 0usize;
#( #parser_code )*
crate::proto::WaylandProtocolParsingOutcome::Ok(#struct_name {
_phantom: std::marker::PhantomData,
obj_id: msg.obj_id,
#( #field_names, )*
})
}
fn _obj_id(&self) -> u32 {
self.obj_id
}
fn _known_objects_created(&self) -> Option<Vec<(u32, crate::objects::WlObjectType)>> {
#known_objects_created
}
fn _to_json(&self) -> String {
serde_json::to_string(self).unwrap()
}
}
unsafe impl<'a> crate::proto::DowncastableWlParsedMessage<'a> for #struct_name<'a> {
type Static = #struct_name<'static>;
}
pub struct #parser_fn_name;
impl crate::proto::WlMsgParserFn for #parser_fn_name {
fn try_from_msg<'obj, 'msg>(
&self,
objects: &'obj crate::proto::WlObjects,
msg: &'msg crate::codec::WlRawMsg,
) -> crate::proto::WaylandProtocolParsingOutcome<Box<dyn crate::proto::AnyWlParsedMessage + 'msg>> {
#struct_name::try_from_msg(objects, msg).map(|r| Box::new(r) as Box<_>)
}
}
impl<'a> crate::proto::WlConstructableMessage<'a> for #struct_name<'a> {
#[allow(unused, non_snake_case)]
fn build_inner(&self, buf: &mut bytes::BytesMut, fds: &mut Vec<std::os::fd::OwnedFd>) {
use bytes::BufMut;
use std::os::fd::BorrowedFd;
#( #builder_code )*
}
}
}
}
}
pub(crate) enum WlArgType {
Int,
Uint,
Fixed,
Object,
NewId(Option<String>),
String,
Array,
Fd,
Enum,
}
impl WlArgType {
pub fn parse(s: &str) -> WlArgType {
match s {
"int" => WlArgType::Int,
"uint" => WlArgType::Uint,
"fixed" => WlArgType::Fixed,
"object" => WlArgType::Object,
"new_id" => WlArgType::NewId(None),
"string" => WlArgType::String,
"array" => WlArgType::Array,
"fd" => WlArgType::Fd,
"enum" => WlArgType::Enum,
_ => panic!("Unknown arg type!"),
}
}
/// Attach a known, fixed interface name to `self`, if `self`
/// is a [WlArgType::NewId].
///
/// If a [WlArgType::NewId] does not come with a known interface
/// tag, the caller is responsible for generating the additional
/// args (interface, version) as required by Wayland's special
/// serailization format for them
///
/// We don't verify whether the interface is actually known here.
/// Rather, if it isn't known, our emitted code will refer to
/// an unknown type / const, which will cause a compile-time error.
pub fn set_interface_name(&mut self, interface: String) {
match self {
WlArgType::NewId(_) => *self = WlArgType::NewId(Some(interface)),
_ => panic!("not a new_id but got interface tag!"),
}
}
/// What's the Rust type corresponding to this WL protocol type?
/// Returned as a token that can be used directly in quote! {}
pub fn to_rust_type(&self) -> proc_macro2::TokenStream {
match self {
WlArgType::Int => quote! { i32 },
WlArgType::Uint | WlArgType::Object | WlArgType::NewId(_) | WlArgType::Enum => {
quote! { u32 }
}
WlArgType::Fixed => quote! { fixed::types::I24F8 }, // wl fixed point is 24.8 signed
WlArgType::String => quote! { &'a str },
WlArgType::Array => quote! { &'a [u8] },
WlArgType::Fd => quote! { std::os::fd::BorrowedFd<'a> },
}
}
/// Generate code to be inserted into the parsing function. The parsing function is expected
/// to set up two variables (with `msg` as the input [WlRawMsg]):
///
/// let payload: &[u8] = msg.payload();
/// let mut pos: usize = 0;
///
/// `pos` records where we last read in the payload.
///
/// Code generated here will set up a variable with `var_name` containing the parsed result
/// of the current argument. This `var_name` can then be used later to construct the event or
/// request's struct.
pub fn generate_parser_code(&self, var_name: &Ident) -> proc_macro2::TokenStream {
match self {
WlArgType::Int => quote! {
if payload.len() < pos + 4 {
return crate::proto::WaylandProtocolParsingOutcome::MalformedMessage;
}
let #var_name: i32 = byteorder::NativeEndian::read_i32(&payload[pos..pos + 4]);
pos += 4;
},
WlArgType::Uint | WlArgType::Object | WlArgType::NewId(_) | WlArgType::Enum => quote! {
if payload.len() < pos + 4 {
return crate::proto::WaylandProtocolParsingOutcome::MalformedMessage;
}
let #var_name: u32 = byteorder::NativeEndian::read_u32(&payload[pos..pos + 4]);
pos += 4;
},
WlArgType::Fixed => quote! {
if payload.len() < pos + 4 {
return crate::proto::WaylandProtocolParsingOutcome::MalformedMessage;
}
let #var_name = fixed::types::I24F8::from_bits(byteorder::NativeEndian::read_i32(&payload[pos..pos + 4]));
pos += 4;
},
WlArgType::String => quote! {
let #var_name: &str = {
if payload.len() < pos + 4 {
return crate::proto::WaylandProtocolParsingOutcome::MalformedMessage;
}
let len = byteorder::NativeEndian::read_u32(&payload[pos..pos + 4]) as usize;
pos += 4;
if len == 0 {
""
} else {
if payload.len() < pos + len {
return crate::proto::WaylandProtocolParsingOutcome::MalformedMessage;
}
let Ok(#var_name) = std::str::from_utf8(&payload[pos..pos + len - 1]) else {
return crate::proto::WaylandProtocolParsingOutcome::MalformedMessage;
};
if len % 4 == 0 {
pos += len;
} else {
pos += len + (4 - len % 4);
}
#var_name
}
};
},
WlArgType::Array => quote! {
let #var_name: &[u8] = {
if payload.len() < pos + 4 {
return crate::proto::WaylandProtocolParsingOutcome::MalformedMessage;
}
let len = byteorder::NativeEndian::read_u32(&payload[pos..pos + 4]) as usize;
if len == 0 {
&[]
} else {
pos += 4;
if payload.len() < pos + len {
return crate::proto::WaylandProtocolParsingOutcome::MalformedMessage;
}
let #var_name = &payload[pos..pos + len];
if len % 4 == 0 {
pos += len;
} else {
pos += len + (4 - len % 4);
}
#var_name
}
};
},
WlArgType::Fd => quote! {
if msg.fds.len() < pos_fds + 1 {
return crate::proto::WaylandProtocolParsingOutcome::MalformedMessage;
}
let #var_name: std::os::fd::BorrowedFd<'_> = std::os::fd::AsFd::as_fd(&msg.fds[pos_fds]);
pos_fds += 1;
},
}
}
pub fn generate_builder_code(&self, var_name: &Ident) -> proc_macro2::TokenStream {
match self {
WlArgType::Int => quote! {
buf.put_i32_ne(self.#var_name);
},
WlArgType::Uint | WlArgType::Object | WlArgType::NewId(_) | WlArgType::Enum => quote! {
buf.put_u32_ne(self.#var_name);
},
WlArgType::Fixed => quote! {
buf.extend_from_slice(&self.#var_name.to_ne_bytes());
},
WlArgType::String => quote! {
let bytes = self.#var_name.as_bytes();
let len = bytes.len() + 1;
buf.put_u32_ne(len as u32);
buf.extend_from_slice(bytes);
buf.put_u8(0);
if len % 4 != 0 {
buf.put_bytes(0, (4 - len % 4));
}
},
WlArgType::Array => quote! {
buf.put_u32_ne(self.#var_name.len() as u32);
buf.extend_from_slice(self.#var_name);
if self.#var_name.len() % 4 != 0 {
buf.put_bytes(0, (4 - self.#var_name.len() % 4));
}
},
WlArgType::Fd => quote! {
fds.push(self.#var_name.try_clone_to_owned().unwrap());
},
}
}
}