diff --git a/Cargo.toml b/Cargo.toml index 878eec1c..c38a1156 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,7 +48,7 @@ edition = "2021" match_same_arms = "warn" unused_async = "warn" uninlined_format_args = "warn" +manual_let_else = "warn" [workspace.lints.rust] unreachable_pub = "warn" -manual_let_else = "warn" \ No newline at end of file diff --git a/crates/orchestrator/src/api/routes/storage.rs b/crates/orchestrator/src/api/routes/storage.rs index ce56e667..57a30748 100644 --- a/crates/orchestrator/src/api/routes/storage.rs +++ b/crates/orchestrator/src/api/routes/storage.rs @@ -325,7 +325,6 @@ mod tests { use std::sync::Arc; use super::*; - use crate::plugins::StatusUpdatePlugin; use crate::{ api::tests::helper::{create_test_app_state, create_test_app_state_with_nodegroups}, models::node::{NodeStatus, OrchestratorNode}, @@ -575,12 +574,14 @@ mod tests { None, None, )); - let _ = plugin.clone().register_observer().await; - - let _ = plugin - .handle_status_change(&node, &NodeStatus::Healthy) + let _ = app_state + .store_context + .task_store + .add_observer(plugin.clone()) .await; + let _ = plugin.handle_status_change(&node).await; + let task = Task { id: Uuid::new_v4(), image: "test-image".to_string(), diff --git a/crates/orchestrator/src/api/routes/task.rs b/crates/orchestrator/src/api/routes/task.rs index 150ac2be..7cff4b6d 100644 --- a/crates/orchestrator/src/api/routes/task.rs +++ b/crates/orchestrator/src/api/routes/task.rs @@ -1,4 +1,5 @@ use crate::api::server::AppState; +use crate::plugins::node_groups::get_task_topologies; use actix_web::{ web::{self, delete, get, post, Data}, HttpResponse, Scope, @@ -64,8 +65,8 @@ async fn create_task(task: web::Json, app_state: Data) -> } }; - if let Some(group_plugin) = &app_state.node_groups_plugin { - match group_plugin.get_task_topologies(&task) { + if app_state.node_groups_plugin.is_some() { + match get_task_topologies(&task) { Ok(topology) => { if topology.is_empty() { return HttpResponse::BadRequest().json(json!({"success": false, "error": "No topology found for task but grouping plugin is active."})); diff --git a/crates/orchestrator/src/discovery/monitor.rs b/crates/orchestrator/src/discovery/monitor.rs index cbf868aa..5462f994 100644 --- a/crates/orchestrator/src/discovery/monitor.rs +++ b/crates/orchestrator/src/discovery/monitor.rs @@ -25,7 +25,7 @@ pub struct DiscoveryMonitor { heartbeats: Arc, http_client: reqwest::Client, max_healthy_nodes_with_same_endpoint: u32, - status_change_handlers: Vec>, + status_change_handlers: Vec, } impl DiscoveryMonitor { @@ -38,7 +38,7 @@ impl DiscoveryMonitor { store_context: Arc, heartbeats: Arc, max_healthy_nodes_with_same_endpoint: u32, - status_change_handlers: Vec>, + status_change_handlers: Vec, ) -> Self { Self { coordinator_wallet, diff --git a/crates/orchestrator/src/events/mod.rs b/crates/orchestrator/src/events/mod.rs deleted file mode 100644 index e3ae2175..00000000 --- a/crates/orchestrator/src/events/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -use anyhow::Result; -use shared::models::task::Task; - -pub trait TaskObserver: Send + Sync { - fn on_task_created(&self, task: &Task) -> Result<()>; - fn on_task_deleted(&self, task: Option) -> Result<()>; -} diff --git a/crates/orchestrator/src/lib.rs b/crates/orchestrator/src/lib.rs index 1efd8f48..5f82d58d 100644 --- a/crates/orchestrator/src/lib.rs +++ b/crates/orchestrator/src/lib.rs @@ -1,6 +1,5 @@ mod api; mod discovery; -mod events; mod metrics; mod models; mod node; diff --git a/crates/orchestrator/src/main.rs b/crates/orchestrator/src/main.rs index 080347ea..f9beaccb 100644 --- a/crates/orchestrator/src/main.rs +++ b/crates/orchestrator/src/main.rs @@ -154,8 +154,8 @@ async fn main() -> Result<()> { .unwrap(); let group_store_context = store_context.clone(); - let mut scheduler_plugins: Vec> = Vec::new(); - let mut status_update_plugins: Vec> = vec![]; + let mut scheduler_plugins: Vec = Vec::new(); + let mut status_update_plugins: Vec = vec![]; let mut node_groups_plugin: Option> = None; let mut webhook_plugins: Vec = vec![]; @@ -167,7 +167,7 @@ async fn main() -> Result<()> { let plugin = WebhookPlugin::new(config); let plugin_clone = plugin.clone(); webhook_plugins.push(plugin_clone); - status_update_plugins.push(Box::new(plugin)); + status_update_plugins.push(plugin.into()); info!("Plugin: Webhook plugin initialized"); } } @@ -201,26 +201,27 @@ async fn main() -> Result<()> { match serde_json::from_str::>(&node_group_configs) { Ok(configs) if !configs.is_empty() => { let node_groups_heartbeats = heartbeats.clone(); - let group_plugin = NodeGroupsPlugin::new( + + let group_plugin = Arc::new(NodeGroupsPlugin::new( configs, store.clone(), group_store_context.clone(), Some(node_groups_heartbeats.clone()), Some(webhook_plugins.clone()), - ); + )); + + // Register the plugin as a task observer + group_store_context + .task_store + .add_observer(group_plugin.clone()) + .await; let status_group_plugin = group_plugin.clone(); let group_plugin_for_server = group_plugin.clone(); - let group_plugin_arc = Arc::new(group_plugin_for_server); - - // Register the plugin as a task observer - if let Err(e) = group_plugin_arc.clone().register_observer().await { - error!("Failed to register node groups plugin as observer: {e}"); - } - node_groups_plugin = Some(group_plugin_arc); - scheduler_plugins.push(Box::new(group_plugin)); - status_update_plugins.push(Box::new(status_group_plugin)); + node_groups_plugin = Some(group_plugin_for_server); + scheduler_plugins.push(group_plugin.into()); + status_update_plugins.push(status_group_plugin.into()); info!("Plugin: Node group plugin initialized"); } Ok(_) => { @@ -263,16 +264,16 @@ async fn main() -> Result<()> { } // Create status_update_plugins for discovery monitor - let mut discovery_status_update_plugins: Vec> = vec![]; + let mut discovery_status_update_plugins: Vec = vec![]; // Add webhook plugins to discovery status update plugins for plugin in &webhook_plugins { - discovery_status_update_plugins.push(Box::new(plugin.clone())); + discovery_status_update_plugins.push(plugin.into()); } // Add node groups plugin if available if let Some(group_plugin) = node_groups_plugin.clone() { - discovery_status_update_plugins.push(Box::new(group_plugin.as_ref().clone())); + discovery_status_update_plugins.push(group_plugin.into()); } let discovery_store_context = store_context.clone(); @@ -316,16 +317,16 @@ async fn main() -> Result<()> { }); // Create status_update_plugins for status updater - let mut status_updater_plugins: Vec> = vec![]; + let mut status_updater_plugins: Vec = vec![]; // Add webhook plugins to status updater plugins for plugin in &webhook_plugins { - status_updater_plugins.push(Box::new(plugin.clone())); + status_updater_plugins.push(plugin.into()); } // Add node groups plugin if available if let Some(group_plugin) = node_groups_plugin.clone() { - status_updater_plugins.push(Box::new(group_plugin.as_ref().clone())); + status_updater_plugins.push(group_plugin.into()); } let status_update_store_context = store_context.clone(); diff --git a/crates/orchestrator/src/plugins/mod.rs b/crates/orchestrator/src/plugins/mod.rs index 256de0cb..f2e086d0 100644 --- a/crates/orchestrator/src/plugins/mod.rs +++ b/crates/orchestrator/src/plugins/mod.rs @@ -1,8 +1,103 @@ -mod traits; -pub use traits::*; +use crate::plugins::newest_task::NewestTaskPlugin; +use alloy::primitives::Address; +use anyhow::Result; +use shared::models::task::Task; +use std::sync::Arc; -pub(crate) mod node_groups; +use crate::{ + models::node::{NodeStatus, OrchestratorNode}, + plugins::node_groups::NodeGroupsPlugin, + plugins::webhook::WebhookPlugin, +}; pub(crate) mod newest_task; - +pub(crate) mod node_groups; pub(crate) mod webhook; + +#[derive(Clone)] +pub enum StatusUpdatePlugin { + NodeGroupsPlugin(Arc), + WebhookPlugin(WebhookPlugin), +} + +impl StatusUpdatePlugin { + pub(crate) async fn handle_status_change( + &self, + node: &OrchestratorNode, + status: &NodeStatus, + ) -> Result<()> { + match self { + StatusUpdatePlugin::NodeGroupsPlugin(plugin) => plugin.handle_status_change(node).await, + StatusUpdatePlugin::WebhookPlugin(plugin) => plugin.handle_status_change(node, status), + } + } +} + +impl From> for StatusUpdatePlugin { + fn from(plugin: Arc) -> Self { + StatusUpdatePlugin::NodeGroupsPlugin(plugin) + } +} + +impl From<&Arc> for StatusUpdatePlugin { + fn from(plugin: &Arc) -> Self { + StatusUpdatePlugin::NodeGroupsPlugin(plugin.clone()) + } +} + +impl From for StatusUpdatePlugin { + fn from(plugin: WebhookPlugin) -> Self { + StatusUpdatePlugin::WebhookPlugin(plugin) + } +} + +impl From<&WebhookPlugin> for StatusUpdatePlugin { + fn from(plugin: &WebhookPlugin) -> Self { + StatusUpdatePlugin::WebhookPlugin(plugin.clone()) + } +} + +#[derive(Clone)] +pub enum SchedulerPlugin { + NodeGroupsPlugin(Arc), + NewestTaskPlugin(NewestTaskPlugin), +} + +impl SchedulerPlugin { + pub(crate) async fn filter_tasks( + &self, + tasks: &[Task], + node_address: &Address, + ) -> Result> { + match self { + SchedulerPlugin::NodeGroupsPlugin(plugin) => { + plugin.filter_tasks(tasks, node_address).await + } + SchedulerPlugin::NewestTaskPlugin(plugin) => plugin.filter_tasks(tasks), + } + } +} + +impl From> for SchedulerPlugin { + fn from(plugin: Arc) -> Self { + SchedulerPlugin::NodeGroupsPlugin(plugin) + } +} + +impl From<&Arc> for SchedulerPlugin { + fn from(plugin: &Arc) -> Self { + SchedulerPlugin::NodeGroupsPlugin(plugin.clone()) + } +} + +impl From for SchedulerPlugin { + fn from(plugin: NewestTaskPlugin) -> Self { + SchedulerPlugin::NewestTaskPlugin(plugin) + } +} + +impl From<&NewestTaskPlugin> for SchedulerPlugin { + fn from(plugin: &NewestTaskPlugin) -> Self { + SchedulerPlugin::NewestTaskPlugin(plugin.clone()) + } +} diff --git a/crates/orchestrator/src/plugins/newest_task/mod.rs b/crates/orchestrator/src/plugins/newest_task/mod.rs index 4a7d2d0c..9cb40662 100644 --- a/crates/orchestrator/src/plugins/newest_task/mod.rs +++ b/crates/orchestrator/src/plugins/newest_task/mod.rs @@ -1,17 +1,11 @@ -use alloy::primitives::Address; use anyhow::Result; -use async_trait::async_trait; use shared::models::task::Task; -use super::{Plugin, SchedulerPlugin}; +#[derive(Clone)] +pub struct NewestTaskPlugin; -pub(crate) struct NewestTaskPlugin; - -impl Plugin for NewestTaskPlugin {} - -#[async_trait] -impl SchedulerPlugin for NewestTaskPlugin { - async fn filter_tasks(&self, tasks: &[Task], _node_address: &Address) -> Result> { +impl NewestTaskPlugin { + pub(crate) fn filter_tasks(&self, tasks: &[Task]) -> Result> { if tasks.is_empty() { return Ok(vec![]); } @@ -32,8 +26,8 @@ mod tests { use super::*; - #[tokio::test] - async fn test_filter_tasks() { + #[test] + fn test_filter_tasks() { let plugin = NewestTaskPlugin; let tasks = vec![ Task { @@ -54,7 +48,7 @@ mod tests { }, ]; - let filtered_tasks = plugin.filter_tasks(&tasks, &Address::ZERO).await.unwrap(); + let filtered_tasks = plugin.filter_tasks(&tasks).unwrap(); assert_eq!(filtered_tasks.len(), 1); assert_eq!(filtered_tasks[0].id, tasks[1].id); } diff --git a/crates/orchestrator/src/plugins/node_groups/mod.rs b/crates/orchestrator/src/plugins/node_groups/mod.rs index be99ef63..205f1576 100644 --- a/crates/orchestrator/src/plugins/node_groups/mod.rs +++ b/crates/orchestrator/src/plugins/node_groups/mod.rs @@ -1,6 +1,4 @@ use super::webhook::WebhookPlugin; -use super::{Plugin, SchedulerPlugin}; -use crate::events::TaskObserver; use crate::models::node::{NodeStatus, OrchestratorNode}; use crate::store::core::{RedisStore, StoreContext}; use crate::utils::loop_heartbeats::LoopHeartbeats; @@ -99,7 +97,6 @@ impl Default for TaskSwitchingPolicy { } } -#[derive(Clone)] pub struct NodeGroupsPlugin { configuration_templates: Vec, store: Arc, @@ -177,13 +174,32 @@ impl NodeGroupsPlugin { } } - /// Register this plugin as a task observer (async) - pub async fn register_observer(self: Arc) -> Result<()> { - self.store_context - .task_store - .add_observer(self.clone()) - .await; - Ok(()) + // TODO: this should consume self; refactor this to separate running logic + // and other components. it appears quite a lot of different logic is + // combined into this one type + pub async fn run_group_management_loop(&self, duration: u64) -> Result<(), Error> { + let mut interval = tokio::time::interval(Duration::from_secs(duration)); + + loop { + let start = std::time::Instant::now(); + interval.tick().await; + + // First, form new groups with optimal sizing + if let Err(e) = self.try_form_new_groups().await { + error!("Error in group formation: {e}"); + } + + if let Err(e) = self.try_merge_solo_groups().await { + error!("Error in group merging: {e}"); + } + + if let Some(heartbeats) = &self.node_groups_heartbeats { + heartbeats.update_node_groups(); + } + + let elapsed = start.elapsed(); + log::info!("Group management loop completed in {elapsed:?}"); + } } /// Check if a node is compatible with a configuration's compute requirements @@ -289,7 +305,7 @@ impl NodeGroupsPlugin { pipe.atomic(); // Store group data - let group_key = Self::get_group_key(&group.id); + let group_key = get_group_key(&group.id); let group_data = serde_json::to_string(group)?; pipe.set(&group_key, group_data); @@ -305,22 +321,12 @@ impl NodeGroupsPlugin { Ok(()) } - fn generate_group_id() -> String { - use rand::Rng; - let mut rng = rand::rng(); - format!("{:x}", rng.random::()) - } - - fn get_group_key(group_id: &str) -> String { - format!("{GROUP_KEY_PREFIX}{group_id}") - } - pub async fn get_node_group(&self, node_addr: &str) -> Result, Error> { let mut conn = self.store.client.get_multiplexed_async_connection().await?; let group_id: Option = conn.hget(NODE_GROUP_MAP_KEY, node_addr).await?; if let Some(group_id) = group_id { - let group_key = Self::get_group_key(&group_id); + let group_key = get_group_key(&group_id); let group_data: Option = conn.get(&group_key).await?; if let Some(group_data) = group_data { return Ok(Some(serde_json::from_str(&group_data)?)); @@ -357,7 +363,7 @@ impl NodeGroupsPlugin { let group_data: HashMap = if !unique_group_ids.is_empty() { let group_keys: Vec = unique_group_ids .iter() - .map(|id| Self::get_group_key(id)) + .map(|id| get_group_key(id)) .collect(); let group_values: Vec> = conn.mget(&group_keys).await?; @@ -415,20 +421,6 @@ impl NodeGroupsPlugin { self.configuration_templates.clone() } - pub async fn enable_configuration(&self, configuration_name: &str) -> Result<(), Error> { - let mut conn = self.store.client.get_multiplexed_async_connection().await?; - conn.sadd::<_, _, ()>("available_node_group_configs", configuration_name) - .await?; - Ok(()) - } - - pub async fn disable_configuration(&self, configuration_name: &str) -> Result<(), Error> { - let mut conn = self.store.client.get_multiplexed_async_connection().await?; - conn.srem::<_, _, ()>("available_node_group_configs", configuration_name) - .await?; - Ok(()) - } - pub fn get_idx_in_group( &self, node_group: &NodeGroup, @@ -574,7 +566,7 @@ impl NodeGroupsPlugin { } // Create new group - let group_id = Self::generate_group_id(); + let group_id = generate_group_id(); debug!("Generating new group with ID: {group_id}"); let group = NodeGroup { @@ -892,7 +884,7 @@ impl NodeGroupsPlugin { groups_to_merge.iter().map(|g| g.id.clone()).collect(); // Create new merged group - let new_group_id = Self::generate_group_id(); + let new_group_id = generate_group_id(); let merged_group = NodeGroup { id: new_group_id.clone(), nodes: merged_nodes.clone(), @@ -910,7 +902,7 @@ impl NodeGroupsPlugin { // Dissolve old groups for group_id in &group_ids_to_dissolve { // Get group data for webhook notifications (done before deletion) - let group_key = Self::get_group_key(group_id); + let group_key = get_group_key(group_id); // Remove nodes from group mapping if let Some(old_group) = groups_to_merge.iter().find(|g| &g.id == group_id) { @@ -931,7 +923,7 @@ impl NodeGroupsPlugin { } // Create new merged group - let group_key = Self::get_group_key(&new_group_id); + let group_key = get_group_key(&new_group_id); let group_data = serde_json::to_string(&merged_group)?; pipe.set(&group_key, group_data); @@ -1007,123 +999,11 @@ impl NodeGroupsPlugin { } } - pub async fn run_group_management_loop(&self, duration: u64) -> Result<(), Error> { - let mut interval = tokio::time::interval(Duration::from_secs(duration)); - - loop { - let start = std::time::Instant::now(); - interval.tick().await; - - // First, form new groups with optimal sizing - if let Err(e) = self.try_form_new_groups().await { - error!("Error in group formation: {e}"); - } - - if let Err(e) = self.try_merge_solo_groups().await { - error!("Error in group merging: {e}"); - } - - if let Some(heartbeats) = &self.node_groups_heartbeats { - heartbeats.update_node_groups(); - } - - let elapsed = start.elapsed(); - log::info!("Group management loop completed in {elapsed:?}"); - } + pub(crate) async fn dissolve_group(&self, group_id: &str) -> Result<(), Error> { + dissolve_group(group_id, &self.store, &self.webhook_plugins).await } - pub async fn dissolve_group(&self, group_id: &str) -> Result<(), Error> { - debug!("Attempting to dissolve group: {group_id}"); - let mut conn = self.store.client.get_multiplexed_async_connection().await?; - - let group_key = Self::get_group_key(group_id); - let group_data: Option = conn.get(&group_key).await?; - - if let Some(group_data) = group_data { - let group: NodeGroup = serde_json::from_str(&group_data)?; - debug!("Found group to dissolve: {group:?}"); - - // Use a Redis transaction to atomically dissolve the group - let mut pipe = redis::pipe(); - pipe.atomic(); - - // Remove all nodes from the group mapping - debug!("Removing {} nodes from group mapping", group.nodes.len()); - for node in &group.nodes { - pipe.hdel(NODE_GROUP_MAP_KEY, node); - } - - // Remove group ID from groups index - pipe.srem(GROUPS_INDEX_KEY, group_id); - - // Delete group task assignment - let task_key = format!("{GROUP_TASK_KEY_PREFIX}{group_id}"); - debug!("Deleting group task assignment from key: {task_key}"); - pipe.del(&task_key); - - // Delete group - debug!("Deleting group data from key: {group_key}"); - pipe.del(&group_key); - - // Execute all operations atomically - pipe.query_async::<()>(&mut conn).await?; - - info!( - "Dissolved group {} with {} nodes", - group_id, - group.nodes.len() - ); - if let Some(plugins) = &self.webhook_plugins { - for plugin in plugins.iter() { - if let Err(e) = plugin.send_group_destroyed( - group.id.clone(), - group.configuration_name.clone(), - group.nodes.iter().cloned().collect(), - ) { - error!("Failed to send group dissolved webhook: {e}"); - } - } - } - } else { - debug!("No group found with ID: {group_id}"); - } - - Ok(()) - } - - pub fn get_task_topologies(&self, task: &Task) -> Result, Error> { - debug!("Getting topologies for task: {task:?}"); - if let Some(config) = &task.scheduling_config { - if let Some(plugins) = &config.plugins { - if let Some(node_groups) = plugins.get("node_groups") { - if let Some(allowed_topologies) = node_groups.get("allowed_topologies") { - debug!("Found allowed topologies: {allowed_topologies:?}"); - return Ok(allowed_topologies.iter().map(|t| t.to_string()).collect()); - } - } - } - } - debug!("No topologies found for task"); - Ok(vec![]) - } - - pub async fn get_all_tasks_for_topology(&self, topology: &str) -> Result, Error> { - debug!("Getting all tasks for topology: {topology}"); - let all_tasks = self.store_context.task_store.get_all_tasks().await?; - debug!("Found {} total tasks to check", all_tasks.len()); - - let mut tasks = Vec::new(); - for task in all_tasks { - let topologies = self.get_task_topologies(&task)?; - if topologies.contains(&topology.to_string()) { - tasks.push(task); - } - } - debug!("Found {} tasks for topology {}", tasks.len(), topology); - Ok(tasks) - } - - pub async fn get_all_groups(&self) -> Result, Error> { + pub(crate) async fn get_all_groups(&self) -> Result, Error> { debug!("Getting all groups"); let mut conn = self.store.client.get_multiplexed_async_connection().await?; @@ -1138,7 +1018,7 @@ impl NodeGroupsPlugin { debug!("Found {} group IDs in index", group_ids.len()); // Use MGET to batch fetch all group data - let group_keys: Vec = group_ids.iter().map(|id| Self::get_group_key(id)).collect(); + let group_keys: Vec = group_ids.iter().map(|id| get_group_key(id)).collect(); let group_values: Vec> = conn.mget(&group_keys).await?; @@ -1163,9 +1043,9 @@ impl NodeGroupsPlugin { Ok(groups) } - pub async fn get_group_by_id(&self, group_id: &str) -> Result, Error> { + pub(crate) async fn get_group_by_id(&self, group_id: &str) -> Result, Error> { let mut conn = self.store.client.get_multiplexed_async_connection().await?; - let group_key = Self::get_group_key(group_id); + let group_key = get_group_key(group_id); if let Some(group_data) = conn.get::<_, Option>(&group_key).await? { Ok(Some(serde_json::from_str(&group_data)?)) @@ -1174,61 +1054,27 @@ impl NodeGroupsPlugin { } } - pub async fn get_all_node_group_mappings(&self) -> Result, Error> { + pub(crate) async fn get_all_node_group_mappings( + &self, + ) -> Result, Error> { let mut conn = self.store.client.get_multiplexed_async_connection().await?; let mappings: HashMap = conn.hgetall(NODE_GROUP_MAP_KEY).await?; Ok(mappings) } - /// Get all groups assigned to a specific task - /// Returns a list of group IDs that are currently working on the given task - pub async fn get_groups_for_task(&self, task_id: &str) -> Result, Error> { - debug!("Getting all groups for task: {task_id}"); - let mut conn = self.store.client.get_multiplexed_async_connection().await?; - - // First, collect all group_task keys - let pattern = format!("{GROUP_TASK_KEY_PREFIX}*"); - let mut iter: redis::AsyncIter = conn.scan_match(&pattern).await?; - let mut all_keys = Vec::new(); - - while let Some(key) = iter.next_item().await { - all_keys.push(key); - } - - // Drop the iterator to release the borrow on conn - drop(iter); - - // Now check which keys point to our task_id - let mut group_ids = Vec::new(); - for key in all_keys { - if let Some(stored_task_id) = conn.get::<_, Option>(&key).await? { - if stored_task_id == task_id { - // Extract group_id from the key (remove the prefix) - if let Some(group_id) = key.strip_prefix(GROUP_TASK_KEY_PREFIX) { - group_ids.push(group_id.to_string()); - } - } - } - } - - debug!( - "Found {} groups for task {}: {:?}", - group_ids.len(), - task_id, - group_ids - ); - Ok(group_ids) - } - /// Validate that a group still exists before task assignment (for scheduler integration) - pub async fn validate_group_exists(&self, group_id: &str) -> Result { + pub(crate) async fn validate_group_exists(&self, group_id: &str) -> Result { let group = self.get_group_by_id(group_id).await?; Ok(group.is_some()) } /// Handle the case where a group was dissolved while processing - for scheduler integration - pub async fn handle_group_not_found(&self, group_id: &str, task_id: &str) -> Result<(), Error> { + pub(crate) async fn handle_group_not_found( + &self, + group_id: &str, + task_id: &str, + ) -> Result<(), Error> { warn!( "Group {group_id} not found during task assignment for task {task_id}, attempting recovery" ); @@ -1343,17 +1189,17 @@ impl NodeGroupsPlugin { } #[cfg(test)] - pub async fn test_try_form_new_groups(&self) -> Result, Error> { + pub(crate) async fn test_try_form_new_groups(&self) -> Result, Error> { self.try_form_new_groups().await } #[cfg(test)] - pub async fn test_try_merge_solo_groups(&self) -> Result, Error> { + pub(crate) async fn test_try_merge_solo_groups(&self) -> Result, Error> { self.try_merge_solo_groups().await } #[cfg(test)] - pub async fn test_should_switch_tasks( + pub(crate) async fn test_should_switch_tasks( &self, current_groups: &[NodeGroup], potential_merged_size: usize, @@ -1363,7 +1209,7 @@ impl NodeGroupsPlugin { } #[cfg(test)] - pub async fn test_find_best_task_for_group( + pub(crate) async fn test_find_best_task_for_group( &self, group: &NodeGroup, ) -> Result, Error> { @@ -1371,26 +1217,22 @@ impl NodeGroupsPlugin { } #[cfg(test)] - pub fn test_get_task_switching_policy(&self) -> &TaskSwitchingPolicy { + pub(crate) fn test_get_task_switching_policy(&self) -> &TaskSwitchingPolicy { &self.task_switching_policy } -} -impl Plugin for NodeGroupsPlugin {} - -impl TaskObserver for NodeGroupsPlugin { - fn on_task_created(&self, task: &Task) -> Result<()> { + pub(crate) fn on_task_created(&self, task: &Task) -> Result<()> { debug!("Task created event received: {task:?}"); - let topologies = self.get_task_topologies(task)?; + let topologies = get_task_topologies(task)?; debug!("Found {} topologies for new task", topologies.len()); for topology in topologies { debug!("Enabling configuration for topology: {topology}"); + let store = self.store.clone(); tokio::spawn({ - let plugin = self.clone(); let topology = topology.clone(); async move { - if let Err(e) = plugin.enable_configuration(&topology).await { + if let Err(e) = enable_configuration(&store, &topology).await { error!("Failed to enable configuration for topology {topology}: {e}"); } } @@ -1400,21 +1242,23 @@ impl TaskObserver for NodeGroupsPlugin { Ok(()) } - fn on_task_deleted(&self, task: Option) -> Result<()> { + pub(crate) fn on_task_deleted(&self, task: Option) -> Result<()> { if let Some(task) = task { debug!("Task deleted event received: {task:?}"); let task_id = task.id.to_string(); - let topologies = self.get_task_topologies(&task)?; + let topologies = get_task_topologies(&task)?; debug!("Found {} topologies for task cleanup", topologies.len()); + let store = self.store.clone(); + let store_context = self.store_context.clone(); + let webhook_plugins = self.webhook_plugins.clone(); tokio::spawn({ - let plugin = self.clone(); let task_id = task_id.clone(); let topologies = topologies.clone(); async move { // Immediately dissolve all groups assigned to this specific task debug!("Dissolving groups for deleted task: {task_id}"); - let groups_for_task = match plugin.get_groups_for_task(&task_id).await { + let groups_for_task = match get_groups_for_task(store.clone(), &task_id).await { Ok(groups) => groups, Err(e) => { error!("Failed to get groups for task {task_id}: {e}"); @@ -1431,7 +1275,8 @@ impl TaskObserver for NodeGroupsPlugin { for group_id in &groups_for_task { debug!("Dissolving group {group_id} for deleted task {task_id}"); - if let Err(e) = plugin.dissolve_group(group_id).await { + if let Err(e) = dissolve_group(group_id, &store, &webhook_plugins).await + { error!( "Failed to dissolve group {group_id} for task {task_id}: {e}" ); @@ -1449,20 +1294,24 @@ impl TaskObserver for NodeGroupsPlugin { // This is secondary to the immediate group dissolution above for topology in topologies { debug!("Checking topology {topology} for configuration cleanup"); - let remaining_tasks = - match plugin.get_all_tasks_for_topology(&topology).await { - Ok(tasks) => tasks, - Err(e) => { - error!("Failed to get tasks for topology {topology}: {e}"); - continue; - } - }; + let remaining_tasks = match get_all_tasks_for_topology( + store_context.clone(), + &topology, + ) + .await + { + Ok(tasks) => tasks, + Err(e) => { + error!("Failed to get tasks for topology {topology}: {e}"); + continue; + } + }; if remaining_tasks.is_empty() { debug!( "No tasks remaining for topology {topology}, disabling configuration" ); - if let Err(e) = plugin.disable_configuration(&topology).await { + if let Err(e) = disable_configuration(&store, &topology).await { error!( "Failed to disable configuration for topology {topology}: {e}" ); @@ -1475,3 +1324,174 @@ impl TaskObserver for NodeGroupsPlugin { Ok(()) } } + +pub(crate) async fn enable_configuration( + store: &Arc, + configuration_name: &str, +) -> Result<(), Error> { + let mut conn = store.client.get_multiplexed_async_connection().await?; + conn.sadd::<_, _, ()>("available_node_group_configs", configuration_name) + .await?; + Ok(()) +} + +pub(crate) async fn disable_configuration( + store: &Arc, + configuration_name: &str, +) -> Result<(), Error> { + let mut conn = store.client.get_multiplexed_async_connection().await?; + conn.srem::<_, _, ()>("available_node_group_configs", configuration_name) + .await?; + Ok(()) +} + +/// Get all groups assigned to a specific task +/// Returns a list of group IDs that are currently working on the given task +async fn get_groups_for_task(store: Arc, task_id: &str) -> Result, Error> { + debug!("Getting all groups for task: {task_id}"); + let mut conn = store.client.get_multiplexed_async_connection().await?; + + // First, collect all group_task keys + let pattern = format!("{GROUP_TASK_KEY_PREFIX}*"); + let mut iter: redis::AsyncIter = conn.scan_match(&pattern).await?; + let mut all_keys = Vec::new(); + + while let Some(key) = iter.next_item().await { + all_keys.push(key); + } + + // Drop the iterator to release the borrow on conn + drop(iter); + + // Now check which keys point to our task_id + let mut group_ids = Vec::new(); + for key in all_keys { + if let Some(stored_task_id) = conn.get::<_, Option>(&key).await? { + if stored_task_id == task_id { + // Extract group_id from the key (remove the prefix) + if let Some(group_id) = key.strip_prefix(GROUP_TASK_KEY_PREFIX) { + group_ids.push(group_id.to_string()); + } + } + } + } + + debug!( + "Found {} groups for task {}: {:?}", + group_ids.len(), + task_id, + group_ids + ); + Ok(group_ids) +} + +async fn get_all_tasks_for_topology( + store_context: Arc, + topology: &str, +) -> Result, Error> { + debug!("Getting all tasks for topology: {topology}"); + let all_tasks = store_context.task_store.get_all_tasks().await?; + debug!("Found {} total tasks to check", all_tasks.len()); + + let mut tasks = Vec::new(); + for task in all_tasks { + let topologies = get_task_topologies(&task)?; + if topologies.contains(&topology.to_string()) { + tasks.push(task); + } + } + debug!("Found {} tasks for topology {}", tasks.len(), topology); + Ok(tasks) +} + +pub(crate) fn get_task_topologies(task: &Task) -> Result, Error> { + debug!("Getting topologies for task: {task:?}"); + if let Some(config) = &task.scheduling_config { + if let Some(plugins) = &config.plugins { + if let Some(node_groups) = plugins.get("node_groups") { + if let Some(allowed_topologies) = node_groups.get("allowed_topologies") { + debug!("Found allowed topologies: {allowed_topologies:?}"); + return Ok(allowed_topologies.iter().map(|t| t.to_string()).collect()); + } + } + } + } + debug!("No topologies found for task"); + Ok(vec![]) +} + +async fn dissolve_group( + group_id: &str, + store: &Arc, + webhook_plugins: &Option>, +) -> Result<(), Error> { + debug!("Attempting to dissolve group: {group_id}"); + let mut conn = store.client.get_multiplexed_async_connection().await?; + + let group_key = get_group_key(group_id); + let group_data: Option = conn.get(&group_key).await?; + + if let Some(group_data) = group_data { + let group: NodeGroup = serde_json::from_str(&group_data)?; + debug!("Found group to dissolve: {group:?}"); + + // Use a Redis transaction to atomically dissolve the group + let mut pipe = redis::pipe(); + pipe.atomic(); + + // Remove all nodes from the group mapping + debug!("Removing {} nodes from group mapping", group.nodes.len()); + for node in &group.nodes { + pipe.hdel(NODE_GROUP_MAP_KEY, node); + } + + // Remove group ID from groups index + pipe.srem(GROUPS_INDEX_KEY, group_id); + + // Delete group task assignment + let task_key = format!("{GROUP_TASK_KEY_PREFIX}{group_id}"); + debug!("Deleting group task assignment from key: {task_key}"); + pipe.del(&task_key); + + // Delete group + debug!("Deleting group data from key: {group_key}"); + pipe.del(&group_key); + + // Execute all operations atomically + pipe.query_async::<()>(&mut conn).await?; + + info!( + "Dissolved group {} with {} nodes", + group_id, + group.nodes.len() + ); + + if let Some(plugins) = webhook_plugins { + for plugin in plugins.iter() { + let plugin_clone = plugin.clone(); + let group_clone = group.clone(); + if let Err(e) = plugin_clone.send_group_destroyed( + group_clone.id.to_string(), + group_clone.configuration_name.to_string(), + group_clone.nodes.iter().cloned().collect(), + ) { + error!("Failed to send group dissolved webhook: {e}"); + } + } + } + } else { + debug!("No group found with ID: {group_id}"); + } + + Ok(()) +} + +fn generate_group_id() -> String { + use rand::Rng; + let mut rng = rand::rng(); + format!("{:x}", rng.random::()) +} + +fn get_group_key(group_id: &str) -> String { + format!("{GROUP_KEY_PREFIX}{group_id}") +} diff --git a/crates/orchestrator/src/plugins/node_groups/scheduler_impl.rs b/crates/orchestrator/src/plugins/node_groups/scheduler_impl.rs index dcb302ac..3aaa27c2 100644 --- a/crates/orchestrator/src/plugins/node_groups/scheduler_impl.rs +++ b/crates/orchestrator/src/plugins/node_groups/scheduler_impl.rs @@ -1,16 +1,18 @@ -use super::{NodeGroupsPlugin, SchedulerPlugin}; +use super::NodeGroupsPlugin; use alloy::primitives::Address; use anyhow::Result; -use async_trait::async_trait; use log::{error, info}; use rand::seq::IteratorRandom; use redis::AsyncCommands; use shared::models::task::Task; use std::str::FromStr; -#[async_trait] -impl SchedulerPlugin for NodeGroupsPlugin { - async fn filter_tasks(&self, tasks: &[Task], node_address: &Address) -> Result> { +impl NodeGroupsPlugin { + pub(crate) async fn filter_tasks( + &self, + tasks: &[Task], + node_address: &Address, + ) -> Result> { if let Ok(Some(group)) = self.get_node_group(&node_address.to_string()).await { info!( "Node {} is in group {} with {} nodes", @@ -203,10 +205,7 @@ impl SchedulerPlugin for NodeGroupsPlugin { return Ok(vec![task_clone]); } } - info!( - "Node {} is not in a group, skipping all tasks", - node_address - ); + info!("Node {node_address} is not in a group, skipping all tasks"); Ok(vec![]) } } diff --git a/crates/orchestrator/src/plugins/node_groups/status_update_impl.rs b/crates/orchestrator/src/plugins/node_groups/status_update_impl.rs index 568b297f..658f6d66 100644 --- a/crates/orchestrator/src/plugins/node_groups/status_update_impl.rs +++ b/crates/orchestrator/src/plugins/node_groups/status_update_impl.rs @@ -1,17 +1,11 @@ use crate::models::node::{NodeStatus, OrchestratorNode}; use crate::plugins::node_groups::NodeGroupsPlugin; -use crate::plugins::StatusUpdatePlugin; use anyhow::Error; use anyhow::Result; use log::info; -#[async_trait::async_trait] -impl StatusUpdatePlugin for NodeGroupsPlugin { - async fn handle_status_change( - &self, - node: &OrchestratorNode, - _old_status: &NodeStatus, - ) -> Result<(), Error> { +impl NodeGroupsPlugin { + pub(crate) async fn handle_status_change(&self, node: &OrchestratorNode) -> Result<(), Error> { let node_addr = node.address.to_string(); info!( diff --git a/crates/orchestrator/src/plugins/node_groups/tests.rs b/crates/orchestrator/src/plugins/node_groups/tests.rs index 8ca7c086..a7d73b36 100644 --- a/crates/orchestrator/src/plugins/node_groups/tests.rs +++ b/crates/orchestrator/src/plugins/node_groups/tests.rs @@ -1,5 +1,3 @@ -use crate::plugins::traits::SchedulerPlugin; -use crate::plugins::traits::StatusUpdatePlugin; use crate::{ models::node::{NodeStatus, OrchestratorNode}, plugins::node_groups::{ @@ -19,6 +17,8 @@ use shared::models::{ use std::collections::BTreeSet; use std::{collections::HashMap, str::FromStr, sync::Arc}; +use crate::plugins::node_groups::enable_configuration; +use crate::plugins::node_groups::get_task_topologies; use uuid::Uuid; fn create_test_node( @@ -117,11 +117,11 @@ async fn test_group_formation_and_dissolution() { let plugin = Arc::new(NodeGroupsPlugin::new( vec![config], store.clone(), - store_context, + store_context.clone(), None, None, )); - let _ = plugin.clone().register_observer().await; + let _ = store_context.task_store.add_observer(plugin.clone()).await; let task = Task { scheduling_config: Some(SchedulingConfig { @@ -184,9 +184,7 @@ async fn test_group_formation_and_dissolution() { .node_store .update_node_status(&node1_dead.address, NodeStatus::Dead) .await; - let _ = plugin - .handle_status_change(&node1_dead, &NodeStatus::Dead) - .await; + let _ = plugin.handle_status_change(&node1_dead).await; let _ = plugin.try_form_new_groups().await; @@ -220,11 +218,12 @@ async fn test_group_formation_with_multiple_configs() { let plugin = Arc::new(NodeGroupsPlugin::new( vec![config_s, config_xs], store.clone(), - store_context, + store_context.clone(), None, None, )); - let _ = plugin.clone().register_observer().await; + let _ = store_context.task_store.add_observer(plugin.clone()).await; + let task = Task { scheduling_config: Some(SchedulingConfig { plugins: Some(HashMap::from([( @@ -319,11 +318,12 @@ async fn test_group_formation_with_requirements_and_single_node() { let plugin = Arc::new(NodeGroupsPlugin::new( vec![config], store.clone(), - store_context, + store_context.clone(), None, None, )); - let _ = plugin.clone().register_observer().await; + let _ = store_context.task_store.add_observer(plugin.clone()).await; + let task = Task { scheduling_config: Some(SchedulingConfig { plugins: Some(HashMap::from([( @@ -402,11 +402,12 @@ async fn test_group_formation_with_requirements_and_multiple_nodes() { let plugin = Arc::new(NodeGroupsPlugin::new( vec![config], store.clone(), - store_context, + store_context.clone(), None, None, )); - let _ = plugin.clone().register_observer().await; + let _ = store_context.task_store.add_observer(plugin.clone()).await; + let task = Task { scheduling_config: Some(SchedulingConfig { plugins: Some(HashMap::from([( @@ -519,11 +520,11 @@ async fn test_group_scheduling() { let plugin = Arc::new(NodeGroupsPlugin::new( vec![config], store.clone(), - store_context, + store_context.clone(), None, None, )); - let _ = plugin.clone().register_observer().await; + let _ = store_context.task_store.add_observer(plugin.clone()).await; let task = Task { scheduling_config: Some(SchedulingConfig { @@ -689,11 +690,12 @@ async fn test_group_scheduling_without_tasks() { let plugin = Arc::new(NodeGroupsPlugin::new( vec![config], store.clone(), - store_context, + store_context.clone(), None, None, )); - let _ = plugin.clone().register_observer().await; + let _ = store_context.task_store.add_observer(plugin.clone()).await; + let node1 = create_test_node( "0x1234567890123456789012345678901234567890", NodeStatus::Healthy, @@ -744,11 +746,12 @@ async fn test_group_formation_with_max_size() { let plugin = Arc::new(NodeGroupsPlugin::new( vec![config], store.clone(), - store_context, + store_context.clone(), None, None, )); - let _ = plugin.clone().register_observer().await; + let _ = store_context.task_store.add_observer(plugin.clone()).await; + let task = Task { scheduling_config: Some(SchedulingConfig { plugins: Some(HashMap::from([( @@ -897,11 +900,11 @@ async fn test_node_groups_with_allowed_topologies() { let plugin = Arc::new(NodeGroupsPlugin::new( vec![config], store.clone(), - store_context, + store_context.clone(), None, None, )); - let _ = plugin.clone().register_observer().await; + let _ = store_context.task_store.add_observer(plugin.clone()).await; let node1 = create_test_node( "0x1234567890123456789012345678901234567890", @@ -1002,11 +1005,11 @@ async fn test_node_cannot_be_in_multiple_groups() { let plugin = Arc::new(NodeGroupsPlugin::new( vec![config], store.clone(), - store_context, + store_context.clone(), None, None, )); - let _ = plugin.clone().register_observer().await; + let _ = store_context.task_store.add_observer(plugin.clone()).await; let all_nodes = plugin.store_context.node_store.get_nodes().await.unwrap(); assert_eq!(all_nodes.len(), 0, "No nodes should be in the store"); @@ -1222,11 +1225,12 @@ async fn test_reformation_on_death() { let plugin = Arc::new(NodeGroupsPlugin::new( vec![config], store.clone(), - store_context, + store_context.clone(), None, None, )); - let _ = plugin.clone().register_observer().await; + let _ = store_context.task_store.add_observer(plugin.clone()).await; + let task = Task { scheduling_config: Some(SchedulingConfig { plugins: Some(HashMap::from([( @@ -1308,7 +1312,7 @@ async fn test_reformation_on_death() { .node_store .update_node_status(&node2.address, NodeStatus::Dead) .await; - let _ = plugin.handle_status_change(&node2, &NodeStatus::Dead).await; + let _ = plugin.handle_status_change(&node2).await; let _ = plugin.try_form_new_groups().await; @@ -1486,7 +1490,7 @@ async fn test_task_observer() { None, None, )); - let _ = plugin.clone().register_observer().await; + let _ = store_context.task_store.add_observer(plugin.clone()).await; let node = create_test_node( "0x1234567890123456789012345678901234567890", @@ -1543,10 +1547,10 @@ async fn test_task_observer() { println!("All tasks: {:?}", all_tasks); assert_eq!(all_tasks.len(), 2); assert!(all_tasks[0].id != all_tasks[1].id); - let topologies = plugin.get_task_topologies(&task).unwrap(); + let topologies = get_task_topologies(&task).unwrap(); assert_eq!(topologies.len(), 1); assert_eq!(topologies[0], "test-config"); - let topologies = plugin.get_task_topologies(&task2).unwrap(); + let topologies = get_task_topologies(&task2).unwrap(); assert_eq!(topologies.len(), 1); assert_eq!(topologies[0], "test-config2"); @@ -1658,7 +1662,7 @@ async fn test_building_largest_possible_groups() { None, None, )); - let _ = plugin.clone().register_observer().await; + let _ = store_context.task_store.add_observer(plugin.clone()).await; // Create and add 3 nodes let node1 = create_test_node( @@ -1819,11 +1823,11 @@ async fn test_group_formation_priority() { let plugin = Arc::new(NodeGroupsPlugin::new( vec![config_large, config_small], store.clone(), - store_context, + store_context.clone(), None, None, )); - let _ = plugin.clone().register_observer().await; + let _ = store_context.task_store.add_observer(plugin.clone()).await; // Add 4 healthy nodes let nodes: Vec<_> = (1..=4) @@ -1915,11 +1919,11 @@ async fn test_multiple_groups_same_configuration() { let plugin = Arc::new(NodeGroupsPlugin::new( vec![config], store.clone(), - store_context, + store_context.clone(), None, None, )); - let _ = plugin.clone().register_observer().await; + let _ = store_context.task_store.add_observer(plugin.clone()).await; // Create task that requires this configuration let task = Task { @@ -2034,7 +2038,10 @@ async fn test_task_switching_policy() { disabled_policy, ProximityOptimizationPolicy::default(), )); - let _ = plugin_disabled.clone().register_observer().await; + let _ = store_context + .task_store + .add_observer(plugin_disabled.clone()) + .await; let solo_group = NodeGroup { id: "solo1".to_string(), @@ -2064,7 +2071,10 @@ async fn test_task_switching_policy() { no_prefer_policy, ProximityOptimizationPolicy::default(), )); - let _ = plugin_no_prefer.clone().register_observer().await; + let _ = store_context + .task_store + .add_observer(plugin_no_prefer.clone()) + .await; // Add a node and create a solo group with a task let node1 = create_test_node( @@ -2131,11 +2141,14 @@ async fn test_task_switching_policy() { let plugin_default = Arc::new(NodeGroupsPlugin::new( vec![config], store.clone(), - store_context, + store_context.clone(), None, None, )); - let _ = plugin_default.clone().register_observer().await; + let _ = store_context + .task_store + .add_observer(plugin_default.clone()) + .await; let policy = plugin_default.test_get_task_switching_policy(); assert_eq!( @@ -2171,11 +2184,11 @@ async fn test_merge_solo_groups_with_active_tasks() { let plugin = Arc::new(NodeGroupsPlugin::new( vec![config], store.clone(), - store_context, + store_context.clone(), None, None, )); - let _ = plugin.clone().register_observer().await; + let _ = store_context.task_store.add_observer(plugin.clone()).await; // Create 3 nodes let node1 = create_test_node( @@ -2338,11 +2351,11 @@ async fn test_task_assignment_during_merge() { let plugin = Arc::new(NodeGroupsPlugin::new( vec![config], store.clone(), - store_context, + store_context.clone(), None, None, )); - let _ = plugin.clone().register_observer().await; + let _ = store_context.task_store.add_observer(plugin.clone()).await; // Create 2 nodes let node1 = create_test_node( @@ -2478,11 +2491,11 @@ async fn test_merge_only_compatible_groups() { let plugin = Arc::new(NodeGroupsPlugin::new( vec![config1, config2], store.clone(), - store_context, + store_context.clone(), None, None, )); - let _ = plugin.clone().register_observer().await; + let _ = store_context.task_store.add_observer(plugin.clone()).await; // Create 4 nodes: 2 with GPU specs, 2 without let node1_no_gpu = create_test_node( @@ -2640,13 +2653,13 @@ async fn test_no_merge_when_policy_disabled() { let plugin = Arc::new(NodeGroupsPlugin::new_with_policy( vec![config], store.clone(), - store_context, + store_context.clone(), None, None, disabled_policy, ProximityOptimizationPolicy::default(), )); - let _ = plugin.clone().register_observer().await; + let _ = store_context.task_store.add_observer(plugin.clone()).await; // Create 3 nodes let nodes: Vec<_> = (1..=3) @@ -2710,11 +2723,11 @@ async fn test_edge_case_no_available_tasks() { let plugin = Arc::new(NodeGroupsPlugin::new( vec![config], store.clone(), - store_context, + store_context.clone(), None, None, )); - let _ = plugin.clone().register_observer().await; + let _ = store_context.task_store.add_observer(plugin.clone()).await; // Create 2 nodes let node1 = create_test_node( @@ -2742,7 +2755,7 @@ async fn test_edge_case_no_available_tasks() { // DON'T create any tasks - test what happens during group formation with no available tasks // Manually enable the configuration since there are no tasks to trigger it - let _ = plugin.enable_configuration("no-tasks-config").await; + let _ = enable_configuration(&plugin.store, "no-tasks-config").await; // Form groups (should create groups even without tasks) let _ = plugin.test_try_form_new_groups().await; @@ -2782,11 +2795,11 @@ async fn test_scheduler_integration_with_dissolved_groups() { let plugin = Arc::new(NodeGroupsPlugin::new( vec![config], store.clone(), - store_context, + store_context.clone(), None, None, )); - let _ = plugin.clone().register_observer().await; + let _ = store_context.task_store.add_observer(plugin.clone()).await; // Create node let node1 = create_test_node( @@ -2808,7 +2821,7 @@ async fn test_scheduler_integration_with_dissolved_groups() { assert!(!exists, "Non-existent group should return false"); // Manually enable the configuration to ensure group formation works - let _ = plugin.enable_configuration("scheduler-test").await; + let _ = enable_configuration(&plugin.store, "scheduler-test").await; // Create a group and test validation let _ = plugin.test_try_form_new_groups().await; @@ -2874,7 +2887,8 @@ async fn test_proximity_merging_prevents_wrong_nodes_grouping() { TaskSwitchingPolicy::default(), ProximityOptimizationPolicy { enabled: true }, )); - let _ = plugin.clone().register_observer().await; + let _ = store_context.task_store.add_observer(plugin.clone()).await; + // Create tasks that allow both configurations let task1 = Task { scheduling_config: Some(SchedulingConfig { diff --git a/crates/orchestrator/src/plugins/traits.rs b/crates/orchestrator/src/plugins/traits.rs deleted file mode 100644 index 522ffa71..00000000 --- a/crates/orchestrator/src/plugins/traits.rs +++ /dev/null @@ -1,23 +0,0 @@ -use alloy::primitives::Address; -use anyhow::Error; -use anyhow::Result; -use async_trait::async_trait; -use shared::models::task::Task; - -use crate::models::node::{NodeStatus, OrchestratorNode}; - -pub trait Plugin {} - -#[async_trait] -pub trait StatusUpdatePlugin: Plugin + Send + Sync { - async fn handle_status_change( - &self, - node: &OrchestratorNode, - status: &NodeStatus, - ) -> Result<(), Error>; -} - -#[async_trait] -pub trait SchedulerPlugin: Plugin + Send + Sync { - async fn filter_tasks(&self, tasks: &[Task], node_address: &Address) -> Result>; -} diff --git a/crates/orchestrator/src/plugins/webhook/mod.rs b/crates/orchestrator/src/plugins/webhook/mod.rs index c57a3ff4..3ddf5942 100644 --- a/crates/orchestrator/src/plugins/webhook/mod.rs +++ b/crates/orchestrator/src/plugins/webhook/mod.rs @@ -5,7 +5,6 @@ use serde::{Deserialize, Serialize}; use crate::models::node::{NodeStatus, OrchestratorNode}; -use super::{Plugin, StatusUpdatePlugin}; use log::{error, info, warn}; #[derive(Debug, Clone, Serialize, Deserialize)] @@ -274,13 +273,8 @@ impl WebhookPlugin { let event = WebhookEvent::MetricsUpdated { pool_id, metrics }; self.send_event(event) } -} - -impl Plugin for WebhookPlugin {} -#[async_trait::async_trait] -impl StatusUpdatePlugin for WebhookPlugin { - async fn handle_status_change( + pub(crate) fn handle_status_change( &self, node: &OrchestratorNode, old_status: &NodeStatus, @@ -350,7 +344,7 @@ mod tests { bearer_token: None, }); let node = create_test_node(NodeStatus::Healthy); - let result = plugin.handle_status_change(&node, &NodeStatus::Dead).await; + let result = plugin.handle_status_change(&node, &NodeStatus::Dead); assert!(result.is_ok()); } diff --git a/crates/orchestrator/src/scheduler/mod.rs b/crates/orchestrator/src/scheduler/mod.rs index ced7d104..711f313f 100644 --- a/crates/orchestrator/src/scheduler/mod.rs +++ b/crates/orchestrator/src/scheduler/mod.rs @@ -8,13 +8,13 @@ use anyhow::Result; pub struct Scheduler { store_context: Arc, - plugins: Vec>, + plugins: Vec, } impl Scheduler { - pub fn new(store_context: Arc, plugins: Vec>) -> Self { + pub fn new(store_context: Arc, plugins: Vec) -> Self { let mut plugins = plugins; if plugins.is_empty() { - plugins.push(Box::new(NewestTaskPlugin)); + plugins.push(NewestTaskPlugin.into()); } Self { diff --git a/crates/orchestrator/src/status_update/mod.rs b/crates/orchestrator/src/status_update/mod.rs index 2567f69e..b2738488 100644 --- a/crates/orchestrator/src/status_update/mod.rs +++ b/crates/orchestrator/src/status_update/mod.rs @@ -3,6 +3,7 @@ use crate::models::node::{NodeStatus, OrchestratorNode}; use crate::plugins::StatusUpdatePlugin; use crate::store::core::StoreContext; use crate::utils::loop_heartbeats::LoopHeartbeats; +use futures::stream::FuturesUnordered; use log::{debug, error, info}; use shared::web3::contracts::core::builder::Contracts; use shared::web3::wallet::WalletProvider; @@ -19,7 +20,7 @@ pub struct NodeStatusUpdater { pool_id: u32, disable_ejection: bool, heartbeats: Arc, - plugins: Vec>, + plugins: Vec, metrics: Arc, } @@ -33,7 +34,7 @@ impl NodeStatusUpdater { pool_id: u32, disable_ejection: bool, heartbeats: Arc, - plugins: Vec>, + plugins: Vec, metrics: Arc, ) -> Self { Self { @@ -141,201 +142,251 @@ impl NodeStatusUpdater { } pub async fn process_nodes(&self) -> Result<(), anyhow::Error> { + use futures::StreamExt as _; + let nodes = self.store_context.node_store.get_nodes().await?; + let futures = FuturesUnordered::new(); for node in nodes { - let start_time = Instant::now(); - let node = node.clone(); - let old_status = node.status.clone(); - let heartbeat = self - .store_context - .heartbeat_store - .get_heartbeat(&node.address) - .await?; - let unhealthy_counter: u32 = self - .store_context - .heartbeat_store - .get_unhealthy_counter(&node.address) - .await?; - - #[cfg(test)] - let is_node_in_pool = self.is_node_in_pool(&node); - #[cfg(not(test))] - let is_node_in_pool = self.is_node_in_pool(&node).await; - let mut status_changed = false; - let mut new_status = node.status.clone(); - - match heartbeat { - Some(beat) => { - // Update version if necessary - if let Some(version) = &beat.version { - if node.version.as_ref() != Some(version) { - if let Err(e) = self - .store_context - .node_store - .update_node_version(&node.address, version) - .await - { - error!("Error updating node version: {e}"); - } - } - } + let store_context = self.store_context.clone(); + let contracts = self.contracts.clone(); + let pool_id = self.pool_id; + let missing_heartbeat_threshold = self.missing_heartbeat_threshold; + let plugins = self.plugins.clone(); + let metrics = self.metrics.clone(); + + futures.push(async move { + let address = node.address; + ( + process_node( + node, + store_context, + contracts, + pool_id, + missing_heartbeat_threshold, + plugins, + metrics, + ) + .await, + address, + ) + }); + } - // Check if the node is in the pool (needed for status transitions) + let results: Vec<_> = futures.collect().await; + for result in results { + match result { + (Ok(()), address) => { + debug!("Successfully processed node: {address:?}"); + } + (Err(e), address) => { + error!("Error processing node {address:?}: {e}"); + } + } + } - // If node is Unhealthy or WaitingForHeartbeat: - if node.status == NodeStatus::Unhealthy - || node.status == NodeStatus::WaitingForHeartbeat - { - if is_node_in_pool { - new_status = NodeStatus::Healthy; - } else { - // Reset to discovered to init re-invite to pool - new_status = NodeStatus::Discovered; - } - status_changed = true; - } - // If node is Discovered or Dead: - else if node.status == NodeStatus::Discovered - || node.status == NodeStatus::Dead - { - if is_node_in_pool { - new_status = NodeStatus::Healthy; - } else { - new_status = NodeStatus::Discovered; - } - status_changed = true; - } + Ok(()) + } +} - // Clear unhealthy counter on heartbeat receipt - if let Err(e) = self - .store_context - .heartbeat_store - .clear_unhealthy_counter(&node.address) +async fn process_node( + node: OrchestratorNode, + store_context: Arc, + contracts: Contracts, + pool_id: u32, + missing_heartbeat_threshold: u32, + plugins: Vec, + metrics: Arc, +) -> Result<(), anyhow::Error> { + let start_time = Instant::now(); + let old_status = node.status.clone(); + let heartbeat = store_context + .heartbeat_store + .get_heartbeat(&node.address) + .await?; + let unhealthy_counter: u32 = store_context + .heartbeat_store + .get_unhealthy_counter(&node.address) + .await?; + + let is_node_in_pool = is_node_in_pool(contracts, pool_id, &node).await; + let mut status_changed = false; + let mut new_status = node.status.clone(); + + match heartbeat { + Some(beat) => { + // Update version if necessary + if let Some(version) = &beat.version { + if node.version.as_ref() != Some(version) { + if let Err(e) = store_context + .node_store + .update_node_version(&node.address, version) .await { - error!("Error clearing unhealthy counter: {e}"); + error!("Error updating node version: {e}"); } } - None => { - // We don't have a heartbeat, increment unhealthy counter - if let Err(e) = self - .store_context - .heartbeat_store - .increment_unhealthy_counter(&node.address) - .await - { - error!("Error incrementing unhealthy counter: {e}"); - } + } - match node.status { - NodeStatus::Healthy => { - new_status = NodeStatus::Unhealthy; - status_changed = true; - } - NodeStatus::Unhealthy => { - if unhealthy_counter + 1 >= self.missing_heartbeat_threshold { - new_status = NodeStatus::Dead; - status_changed = true; - } - } - NodeStatus::Discovered => { - if is_node_in_pool { - // We have caught a very interesting edge case here. - // The node is in pool but does not send heartbeats - maybe due to a downtime of the orchestrator? - // Node invites fail now since the node cannot be in pool again. - // We have to eject and re-invite - we can simply do this by setting the status to unhealthy. The node will eventually be ejected. - new_status = NodeStatus::Unhealthy; - status_changed = true; - } else { - // if we've been trying to invite this node for a while, we eventually give up and mark it as dead - // The node will simply be in status discovered again when the discovery svc date > status change date. - if unhealthy_counter + 1 > 360 { - new_status = NodeStatus::Dead; - status_changed = true; - } - } - } - NodeStatus::WaitingForHeartbeat => { - if unhealthy_counter + 1 >= self.missing_heartbeat_threshold { - // Unhealthy counter is reset when node is invited - // usually it starts directly with heartbeat - new_status = NodeStatus::Unhealthy; - status_changed = true; - } - } - _ => (), - } + // Check if the node is in the pool (needed for status transitions) + + // If node is Unhealthy or WaitingForHeartbeat: + if node.status == NodeStatus::Unhealthy + || node.status == NodeStatus::WaitingForHeartbeat + { + if is_node_in_pool { + new_status = NodeStatus::Healthy; + } else { + // Reset to discovered to init re-invite to pool + new_status = NodeStatus::Discovered; + } + status_changed = true; + } + // If node is Discovered or Dead: + else if node.status == NodeStatus::Discovered || node.status == NodeStatus::Dead { + if is_node_in_pool { + new_status = NodeStatus::Healthy; + } else { + new_status = NodeStatus::Discovered; } + status_changed = true; } - if status_changed { - // Clean up metrics when node becomes Dead, Ejected, or Banned - if matches!( - &new_status, - NodeStatus::Dead | NodeStatus::Ejected | NodeStatus::Banned - ) { - let node_metrics = match self - .store_context - .metrics_store - .get_metrics_for_node(node.address) - .await - { - Ok(metrics) => metrics, - Err(e) => { - error!("Error getting metrics for node: {e}"); - Default::default() - } - }; - - for (task_id, task_metrics) in node_metrics { - for (label, _value) in task_metrics { - // Remove from Redis metrics store - if let Err(e) = self - .store_context - .metrics_store - .delete_metric(&task_id, &label, &node.address.to_string()) - .await - { - error!("Error deleting metric: {e}"); - } + // Clear unhealthy counter on heartbeat receipt + if let Err(e) = store_context + .heartbeat_store + .clear_unhealthy_counter(&node.address) + .await + { + error!("Error clearing unhealthy counter: {e}"); + } + } + None => { + // We don't have a heartbeat, increment unhealthy counter + if let Err(e) = store_context + .heartbeat_store + .increment_unhealthy_counter(&node.address) + .await + { + error!("Error incrementing unhealthy counter: {e}"); + } + + match node.status { + NodeStatus::Healthy => { + new_status = NodeStatus::Unhealthy; + status_changed = true; + } + NodeStatus::Unhealthy => { + if unhealthy_counter + 1 >= missing_heartbeat_threshold { + new_status = NodeStatus::Dead; + status_changed = true; + } + } + NodeStatus::Discovered => { + if is_node_in_pool { + // We have caught a very interesting edge case here. + // The node is in pool but does not send heartbeats - maybe due to a downtime of the orchestrator? + // Node invites fail now since the node cannot be in pool again. + // We have to eject and re-invite - we can simply do this by setting the status to unhealthy. The node will eventually be ejected. + new_status = NodeStatus::Unhealthy; + status_changed = true; + } else { + // if we've been trying to invite this node for a while, we eventually give up and mark it as dead + // The node will simply be in status discovered again when the discovery svc date > status change date. + if unhealthy_counter + 1 > 360 { + new_status = NodeStatus::Dead; + status_changed = true; } } } + NodeStatus::WaitingForHeartbeat => { + if unhealthy_counter + 1 >= missing_heartbeat_threshold { + // Unhealthy counter is reset when node is invited + // usually it starts directly with heartbeat + new_status = NodeStatus::Unhealthy; + status_changed = true; + } + } + _ => (), + } + } + } - if let Err(e) = self - .store_context - .node_store - .update_node_status(&node.address, new_status) - .await - { - error!("Error updating node status: {e}"); + if status_changed { + // Clean up metrics when node becomes Dead, Ejected, or Banned + if matches!( + &new_status, + NodeStatus::Dead | NodeStatus::Ejected | NodeStatus::Banned + ) { + let node_metrics = match store_context + .metrics_store + .get_metrics_for_node(node.address) + .await + { + Ok(metrics) => metrics, + Err(e) => { + error!("Error getting metrics for node: {e}"); + Default::default() } + }; - if let Some(updated_node) = self - .store_context - .node_store - .get_node(&node.address) - .await? - { - for plugin in self.plugins.iter() { - if let Err(e) = plugin - .handle_status_change(&updated_node, &old_status) - .await - { - error!("Error handling status change: {e}"); - } + for (task_id, task_metrics) in node_metrics { + for (label, _value) in task_metrics { + // Remove from Redis metrics store + if let Err(e) = store_context + .metrics_store + .delete_metric(&task_id, &label, &node.address.to_string()) + .await + { + error!("Error deleting metric: {e}"); } } } // Record status update execution time let duration = start_time.elapsed(); - self.metrics.record_status_update_execution_time( + metrics.record_status_update_execution_time( &node.address.to_string(), duration.as_secs_f64(), ); } - Ok(()) + + if let Err(e) = store_context + .node_store + .update_node_status(&node.address, new_status) + .await + { + error!("Error updating node status: {e}"); + } + + if let Some(updated_node) = store_context.node_store.get_node(&node.address).await? { + for plugin in plugins.iter() { + if let Err(e) = plugin + .handle_status_change(&updated_node, &old_status) + .await + { + error!("Error handling status change: {e}"); + } + } + } } + Ok(()) +} + +#[cfg(test)] +async fn is_node_in_pool(_: Contracts, _: u32, _: &OrchestratorNode) -> bool { + true +} + +#[cfg(not(test))] +async fn is_node_in_pool( + contracts: Contracts, + pool_id: u32, + node: &OrchestratorNode, +) -> bool { + contracts + .compute_pool + .is_node_in_pool(pool_id, node.address) + .await + .unwrap_or(false) } #[cfg(test)] diff --git a/crates/orchestrator/src/store/domains/task_store.rs b/crates/orchestrator/src/store/domains/task_store.rs index 9dc23f0d..3442dc50 100644 --- a/crates/orchestrator/src/store/domains/task_store.rs +++ b/crates/orchestrator/src/store/domains/task_store.rs @@ -1,5 +1,5 @@ -use crate::events::TaskObserver; use crate::store::core::RedisStore; +use crate::NodeGroupsPlugin; use anyhow::Result; use futures::future; use log::error; @@ -14,7 +14,7 @@ const TASK_NAME_INDEX_KEY: &str = "orchestrator:task_names"; pub struct TaskStore { redis: Arc, - observers: Arc>>>, + observers: Arc>>>, } impl TaskStore { @@ -25,7 +25,7 @@ impl TaskStore { } } - pub async fn add_observer(&self, observer: Arc) { + pub async fn add_observer(&self, observer: Arc) { let mut observers = self.observers.lock().await; observers.push(observer); }