Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 100 additions & 21 deletions askama_derive/src/filter_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

use std::ops::ControlFlow;

use proc_macro2::{Ident, Span, TokenStream};
use proc_macro2::{Ident, Span, TokenStream, TokenTree};
use quote::{ToTokens, format_ident, quote, quote_spanned};
use syn::punctuated::Punctuated;
use syn::spanned::Spanned;
Expand Down Expand Up @@ -79,19 +79,45 @@ struct FilterArgumentOptional {
default: Expr,
}

/// Internal representation for a filter function's lifetime.
#[derive(Clone)]
struct FilterLifetime {
lifetime: Lifetime,
bounds: Punctuated<Lifetime, Token![+]>,
used_by_extra_args: bool,
}

/// Internal representation for a filter function's generic argument.
#[derive(Clone)]
struct FilterArgumentGeneric {
ident: Ident,
bounds: Punctuated<TypeParamBound, Token![+]>,
}

fn get_lifetimes(stream: TokenStream, lifetimes: &mut HashSet<Ident>) {
let mut iterator = stream.into_iter().peekable();
while let Some(token) = iterator.next() {
match token {
TokenTree::Group(g) => get_lifetimes(g.stream(), lifetimes),
TokenTree::Punct(p) if p.as_char() == '\'' => {
// Lifetimes are represented as `[Punct('), Ident("a")]` in the `TokenStream`.
if let Some(TokenTree::Ident(i)) = iterator.peek() {
lifetimes.insert(i.clone());
}
}
TokenTree::Punct(_) | TokenTree::Ident(_) | TokenTree::Literal(_) => continue,
}
}
}

