diff --git a/traversable/src/combinator.rs b/traversable/src/combinator.rs new file mode 100644 index 0000000..2476666 --- /dev/null +++ b/traversable/src/combinator.rs @@ -0,0 +1,203 @@ +// Copyright 2025 FastLabs Developers +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Combinators for Visitors. + +use core::ops::ControlFlow; + +use crate::Visitor; +use crate::VisitorMut; + +/// Extension trait for [`Visitor`]s. +pub trait VisitorExt: Visitor { + /// Combines two visitors into one that runs them in sequence. + /// + /// If the first visitor returns [`ControlFlow::Break`], the combined visitor returns that break + /// immediately. Otherwise, it runs the second visitor. + /// + /// # Examples + /// + /// ``` + /// # #[cfg(not(feature = "derive"))] + /// # fn main() {} + /// # + /// # #[cfg(feature = "derive")] + /// # fn main() { + /// use std::ops::ControlFlow; + /// + /// use traversable::Traversable; + /// use traversable::Visitor; + /// use traversable::combinator::VisitorExt; + /// use traversable::function::make_visitor_enter; + /// + /// #[derive(Traversable)] + /// struct Foo(i32); + /// + /// #[derive(Traversable)] + /// struct Bar(i32); + /// + /// #[derive(Traversable)] + /// struct Data { + /// foo: Foo, + /// bar: Bar, + /// } + /// + /// let data = Data { + /// foo: Foo(1), + /// bar: Bar(2), + /// }; + /// + /// let v1 = make_visitor_enter(|foo: &Foo| { + /// println!("Visiting Foo: {}", foo.0); + /// ControlFlow::<()>::Continue(()) + /// }); + /// + /// let v2 = make_visitor_enter(|bar: &Bar| { + /// println!("Visiting Bar: {}", bar.0); + /// ControlFlow::<()>::Continue(()) + /// }); + /// + /// // v1 runs first, then v2. + /// let mut combined = v1.or(v2); + /// data.traverse(&mut combined); + /// # } + /// ``` + fn or(self, other: V) -> OrVisitor + where + Self: Sized, + V: Visitor, + { + OrVisitor { + visitor1: self, + visitor2: other, + } + } +} + +impl VisitorExt for V {} + +/// Extension trait for [`VisitorMut`]s. +pub trait VisitorMutExt: VisitorMut { + /// Combines two mutable visitors into one that runs them in sequence. + /// + /// If the first visitor returns [`ControlFlow::Break`], the combined visitor returns that break + /// immediately. Otherwise, it runs the second visitor. + /// + /// # Examples + /// + /// ``` + /// # #[cfg(not(feature = "derive"))] + /// # fn main() {} + /// # + /// # #[cfg(feature = "derive")] + /// # fn main() { + /// use std::ops::ControlFlow; + /// + /// use traversable::TraversableMut; + /// use traversable::VisitorMut; + /// use traversable::combinator::VisitorMutExt; + /// use traversable::function::make_visitor_enter_mut; + /// + /// #[derive(TraversableMut)] + /// struct Foo(i32); + /// + /// #[derive(TraversableMut)] + /// struct Bar(i32); + /// + /// #[derive(TraversableMut)] + /// struct Data { + /// foo: Foo, + /// bar: Bar, + /// } + /// + /// let mut data = Data { + /// foo: Foo(1), + /// bar: Bar(2), + /// }; + /// + /// let v1 = make_visitor_enter_mut(|foo: &mut Foo| { + /// foo.0 += 1; + /// ControlFlow::<()>::Continue(()) + /// }); + /// + /// let v2 = make_visitor_enter_mut(|bar: &mut Bar| { + /// bar.0 *= 2; + /// ControlFlow::<()>::Continue(()) + /// }); + /// + /// let mut combined = v1.or(v2); + /// data.traverse_mut(&mut combined); + /// + /// assert_eq!(data.foo.0, 2); + /// assert_eq!(data.bar.0, 4); + /// # } + /// ``` + fn or(self, other: V) -> OrVisitor + where + Self: Sized, + V: VisitorMut, + { + OrVisitor { + visitor1: self, + visitor2: other, + } + } +} + +impl VisitorMutExt for V {} + +/// A visitor that runs two visitors in sequence. +/// +/// This struct is created by the [`or`](VisitorExt::or) method on [`VisitorExt`] or +/// [`VisitorMutExt`]. +pub struct OrVisitor { + visitor1: V1, + visitor2: V2, +} + +impl Visitor for OrVisitor +where + V1: Visitor, + V2: Visitor, +{ + type Break = V1::Break; + + fn enter(&mut self, this: &dyn core::any::Any) -> ControlFlow { + self.visitor1.enter(this)?; + self.visitor2.enter(this) + } + + fn leave(&mut self, this: &dyn core::any::Any) -> ControlFlow { + self.visitor1.leave(this)?; + self.visitor2.leave(this) + } +} + +impl VisitorMut for OrVisitor +where + V1: VisitorMut, + V2: VisitorMut, +{ + type Break = V1::Break; + + fn enter_mut(&mut self, this: &mut dyn core::any::Any) -> ControlFlow { + self.visitor1.enter_mut(this)?; + self.visitor2.enter_mut(this) + } + + fn leave_mut(&mut self, this: &mut dyn core::any::Any) -> ControlFlow { + self.visitor1.leave_mut(this)?; + self.visitor2.leave_mut(this) + } +} diff --git a/traversable/src/lib.rs b/traversable/src/lib.rs index 9626cce..0bc2992 100644 --- a/traversable/src/lib.rs +++ b/traversable/src/lib.rs @@ -30,6 +30,7 @@ pub use traversable_derive::Traversable; /// See [`TraversableMut`]. pub use traversable_derive::TraversableMut; +pub mod combinator; pub mod function; /// Implementations for third-party library types. diff --git a/traversable/tests/test_combinator.rs b/traversable/tests/test_combinator.rs new file mode 100644 index 0000000..f53125c --- /dev/null +++ b/traversable/tests/test_combinator.rs @@ -0,0 +1,212 @@ +// Copyright 2025 FastLabs Developers +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#![cfg(all(feature = "std", feature = "derive"))] + +use std::any::Any; +use std::cell::RefCell; +use std::ops::ControlFlow; +use std::rc::Rc; + +use traversable::Traversable; +use traversable::TraversableMut; +use traversable::Visitor; +use traversable::VisitorMut; +use traversable::combinator::VisitorExt; +use traversable::combinator::VisitorMutExt; + +#[derive(Traversable, TraversableMut)] +struct Data { + val: i32, + child: Option>, +} + +#[derive(Clone)] +struct RecordVisitor { + log: Rc>>, + name: String, + break_on: Option, +} + +impl Visitor for RecordVisitor { + type Break = i32; + + fn enter(&mut self, this: &dyn Any) -> ControlFlow { + if let Some(d) = this.downcast_ref::() { + self.log + .borrow_mut() + .push(format!("{}: enter {}", self.name, d.val)); + if Some(d.val) == self.break_on { + return ControlFlow::Break(d.val); + } + } + ControlFlow::Continue(()) + } + + fn leave(&mut self, this: &dyn Any) -> ControlFlow { + if let Some(d) = this.downcast_ref::() { + self.log + .borrow_mut() + .push(format!("{}: leave {}", self.name, d.val)); + } + ControlFlow::Continue(()) + } +} + +#[test] +fn test_visitor_or() { + let data = Data { + val: 1, + child: Some(Box::new(Data { + val: 2, + child: None, + })), + }; + + let log = Rc::new(RefCell::new(Vec::new())); + let v1 = RecordVisitor { + log: log.clone(), + name: "v1".to_string(), + break_on: None, + }; + let v2 = RecordVisitor { + log: log.clone(), + name: "v2".to_string(), + break_on: None, + }; + + let mut combined = v1.or(v2); + let result = data.traverse(&mut combined); + assert!(result.is_continue()); + + let expected = vec![ + "v1: enter 1", + "v2: enter 1", + "v1: enter 2", + "v2: enter 2", + "v1: leave 2", + "v2: leave 2", + "v1: leave 1", + "v2: leave 1", + ]; + assert_eq!(*log.borrow(), expected); +} + +#[test] +fn test_visitor_or_break_v1() { + let data = Data { + val: 1, + child: Some(Box::new(Data { + val: 2, + child: None, + })), + }; + + let log = Rc::new(RefCell::new(Vec::new())); + // v1 breaks on 2 + let v1 = RecordVisitor { + log: log.clone(), + name: "v1".to_string(), + break_on: Some(2), + }; + let v2 = RecordVisitor { + log: log.clone(), + name: "v2".to_string(), + break_on: None, + }; + + let mut combined = v1.or(v2); + let result = data.traverse(&mut combined); + assert_eq!(result, ControlFlow::Break(2)); + + let expected = vec![ + "v1: enter 1", + "v2: enter 1", + "v1: enter 2", // v1 breaks here, v2 not called for 2, and traversal stops + ]; + assert_eq!(*log.borrow(), expected); +} + +#[test] +fn test_visitor_or_break_v2() { + let data = Data { + val: 1, + child: Some(Box::new(Data { + val: 2, + child: None, + })), + }; + + let log = Rc::new(RefCell::new(Vec::new())); + // v2 breaks on 2 + let v1 = RecordVisitor { + log: log.clone(), + name: "v1".to_string(), + break_on: None, + }; + let v2 = RecordVisitor { + log: log.clone(), + name: "v2".to_string(), + break_on: Some(2), + }; + + let mut combined = v1.or(v2); + let result = data.traverse(&mut combined); + assert_eq!(result, ControlFlow::Break(2)); + + let expected = vec![ + "v1: enter 1", + "v2: enter 1", + "v1: enter 2", + "v2: enter 2", // v2 breaks here + ]; + assert_eq!(*log.borrow(), expected); +} + +struct MutVisitor { + val_multiplier: i32, +} + +impl VisitorMut for MutVisitor { + type Break = (); + fn enter_mut(&mut self, this: &mut dyn Any) -> ControlFlow<()> { + if let Some(d) = this.downcast_mut::() { + d.val *= self.val_multiplier; + } + ControlFlow::Continue(()) + } +} + +#[test] +fn test_visitor_mut_or() { + let mut data = Data { + val: 1, + child: Some(Box::new(Data { + val: 2, + child: None, + })), + }; + + let v1 = MutVisitor { val_multiplier: 2 }; + let v2 = MutVisitor { val_multiplier: 3 }; + + let mut combined = v1.or(v2); + let result = data.traverse_mut(&mut combined); + assert!(result.is_continue()); + + // 1 * 2 * 3 = 6 + assert_eq!(data.val, 6); + // 2 * 2 * 3 = 12 + assert_eq!(data.child.unwrap().val, 12); +}