/// A freestanding method annotated with `askama::filter_fn` is parsed into an instance of this
/// struct, and then the resulting code is generated from there.
/// This struct serves as an intermediate representation after some preprocessing on the raw AST.
struct FilterSignature {
/// Name of the annotated freestanding filter function
ident: Ident,
/// Lifetime bounds.
lifetimes: Vec<FilterLifetime>,
/// Name of the input variable
arg_input: FilterArgumentRequired,
/// Name of the askama environment variable
Expand Down Expand Up @@ -127,9 +153,6 @@ impl FilterSignature {
if let Some(gc_arg) = sig.generics.const_params().next() {
p_err!(gc_arg.span() => "Const generics are currently not supported for filters")?;
}
if let Some(gl_arg) = sig.generics.lifetimes().next() {
p_err!(gl_arg.span() => "Lifetime generics are currently not supported for filters")?;
}
p_assert!(
matches!(sig.output, ReturnType::Type(_, _)),
sig.paren_token.span.close() => "Filter function is missing return type"
Expand Down Expand Up @@ -161,6 +184,7 @@ impl FilterSignature {
let mut args_required = vec![];
let mut args_optional = vec![];
let mut args_required_generics = HashMap::default();
let mut lifetimes_used_in_non_required = HashSet::default();
for (arg_idx, arg) in sig.inputs.iter().skip(2).enumerate() {
let FnArg::Typed(arg) = arg else {
continue;
Expand All @@ -172,6 +196,7 @@ impl FilterSignature {
!matches!(*arg.ty, Type::ImplTrait(_)),
arg.ty.span() => "Impl generics are currently not supported for filters"
)?;
get_lifetimes(arg.to_token_stream(), &mut lifetimes_used_in_non_required);

// reference-parameters without explicit lifetime, inherit the 'filter lifetime
let arg_type = patch_ref_with_lifetime(&arg.ty, &format_ident!("filter"));
Expand Down Expand Up @@ -220,11 +245,27 @@ impl FilterSignature {
}
}
}
// lifetimes
let lifetimes = sig
.generics
.lifetimes()
.map(|lt| {
let lifetime = lt.lifetime.clone();
let bounds = lt.bounds.clone();
let used_by_extra_args = lifetimes_used_in_non_required.contains(&lifetime.ident);
FilterLifetime {
lifetime,
bounds,
used_by_extra_args,
}
})
.collect::<Vec<_>>();

// ########################################

Ok(FilterSignature {
ident: sig.ident.clone(),
lifetimes,
arg_input,
arg_input_generics,
arg_env,
Expand Down Expand Up @@ -284,6 +325,36 @@ impl FilterSignature {
// code generation
// ##############################################################################################
impl FilterSignature {
/// Returns a tuple containing two items:
///
/// 1. The list of lifetimes with their bounds.
/// 2. The list of lifetimes without their bounds.
fn lifetimes_bounds<F: Fn(&FilterLifetime) -> bool>(
&self,
filter: F,
) -> (Vec<TokenStream>, Vec<&Lifetime>) {
let mut lifetimes = Vec::with_capacity(self.lifetimes.len());
let mut lifetimes_no_bounds = Vec::with_capacity(self.lifetimes.len());
for lt in &self.lifetimes {
if !filter(lt) {
continue;
}
let name = &lt.lifetime;
let bounds = &lt.bounds;
lifetimes.push(quote! { #name: #bounds });
lifetimes_no_bounds.push(name);
}
(lifetimes, lifetimes_no_bounds)
}

fn lifetimes_fillers<F: Fn(&FilterLifetime) -> bool>(&self, filter: F) -> Vec<TokenStream> {
self.lifetimes
.iter()
.filter(|l| filter(l))
.map(|_| quote! { '_ })
.collect()
}

/// Generates a struct named after the filter function.
/// This struct will contain all the filter's arguments (except input and env).
/// The struct is basically a builder pattern for the custom filter arguments.
Expand Down Expand Up @@ -325,16 +396,18 @@ impl FilterSignature {
let required_arg_cnt = self.args_required.len();
let optional_arg_cnt = self.args_optional.len();
let arg_cnt = required_arg_cnt + optional_arg_cnt;
let lifetimes_fillers = self.lifetimes_fillers(|l| l.used_by_extra_args);
let valid_arg_impls = (0..arg_cnt).map(|idx| {
quote! {
#[diagnostic::do_not_recommend]
impl askama::filters::ValidArgIdx<#idx> for #ident<'_> {}
impl askama::filters::ValidArgIdx<#idx> for #ident<'_, #(#lifetimes_fillers,)*> {}
}
});

let (_, lifetimes) = self.lifetimes_bounds(|l| l.used_by_extra_args);
quote! {
#[allow(non_camel_case_types)]
#vis struct #ident<'filter, #(#struct_generics = (),)* #(const #required_flags : bool = false,)*> {
#vis struct #ident<'filter, #(#lifetimes,)* #(#struct_generics = (),)* #(const #required_flags : bool = false,)*> {
_lifetime: std::marker::PhantomData<&'filter ()>,
/* required fields */
#(#required_fields,)*
Expand Down Expand Up @@ -366,9 +439,10 @@ impl FilterSignature {
let value = &a.default;
quote! { #ident: #value }
});
let lifetimes_fillers = self.lifetimes_fillers(|l| l.used_by_extra_args);

quote! {
impl std::default::Default for #ident<'_> {
impl std::default::Default for #ident<'_, #(#lifetimes_fillers,)*> {
fn default() -> Self {
Self {
_lifetime: std::marker::PhantomData::default(),
Expand Down Expand Up @@ -441,6 +515,7 @@ impl FilterSignature {
quote! { #ident: #bounds }
})
.collect();
let (_, lifetimes_no_bounds) = self.lifetimes_bounds(|l| l.used_by_extra_args);
// return type
let fn_return_ty = {
let required_generics_result =
Expand All @@ -456,7 +531,7 @@ impl FilterSignature {
false => format_ident!("REQUIRED_ARG_FLAG_{}", a.idx).to_token_stream(),
}
});
quote! { #ident<'filter, #(#required_generics_result,)* #(#required_flags_result,)*> }
quote! { #ident<'filter, #(#lifetimes_no_bounds,)* #(#required_generics_result,)* #(#required_flags_result,)*> }
};
// struct fields - (all fields, except that of current argument)
let other_required_fields = self
Expand All @@ -469,8 +544,8 @@ impl FilterSignature {

quote! {
#[allow(non_camel_case_types)]
impl<'filter, #(#required_generics_impl,)* #(const #required_flags: bool,)*>
#ident<'filter, #(#required_generics_impl,)* #(#required_flags,)*> {
impl<'filter, #(#lifetimes_no_bounds,)* #(#required_generics_impl,)* #(const #required_flags: bool,)*>
#ident<'filter, #(#lifetimes_no_bounds,)* #(#required_generics_impl,)* #(#required_flags,)*> {
// named setter
#[inline(always)]
pub fn #named_ident<#(#required_generics_fn,)*>(self, new_value: #cur_arg_ty) -> #fn_return_ty {
Expand Down Expand Up @@ -530,10 +605,11 @@ impl FilterSignature {
}
});

let (_, lifetimes_no_bounds) = self.lifetimes_bounds(|l| l.used_by_extra_args);
quote! {
#[allow(non_camel_case_types)]
impl<'filter, #(#required_generics,)* #(const #required_flags: bool,)*>
#ident<'filter, #(#required_generics,)* #(#required_flags,)*> {
impl<'filter, #(#lifetimes_no_bounds,)* #(#required_generics,)* #(const #required_flags: bool,)*>
#ident<'filter, #(#lifetimes_no_bounds,)* #(#required_generics,)* #(#required_flags,)*> {
#(#optional_setters)*
}
}
Expand Down Expand Up @@ -565,6 +641,8 @@ impl FilterSignature {
let bounds = &g.bounds;
quote! { #ident: #bounds }
});
let (all_lifetimes, _) = self.lifetimes_bounds(|_| true);
let (_, type_lifetimes) = self.lifetimes_bounds(|l| l.used_by_extra_args);
// env variable
let env_ident = &self.arg_env.ident;
let env_ty = &self.arg_env.ty;
Expand Down Expand Up @@ -596,13 +674,14 @@ impl FilterSignature {
});

let impl_generics = quote! { #(#required_generics: #required_generic_bounds,)* };
let impl_struct_generics = quote! { '_, #(#required_generics,)* #(#required_flags,)* };
let impl_struct_generics = quote! { #(#required_generics,)* #(#required_flags,)* };
let lifetimes_fillers = self.lifetimes_fillers(|l| l.used_by_extra_args);
quote! {
// if all required arguments have been supplied (P0 == true, P1 == true)
// ... the execute() method is "unlocked":
impl<#impl_generics> #ident<#impl_struct_generics> {
impl<#(#all_lifetimes,)* #impl_generics> #ident<'_, #(#type_lifetimes,)* #impl_struct_generics> {
#[inline(always)]
pub fn execute<#(#input_bounds,)*>(self, #input_mutability #input_ident: #input_ty, #env_ident: #env_ty) #result_ty {
pub fn execute< #(#input_bounds,)*>(self, #input_mutability #input_ident: #input_ty, #env_ident: #env_ty) #result_ty {
// map filter variables with original name into scope
#( #required_args )*
#( #optional_args )*
Expand All @@ -611,7 +690,7 @@ impl FilterSignature {
}
}

impl<#impl_generics> askama::filters::ValidFilterInvocation for #ident<#impl_struct_generics> {}
impl<#impl_generics> askama::filters::ValidFilterInvocation for #ident<'_, #(#lifetimes_fillers,)* #impl_struct_generics> {}
}
}
}
Expand All @@ -626,12 +705,12 @@ fn filter_fn_impl(attr: TokenStream, ffn: &ItemFn) -> Result<TokenStream, Compil

let fsig = FilterSignature::try_from_signature(&ffn.sig)?;

let mut arg_generics = HashMap::default();
for gp in &ffn.sig.generics.params {
if let GenericParam::Type(gp) = gp {
arg_generics.insert(gp.ident.clone(), gp.clone());
} else {
p_err!(gp.span() => "Only type generic arguments supported for now")?;
match gp {
GenericParam::Type(_) | GenericParam::Lifetime(_) => {}
GenericParam::Const(_) => {
p_err!(gp.span() => "Const generic arguments are not supported for now")?;
}
}
}

Expand Down
28 changes: 27 additions & 1 deletion testing/tests/filters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,6 @@ fn test_custom_filter_constructs() {
#[test]
fn filter_arguments_mutability() {
mod filters {

// Check mutability is kept for mandatory arguments.
#[askama::filter_fn]
pub fn a(mut value: u32, _: &dyn askama::Values) -> askama::Result<String> {
Expand Down Expand Up @@ -691,3 +690,30 @@ fn filter_arguments_mutability() {

assert_eq!(X.render().unwrap(), "2 9 4");
}

// Checks support for lifetimes.
#[test]
fn filter_lifetimes() {
mod filters {
use std::borrow::Cow;

#[askama::filter_fn]
pub fn a<'a: 'b, 'b>(
value: &'a str,
_: &dyn askama::Values,
extra: &'b str,
) -> askama::Result<Cow<'a, str>> {
if extra.is_empty() {
Ok(Cow::Borrowed(value))
} else {
Ok(Cow::Owned(format!("{value}-{extra}")))
}
}
}

#[derive(Template)]
#[template(ext = "txt", source = r#"{{ "a"|a("b") }}"#)]
struct X;

assert_eq!(X.render().unwrap(), "a-b");
}
5 changes: 0 additions & 5 deletions testing/tests/ui/filter-signature-validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,6 @@ mod missing_required_args {
pub fn filter2(_: &dyn askama::Values) -> askama::Result<String> {}
}

mod lifetime_args {
#[askama::filter_fn]
pub fn filter0<'a>(input: usize, _: &dyn askama::Values, arg: &'a ()) -> askama::Result<String> {}
}

mod const_generic_args {
#[askama::filter_fn]
pub fn filter0<const T: bool>(input: usize, _: &dyn askama::Values) -> askama::Result<String> {}
Expand Down
Loading