diff --git a/desktop/Backend-Rust/src/llm/client.rs b/desktop/Backend-Rust/src/llm/client.rs index 118f497871..d721402e62 100644 --- a/desktop/Backend-Rust/src/llm/client.rs +++ b/desktop/Backend-Rust/src/llm/client.rs @@ -211,17 +211,16 @@ struct GeminiPartResponse { } impl LlmClient { - /// Create a new Gemini client + /// Create a new Gemini client with the QoS-configured default model. pub fn new(api_key: String) -> Self { Self { client: Client::new(), api_key, - model: "gemini-3-flash-preview".to_string(), + model: super::model_qos::gemini_default().to_string(), } } /// Set the model to use - #[allow(dead_code)] pub fn with_model(mut self, model: &str) -> Self { self.model = model.to_string(); self @@ -1162,3 +1161,29 @@ Return relationships as source -> relationship -> target triples."#, Ok(result) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn new_uses_qos_default_model() { + let client = LlmClient::new("test-key".to_string()); + assert_eq!(client.model, super::super::model_qos::gemini_default()); + } + + #[test] + fn with_model_overrides_default() { + let client = LlmClient::new("test-key".to_string()) + .with_model("gemini-pro-latest"); + assert_eq!(client.model, "gemini-pro-latest"); + } + + #[test] + fn with_model_extraction_uses_extraction_accessor() { + let client = LlmClient::new("test-key".to_string()) + .with_model(super::super::model_qos::gemini_extraction()); + // In test env (premium tier), extraction == default == flash + assert_eq!(client.model, "gemini-3-flash-preview"); + } +} diff --git a/desktop/Backend-Rust/src/llm/mod.rs b/desktop/Backend-Rust/src/llm/mod.rs index c444e5f24c..2407d115ef 100644 --- a/desktop/Backend-Rust/src/llm/mod.rs +++ b/desktop/Backend-Rust/src/llm/mod.rs @@ -1,6 +1,7 @@ // LLM module pub mod client; +pub mod model_qos; pub mod persona; pub mod prompts; diff --git a/desktop/Backend-Rust/src/llm/model_qos.rs b/desktop/Backend-Rust/src/llm/model_qos.rs new file mode 100644 index 0000000000..548fb14da8 --- /dev/null +++ b/desktop/Backend-Rust/src/llm/model_qos.rs @@ -0,0 +1,215 @@ +// Model QoS Tier System for Rust Backend +// +// Central model configuration with switchable tiers, mirroring the Swift ModelQoS. +// All LlmClient call sites should use these accessors instead of hardcoded model strings. +// +// Tier is read from OMI_MODEL_TIER env var at startup (default: "premium"). + +use std::sync::OnceLock; + +/// Active tier, resolved once from OMI_MODEL_TIER env var. +static ACTIVE_TIER: OnceLock = OnceLock::new(); + +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum ModelTier { + /// Cost-optimized: Flash for all Gemini workloads, lower rate limits + Premium, + /// Quality-optimized: same models, higher rate limits + Max, +} + +impl ModelTier { + fn from_env() -> Self { + match std::env::var("OMI_MODEL_TIER").as_deref() { + Ok("max") => ModelTier::Max, + _ => ModelTier::Premium, + } + } +} + +/// Get the active model tier (resolved once from env). +pub fn active_tier() -> ModelTier { + *ACTIVE_TIER.get_or_init(ModelTier::from_env) +} + +// MARK: - Gemini Models + +/// Default model for LlmClient (used by chat, conversations, personas, knowledge graph). +pub fn gemini_default() -> &'static str { + gemini_default_for(active_tier()) +} + +fn gemini_default_for(tier: ModelTier) -> &'static str { + match tier { + ModelTier::Premium => "gemini-3-flash-preview", + ModelTier::Max => "gemini-3-flash-preview", + } +} + +/// Model for structured extraction tasks (conversations, knowledge graph). +pub fn gemini_extraction() -> &'static str { + gemini_extraction_for(active_tier()) +} + +fn gemini_extraction_for(_tier: ModelTier) -> &'static str { + "gemini-3-flash-preview" +} + +/// Allowed models for the Gemini proxy (passthrough from Swift app). +/// These are the models the desktop app is allowed to request. +pub fn gemini_proxy_allowed() -> &'static [&'static str] { + &[ + "gemini-3-flash-preview", + "gemini-embedding-001", + ] +} + +/// Model that rate-limited Pro requests degrade to. +pub fn gemini_degrade_target() -> &'static str { + "gemini-3-flash-preview" +} + +// MARK: - Rate Limit Thresholds (tier-aware) + +/// Daily soft limit — at or above this, Pro requests degrade to Flash. +/// Premium: aggressive (30) since premium already sends Flash. +/// Max: generous (300) to allow Pro usage. +pub fn daily_soft_limit() -> u32 { + daily_soft_limit_for(active_tier()) +} + +fn daily_soft_limit_for(tier: ModelTier) -> u32 { + match tier { + ModelTier::Premium => 30, + ModelTier::Max => 300, + } +} + +/// Daily hard limit — at or above this, all requests are rejected (429). +pub fn daily_hard_limit() -> u32 { + daily_hard_limit_for(active_tier()) +} + +fn daily_hard_limit_for(_tier: ModelTier) -> u32 { + 1500 +} + +/// Tier description for logging. +pub fn tier_description() -> &'static str { + tier_description_for(active_tier()) +} + +fn tier_description_for(tier: ModelTier) -> &'static str { + match tier { + ModelTier::Premium => "Premium (cost-optimized)", + ModelTier::Max => "Max (quality-optimized)", + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Mutex; + + /// Serialize env-var-mutating tests to avoid races under parallel execution. + static ENV_LOCK: Mutex<()> = Mutex::new(()); + + // --- ModelTier::from_env (serialized — shares process env) --- + + #[test] + fn from_env_all_cases() { + let _guard = ENV_LOCK.lock().unwrap(); + + // Default (unset) → Premium + std::env::remove_var("OMI_MODEL_TIER"); + assert_eq!(ModelTier::from_env(), ModelTier::Premium); + + // Explicit max → Max + std::env::set_var("OMI_MODEL_TIER", "max"); + assert_eq!(ModelTier::from_env(), ModelTier::Max); + + // Invalid value → Premium fallback + std::env::set_var("OMI_MODEL_TIER", "garbage"); + assert_eq!(ModelTier::from_env(), ModelTier::Premium); + + // Empty string → Premium fallback + std::env::set_var("OMI_MODEL_TIER", ""); + assert_eq!(ModelTier::from_env(), ModelTier::Premium); + + std::env::remove_var("OMI_MODEL_TIER"); + } + + // --- gemini_default_for (both tiers) --- + + #[test] + fn gemini_default_premium_is_flash() { + assert_eq!(gemini_default_for(ModelTier::Premium), "gemini-3-flash-preview"); + } + + #[test] + fn gemini_default_max_is_flash() { + // Default model is Flash for both tiers (cheap baseline) + assert_eq!(gemini_default_for(ModelTier::Max), "gemini-3-flash-preview"); + } + + // --- gemini_extraction_for (the tier-dependent branch) --- + + #[test] + fn gemini_extraction_is_flash_for_both_tiers() { + assert_eq!(gemini_extraction_for(ModelTier::Premium), "gemini-3-flash-preview"); + assert_eq!(gemini_extraction_for(ModelTier::Max), "gemini-3-flash-preview"); + } + + // --- tier_description_for --- + + #[test] + fn tier_description_premium() { + assert!(tier_description_for(ModelTier::Premium).contains("Premium")); + } + + #[test] + fn tier_description_max() { + assert!(tier_description_for(ModelTier::Max).contains("Max")); + } + + // --- Static accessors (pinned models) --- + + #[test] + fn proxy_allowed_contains_expected_models() { + let allowed = gemini_proxy_allowed(); + assert!(allowed.contains(&"gemini-3-flash-preview")); + assert!(allowed.contains(&"gemini-embedding-001")); + assert!(!allowed.contains(&"gemini-pro-latest"), "pro removed from allowlist"); + assert!(!allowed.contains(&"gemini-ultra")); + } + + #[test] + fn degrade_target_is_flash() { + assert_eq!(gemini_degrade_target(), "gemini-3-flash-preview"); + } + + // --- Rate limit thresholds --- + + #[test] + fn daily_soft_limit_premium_is_lower() { + assert_eq!(daily_soft_limit_for(ModelTier::Premium), 30); + } + + #[test] + fn daily_soft_limit_max_is_higher() { + assert_eq!(daily_soft_limit_for(ModelTier::Max), 300); + } + + #[test] + fn daily_hard_limit_same_for_both_tiers() { + assert_eq!(daily_hard_limit_for(ModelTier::Premium), 1500); + assert_eq!(daily_hard_limit_for(ModelTier::Max), 1500); + } + + #[test] + fn soft_limit_always_below_hard_limit() { + for tier in [ModelTier::Premium, ModelTier::Max] { + assert!(daily_soft_limit_for(tier) < daily_hard_limit_for(tier)); + } + } +} diff --git a/desktop/Backend-Rust/src/main.rs b/desktop/Backend-Rust/src/main.rs index 9ff6701345..fa8d779a67 100644 --- a/desktop/Backend-Rust/src/main.rs +++ b/desktop/Backend-Rust/src/main.rs @@ -89,6 +89,13 @@ async fn main() { // Load environment variables dotenvy::dotenv().ok(); + // Log active QoS tier + tracing::info!("Model QoS tier: {} | rate limits: soft={}, hard={}", + llm::model_qos::tier_description(), + llm::model_qos::daily_soft_limit(), + llm::model_qos::daily_hard_limit(), + ); + // Load and validate config let config = Config::from_env(); if let Err(e) = config.validate() { diff --git a/desktop/Backend-Rust/src/routes/conversations.rs b/desktop/Backend-Rust/src/routes/conversations.rs index f8d4078d2e..a808394968 100644 --- a/desktop/Backend-Rust/src/routes/conversations.rs +++ b/desktop/Backend-Rust/src/routes/conversations.rs @@ -191,6 +191,7 @@ async fn create_conversation_from_segments( // Get LLM client (Gemini) let llm_client = if let Some(api_key) = &state.config.gemini_api_key { LlmClient::new(api_key.clone()) + .with_model(crate::llm::model_qos::gemini_extraction()) } else { return Err(( StatusCode::INTERNAL_SERVER_ERROR, @@ -441,6 +442,7 @@ async fn reprocess_conversation( // Get LLM client (Gemini) let llm_client = if let Some(api_key) = &state.config.gemini_api_key { LlmClient::new(api_key.clone()) + .with_model(crate::llm::model_qos::gemini_extraction()) } else { return Err(( StatusCode::INTERNAL_SERVER_ERROR, @@ -841,7 +843,8 @@ async fn merge_conversations( // If reprocessing is requested and we have an LLM client, process the merged conversation if request.reprocess { if let Some(api_key) = &state.config.gemini_api_key { - let llm = LlmClient::new(api_key.clone()); + let llm = LlmClient::new(api_key.clone()) + .with_model(crate::llm::model_qos::gemini_extraction()); // Get existing data for deduplication let existing_memories = state diff --git a/desktop/Backend-Rust/src/routes/knowledge_graph.rs b/desktop/Backend-Rust/src/routes/knowledge_graph.rs index c6e910ec06..9c25d63ecb 100644 --- a/desktop/Backend-Rust/src/routes/knowledge_graph.rs +++ b/desktop/Backend-Rust/src/routes/knowledge_graph.rs @@ -94,7 +94,8 @@ async fn rebuild_knowledge_graph( tracing::info!("Processing {} memories for knowledge graph", memories.len()); // Create LLM client - let llm = LlmClient::new(api_key); + let llm = LlmClient::new(api_key) + .with_model(crate::llm::model_qos::gemini_extraction()); // Track nodes by lowercase label for deduplication let mut node_map: HashMap = HashMap::new(); diff --git a/desktop/Backend-Rust/src/routes/proxy.rs b/desktop/Backend-Rust/src/routes/proxy.rs index 3cee76e660..d345d3cffa 100644 --- a/desktop/Backend-Rust/src/routes/proxy.rs +++ b/desktop/Backend-Rust/src/routes/proxy.rs @@ -27,14 +27,9 @@ const GEMINI_ALLOWED_ACTIONS: &[&str] = &[ "batchEmbedContents", ]; -// Allowed Gemini models — only these can be requested through the proxy. -// Desktop app uses: gemini-3-flash-preview (default), gemini-pro-latest (tasks/insights), -// gemini-embedding-001 (embeddings). Rate limiting may rewrite pro → flash. -const GEMINI_ALLOWED_MODELS: &[&str] = &[ - "gemini-3-flash-preview", - "gemini-pro-latest", - "gemini-embedding-001", -]; +// Allowed Gemini models — driven by model_qos (issue #6834). +// Desktop app uses: gemini-3-flash-preview (all features), gemini-embedding-001 (embeddings). +// Rate limiting may degrade requests above soft limit. /// Maximum request body size for Gemini proxy routes (5 MB). /// Normal app payloads are 300-600 KB (base64 JPEG + prompt); 5 MB gives ~8x headroom. @@ -458,9 +453,9 @@ fn is_gemini_action_allowed(action: &str) -> bool { GEMINI_ALLOWED_ACTIONS.contains(&action) } -/// Check if a Gemini model is in the allowlist (issue #6624) +/// Check if a Gemini model is in the allowlist (issue #6624, #6834) fn is_gemini_model_allowed(model: &str) -> bool { - GEMINI_ALLOWED_MODELS.contains(&model) + crate::llm::model_qos::gemini_proxy_allowed().contains(&model) } /// Sanitize a Gemini request body (issue #6624). @@ -722,12 +717,12 @@ mod tests { #[test] fn model_allowlist_permits_valid_models() { assert!(is_gemini_model_allowed("gemini-3-flash-preview")); - assert!(is_gemini_model_allowed("gemini-pro-latest")); assert!(is_gemini_model_allowed("gemini-embedding-001")); } #[test] fn model_allowlist_blocks_unknown() { + assert!(!is_gemini_model_allowed("gemini-pro-latest"), "pro removed from allowlist"); assert!(!is_gemini_model_allowed("gemini-2.5-pro")); assert!(!is_gemini_model_allowed("gemini-1.5-pro")); assert!(!is_gemini_model_allowed("gemini-ultra")); diff --git a/desktop/Backend-Rust/src/routes/rate_limit.rs b/desktop/Backend-Rust/src/routes/rate_limit.rs index c11be87f3d..e571260060 100644 --- a/desktop/Backend-Rust/src/routes/rate_limit.rs +++ b/desktop/Backend-Rust/src/routes/rate_limit.rs @@ -23,11 +23,8 @@ use tokio::sync::Mutex; use crate::services::RedisService; -/// Daily soft limit — at or above this, Pro requests are degraded to Flash. -const DAILY_SOFT_LIMIT: u32 = 300; - -/// Daily hard limit — at or above this, all requests are rejected with 429. -const DAILY_HARD_LIMIT: u32 = 1500; +// Daily soft/hard limits are tier-aware — see crate::llm::model_qos. +use crate::llm::model_qos; /// Burst cap — max requests per rolling 60-second window. const BURST_PER_MINUTE: usize = 30; @@ -60,9 +57,9 @@ impl RateSnapshot { fn to_decision(&self) -> RateDecision { if self.burst_count > BURST_PER_MINUTE { RateDecision::Reject - } else if self.daily_count >= DAILY_HARD_LIMIT { + } else if self.daily_count >= model_qos::daily_hard_limit() { RateDecision::Reject - } else if self.daily_count >= DAILY_SOFT_LIMIT { + } else if self.daily_count >= model_qos::daily_soft_limit() { RateDecision::DegradeToFlash } else { RateDecision::Allow @@ -197,7 +194,7 @@ pub fn maybe_rewrite_model_path(path: &str, decision: &RateDecision, action: &st return path.to_string(); } if let Some(rest) = path.strip_prefix("models/gemini-pro-latest:") { - return format!("models/gemini-3-flash-preview:{}", rest); + return format!("models/{}:{}", crate::llm::model_qos::gemini_degrade_target(), rest); } path.to_string() } @@ -218,23 +215,25 @@ pub fn rate_limit_error_json(message: &str) -> String { mod tests { use super::*; - // --- Decision from snapshot --- + // --- Decision from snapshot (uses QoS tier — Premium in test env: soft=30, hard=1500) --- #[test] fn snapshot_allow() { - let s = RateSnapshot { daily_count: 100, burst_count: 5 }; + let s = RateSnapshot { daily_count: 10, burst_count: 5 }; assert_eq!(s.to_decision(), RateDecision::Allow); } #[test] fn snapshot_degrade_at_soft_limit() { - let s = RateSnapshot { daily_count: 300, burst_count: 5 }; + let soft = model_qos::daily_soft_limit(); + let s = RateSnapshot { daily_count: soft, burst_count: 5 }; assert_eq!(s.to_decision(), RateDecision::DegradeToFlash); } #[test] fn snapshot_reject_at_hard_limit() { - let s = RateSnapshot { daily_count: 1500, burst_count: 5 }; + let hard = model_qos::daily_hard_limit(); + let s = RateSnapshot { daily_count: hard, burst_count: 5 }; assert_eq!(s.to_decision(), RateDecision::Reject); } @@ -251,6 +250,22 @@ mod tests { assert_eq!(s.to_decision(), RateDecision::Allow); } + // --- Boundary: just below thresholds --- + + #[test] + fn snapshot_allow_just_below_soft_limit() { + let soft = model_qos::daily_soft_limit(); + let s = RateSnapshot { daily_count: soft - 1, burst_count: 5 }; + assert_eq!(s.to_decision(), RateDecision::Allow); + } + + #[test] + fn snapshot_degrade_just_below_hard_limit() { + let hard = model_qos::daily_hard_limit(); + let s = RateSnapshot { daily_count: hard - 1, burst_count: 5 }; + assert_eq!(s.to_decision(), RateDecision::DegradeToFlash); + } + // --- No Redis → unmetered (cache bypassed entirely) --- #[tokio::test] diff --git a/desktop/Desktop/Sources/AppleNotesReaderService.swift b/desktop/Desktop/Sources/AppleNotesReaderService.swift index 3a075c800a..afae47ba03 100644 --- a/desktop/Desktop/Sources/AppleNotesReaderService.swift +++ b/desktop/Desktop/Sources/AppleNotesReaderService.swift @@ -142,7 +142,7 @@ actor AppleNotesReaderService { prompt: synthesisPrompt, systemPrompt: "You extract high-signal user facts from Apple Notes. Output only valid JSON.", - model: "claude-opus-4-6", + model: ModelQoS.Claude.synthesis, onTextDelta: { @Sendable _ in }, onToolCall: { @Sendable _, _, _ in "" }, onToolActivity: { @Sendable _, _, _, _ in } diff --git a/desktop/Desktop/Sources/CalendarReaderService.swift b/desktop/Desktop/Sources/CalendarReaderService.swift index 4d18c66fa2..b96b068ffd 100644 --- a/desktop/Desktop/Sources/CalendarReaderService.swift +++ b/desktop/Desktop/Sources/CalendarReaderService.swift @@ -191,7 +191,7 @@ actor CalendarReaderService { prompt: synthesisPrompt, systemPrompt: "You are a profile extraction assistant. Analyze calendar events and output structured JSON. Be concise and factual.", - model: "claude-opus-4-6", + model: ModelQoS.Claude.synthesis, onTextDelta: { @Sendable _ in }, onToolCall: { @Sendable _, _, _ in return "" }, onToolActivity: { @Sendable _, _, _, _ in } diff --git a/desktop/Desktop/Sources/FloatingControlBar/FloatingControlBarState.swift b/desktop/Desktop/Sources/FloatingControlBar/FloatingControlBarState.swift index 697d185cd7..8608cb771e 100644 --- a/desktop/Desktop/Sources/FloatingControlBar/FloatingControlBarState.swift +++ b/desktop/Desktop/Sources/FloatingControlBar/FloatingControlBarState.swift @@ -103,13 +103,10 @@ class FloatingControlBarState: NSObject, ObservableObject { @Published var currentQueryFromVoice: Bool = false // Model selection - @Published var selectedModel: String = "claude-sonnet-4-6" + @Published var selectedModel: String = ModelQoS.Claude.defaultSelection - /// Available models for the floating bar picker - static let availableModels: [(id: String, label: String)] = [ - ("claude-sonnet-4-6", "Sonnet"), - ("claude-opus-4-6", "Opus"), - ] + /// Available models for the floating bar picker (driven by QoS tier) + static var availableModels: [(id: String, label: String)] { ModelQoS.Claude.availableModels } var isShowingNotification: Bool { currentNotification != nil diff --git a/desktop/Desktop/Sources/FloatingControlBar/FloatingControlBarWindow.swift b/desktop/Desktop/Sources/FloatingControlBar/FloatingControlBarWindow.swift index 4805028f4a..9805de74ed 100644 --- a/desktop/Desktop/Sources/FloatingControlBar/FloatingControlBarWindow.swift +++ b/desktop/Desktop/Sources/FloatingControlBar/FloatingControlBarWindow.swift @@ -1549,7 +1549,7 @@ class FloatingControlBarManager { } let floatingModel = ShortcutSettings.shared.selectedModel.isEmpty - ? "claude-sonnet-4-6" + ? ModelQoS.Claude.defaultSelection : ShortcutSettings.shared.selectedModel let notificationContextSuffix = notificationContextSuffixIfNeeded(for: message) await provider.sendMessage( diff --git a/desktop/Desktop/Sources/FloatingControlBar/ShortcutSettings.swift b/desktop/Desktop/Sources/FloatingControlBar/ShortcutSettings.swift index 12ff2206d7..39f45d7669 100644 --- a/desktop/Desktop/Sources/FloatingControlBar/ShortcutSettings.swift +++ b/desktop/Desktop/Sources/FloatingControlBar/ShortcutSettings.swift @@ -324,11 +324,8 @@ class ShortcutSettings: ObservableObject { didSet { UserDefaults.standard.set(selectedModel, forKey: "shortcut_selectedModel") } } - /// Available models for Ask Omi. - static let availableModels: [(id: String, label: String)] = [ - ("claude-sonnet-4-6", "Sonnet"), - ("claude-opus-4-6", "Opus"), - ] + /// Available models for Ask Omi (driven by QoS tier). + static var availableModels: [(id: String, label: String)] { ModelQoS.Claude.availableModels } /// Push-to-talk transcription mode. enum PTTTranscriptionMode: String, CaseIterable { @@ -471,7 +468,9 @@ class ShortcutSettings: ObservableObject { self.doubleTapForLock = UserDefaults.standard.object(forKey: "shortcut_doubleTapForLock") as? Bool ?? true self.solidBackground = UserDefaults.standard.object(forKey: "shortcut_solidBackground") as? Bool ?? false self.pttSoundsEnabled = UserDefaults.standard.object(forKey: "shortcut_pttSoundsEnabled") as? Bool ?? true - self.selectedModel = UserDefaults.standard.string(forKey: "shortcut_selectedModel") ?? "claude-sonnet-4-6" + self.selectedModel = ModelQoS.Claude.sanitizedSelection( + UserDefaults.standard.string(forKey: "shortcut_selectedModel") + ) if let saved = UserDefaults.standard.string(forKey: "shortcut_pttTranscriptionMode"), let mode = PTTTranscriptionMode(rawValue: saved) { self.pttTranscriptionMode = mode @@ -488,6 +487,11 @@ class ShortcutSettings: ObservableObject { ? storedVoiceID : Self.defaultVoiceID self.selectedVoiceID = validVoiceID + + NotificationCenter.default.addObserver(forName: .modelTierDidChange, object: nil, queue: .main) { [weak self] _ in + guard let self else { return } + self.selectedModel = ModelQoS.Claude.sanitizedSelection(self.selectedModel) + } } private func persistShortcut(_ shortcut: KeyboardShortcut, forKey key: String) { diff --git a/desktop/Desktop/Sources/GmailReaderService.swift b/desktop/Desktop/Sources/GmailReaderService.swift index 02cd1a3f29..ebf1f6be0b 100644 --- a/desktop/Desktop/Sources/GmailReaderService.swift +++ b/desktop/Desktop/Sources/GmailReaderService.swift @@ -267,7 +267,7 @@ actor GmailReaderService { prompt: synthesisPrompt, systemPrompt: "You are a profile extraction assistant. Output ONLY valid JSON. No markdown, no code fences, no explanation.", - model: "claude-opus-4-6", + model: ModelQoS.Claude.synthesis, onTextDelta: { @Sendable _ in }, onToolCall: { @Sendable _, _, _ in return "" }, onToolActivity: { @Sendable _, _, _, _ in } diff --git a/desktop/Desktop/Sources/MainWindow/Pages/ChatLabView.swift b/desktop/Desktop/Sources/MainWindow/Pages/ChatLabView.swift index d7044da880..dd93714b91 100644 --- a/desktop/Desktop/Sources/MainWindow/Pages/ChatLabView.swift +++ b/desktop/Desktop/Sources/MainWindow/Pages/ChatLabView.swift @@ -436,7 +436,7 @@ class ChatLabViewModel: ObservableObject { request.setValue("2023-06-01", forHTTPHeaderField: "anthropic-version") let body: [String: Any] = [ - "model": "claude-sonnet-4-20250514", + "model": ModelQoS.Claude.chatLabQuery, "max_tokens": 1024, "system": systemPrompt.prefix(50000), "messages": [["role": "user", "content": userMessage]], @@ -481,7 +481,7 @@ class ChatLabViewModel: ObservableObject { """ let body: [String: Any] = [ - "model": "claude-haiku-4-5-20251001", + "model": ModelQoS.Claude.chatLabGrade, "max_tokens": 200, "messages": [["role": "user", "content": gradePrompt]], ] diff --git a/desktop/Desktop/Sources/ModelQoS.swift b/desktop/Desktop/Sources/ModelQoS.swift new file mode 100644 index 0000000000..f34da5b1f1 --- /dev/null +++ b/desktop/Desktop/Sources/ModelQoS.swift @@ -0,0 +1,96 @@ +import Foundation + +// MARK: - Model QoS Tier System +// +// Central model configuration with switchable tiers. +// Change `activeTier` to switch all models at once. +// Individual workloads can also override their tier. + +enum ModelTier: String, CaseIterable { + case premium // Cost-optimized: Sonnet + Haiku for Claude, Flash for Gemini + case max // Quality-optimized: higher rate limits, same models +} + +struct ModelQoS { + // MARK: - Active Tier (single switch) + + private static let tierKey = "modelQoS_activeTier" + + static var activeTier: ModelTier { + get { + guard let raw = UserDefaults.standard.string(forKey: tierKey), + let tier = ModelTier(rawValue: raw) else { + return .premium + } + return tier + } + set { + UserDefaults.standard.set(newValue.rawValue, forKey: tierKey) + NotificationCenter.default.post(name: .modelTierDidChange, object: nil) + } + } + + // MARK: - Claude Models + + struct Claude { + /// Main chat session model (user-facing conversations) + static var chat: String { "claude-sonnet-4-6" } + + /// Floating bar responses + static var floatingBar: String { "claude-sonnet-4-6" } + + /// Synthesis extraction tasks (calendar, gmail, notes, memory import) + static var synthesis: String { "claude-haiku-4-5-20251001" } + + /// ChatLab test queries + static var chatLabQuery: String { "claude-sonnet-4-20250514" } + + /// ChatLab grading (always cheap) + static var chatLabGrade: String { "claude-haiku-4-5-20251001" } + + /// Available models shown in the UI picker + static var availableModels: [(id: String, label: String)] { + [("claude-sonnet-4-6", "Sonnet")] + } + + /// Default model for user selection (floating bar / shortcut picker) + static var defaultSelection: String { "claude-sonnet-4-6" } + + /// Sanitize a persisted model ID against the current tier's allowed list. + /// Returns the saved model if it's still available, otherwise falls back to defaultSelection. + static func sanitizedSelection(_ savedModel: String?) -> String { + let model = savedModel ?? defaultSelection + let allowedIDs = availableModels.map(\.id) + return allowedIDs.contains(model) ? model : defaultSelection + } + } + + // MARK: - Gemini Models + + struct Gemini { + /// Proactive assistants (screenshot analysis, context detection) + static var proactive: String { "gemini-3-flash-preview" } + + /// Task extraction + static var taskExtraction: String { "gemini-3-flash-preview" } + + /// Insight generation + static var insight: String { "gemini-3-flash-preview" } + + /// Embeddings (not tier-dependent, kept separate) + static var embedding: String { "gemini-embedding-001" } + } + + // MARK: - Tier Info (for UI / debugging) + + static var tierDescription: String { + switch activeTier { + case .premium: return "Premium (cost-optimized)" + case .max: return "Max (quality-optimized)" + } + } +} + +extension Notification.Name { + static let modelTierDidChange = Notification.Name("modelTierDidChange") +} diff --git a/desktop/Desktop/Sources/OnboardingChatView.swift b/desktop/Desktop/Sources/OnboardingChatView.swift index 67e8233987..0aa41d0cbc 100644 --- a/desktop/Desktop/Sources/OnboardingChatView.swift +++ b/desktop/Desktop/Sources/OnboardingChatView.swift @@ -1383,7 +1383,7 @@ struct OnboardingChatView: View { prompt: "Begin exploration. \(fileCount) files have been indexed in the indexed_files table.", systemPrompt: systemPrompt, - model: "claude-opus-4-6", + model: ModelQoS.Claude.chat, onTextDelta: { @Sendable delta in Task { @MainActor in explorationText += delta diff --git a/desktop/Desktop/Sources/OnboardingMemoryLogImportService.swift b/desktop/Desktop/Sources/OnboardingMemoryLogImportService.swift index 0fe2c71c9e..d708fb783a 100644 --- a/desktop/Desktop/Sources/OnboardingMemoryLogImportService.swift +++ b/desktop/Desktop/Sources/OnboardingMemoryLogImportService.swift @@ -93,7 +93,7 @@ actor OnboardingMemoryLogImportService { prompt: importPrompt, systemPrompt: "You convert memory-log exports into concise durable user memories. Output only valid JSON.", - model: "claude-opus-4-6", + model: ModelQoS.Claude.synthesis, onTextDelta: { @Sendable _ in }, onToolCall: { @Sendable _, _, _ in "" }, onToolActivity: { @Sendable _, _, _, _ in } diff --git a/desktop/Desktop/Sources/OnboardingPagedIntroCoordinator.swift b/desktop/Desktop/Sources/OnboardingPagedIntroCoordinator.swift index 6ea3cebff0..4e5edf612e 100644 --- a/desktop/Desktop/Sources/OnboardingPagedIntroCoordinator.swift +++ b/desktop/Desktop/Sources/OnboardingPagedIntroCoordinator.swift @@ -996,7 +996,7 @@ final class OnboardingPagedIntroCoordinator: ObservableObject { prompt: prompt, systemPrompt: "You are a structured onboarding research assistant. Output only valid JSON.", - model: "claude-opus-4-6", + model: ModelQoS.Claude.chat, onTextDelta: { @Sendable _ in }, onToolCall: { @Sendable _, _, _ in return "" }, onToolActivity: { @Sendable _, _, _, _ in } diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/Insight/InsightAssistant.swift b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/Insight/InsightAssistant.swift index 02b1b3d419..f7a9414f9c 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/Insight/InsightAssistant.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/Insight/InsightAssistant.swift @@ -63,8 +63,7 @@ actor InsightAssistant: ProactiveAssistant { // MARK: - Initialization init(apiKey: String? = nil) throws { - // Use Gemini 3.1 Pro for better insight quality (3-pro-preview retires March 9, 2026) - self.geminiClient = try GeminiClient(apiKey: apiKey, model: "gemini-pro-latest") + self.geminiClient = try GeminiClient(apiKey: apiKey, model: ModelQoS.Gemini.insight) let (stream, continuation) = AsyncStream.makeStream(of: Void.self, bufferingPolicy: .bufferingNewest(1)) self.frameSignal = stream diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskAssistant.swift b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskAssistant.swift index 55969ca029..3ced48065e 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskAssistant.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/Assistants/TaskExtraction/TaskAssistant.swift @@ -144,8 +144,7 @@ actor TaskAssistant: ProactiveAssistant { // MARK: - Initialization init(apiKey: String? = nil) throws { - // Use Gemini 3 Pro for better task extraction quality - self.geminiClient = try GeminiClient(apiKey: apiKey, model: "gemini-pro-latest") + self.geminiClient = try GeminiClient(apiKey: apiKey, model: ModelQoS.Gemini.taskExtraction) let (stream, continuation) = AsyncStream.makeStream(of: TriggerEvent.self, bufferingPolicy: .bufferingNewest(1)) self.triggerStream = stream diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Core/GeminiClient.swift b/desktop/Desktop/Sources/ProactiveAssistants/Core/GeminiClient.swift index 0acdbafef8..5caa6f6c02 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/Core/GeminiClient.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/Core/GeminiClient.swift @@ -226,7 +226,7 @@ actor GeminiClient { } } - init(apiKey: String? = nil, model: String = "gemini-3-flash-preview") throws { + init(apiKey: String? = nil, model: String = ModelQoS.Gemini.proactive) throws { // BREAKING CHANGE (issue #5861): apiKey parameter is ignored. // All Gemini requests now route through the backend proxy which supplies // the key server-side. Requires OMI_API_URL to be set (standard dev flow via run.sh). diff --git a/desktop/Desktop/Sources/ProactiveAssistants/Services/EmbeddingService.swift b/desktop/Desktop/Sources/ProactiveAssistants/Services/EmbeddingService.swift index 0c54886fef..5787938edf 100644 --- a/desktop/Desktop/Sources/ProactiveAssistants/Services/EmbeddingService.swift +++ b/desktop/Desktop/Sources/ProactiveAssistants/Services/EmbeddingService.swift @@ -7,7 +7,7 @@ actor EmbeddingService { /// Gemini embedding-001 outputs 3072 dimensions by default static let embeddingDimension = 3072 - static let modelName = "gemini-embedding-001" + static var modelName: String { ModelQoS.Gemini.embedding } /// In-memory index: action_item.id -> normalized embedding private var index: [Int64: [Float]] = [:] diff --git a/desktop/Desktop/Sources/Providers/ChatProvider.swift b/desktop/Desktop/Sources/Providers/ChatProvider.swift index d58b5bef82..9a2ca79cfc 100644 --- a/desktop/Desktop/Sources/Providers/ChatProvider.swift +++ b/desktop/Desktop/Sources/Providers/ChatProvider.swift @@ -772,11 +772,11 @@ A screenshot may be attached — use it silently only if relevant. Never mention let mainSystemPrompt = buildSystemPrompt(contextString: formatMemoriesSection()) let floatingSystemPrompt = Self.floatingBarSystemPromptPrefix + "\n\n" + mainSystemPrompt let floatingModel = ShortcutSettings.shared.selectedModel.isEmpty - ? "claude-sonnet-4-6" + ? ModelQoS.Claude.defaultSelection : ShortcutSettings.shared.selectedModel cachedMainSystemPrompt = mainSystemPrompt await acpBridge.warmupSession(cwd: workingDirectory, sessions: [ - .init(key: "main", model: "claude-opus-4-6", systemPrompt: mainSystemPrompt), + .init(key: "main", model: ModelQoS.Claude.chat, systemPrompt: mainSystemPrompt), .init(key: "floating", model: floatingModel, systemPrompt: floatingSystemPrompt) ]) return true @@ -1619,7 +1619,7 @@ A screenshot may be attached — use it silently only if relevant. Never mention prompt: question, systemPrompt: systemPrompt, sessionKey: sessionKey, - model: "claude-sonnet-4-20250514", + model: ModelQoS.Claude.chatLabQuery, onTextDelta: { _ in }, onToolCall: { callId, name, input in let toolCall = ToolCall(name: name, arguments: input, thoughtSignature: nil) diff --git a/desktop/Desktop/Tests/ModelQoSTests.swift b/desktop/Desktop/Tests/ModelQoSTests.swift new file mode 100644 index 0000000000..266c27c460 --- /dev/null +++ b/desktop/Desktop/Tests/ModelQoSTests.swift @@ -0,0 +1,139 @@ +import XCTest +@testable import Omi_Computer + +final class ModelQoSTests: XCTestCase { + private let tierKey = "modelQoS_activeTier" + + override func setUp() { + super.setUp() + UserDefaults.standard.removeObject(forKey: tierKey) + } + + override func tearDown() { + UserDefaults.standard.removeObject(forKey: tierKey) + super.tearDown() + } + + // MARK: - Default tier + + func testDefaultTierIsPremium() { + XCTAssertEqual(ModelQoS.activeTier, .premium) + } + + // MARK: - Tier persistence + + func testSetTierPersistsToUserDefaults() { + ModelQoS.activeTier = .max + XCTAssertEqual(UserDefaults.standard.string(forKey: tierKey), "max") + + ModelQoS.activeTier = .premium + XCTAssertEqual(UserDefaults.standard.string(forKey: tierKey), "premium") + } + + func testInvalidUserDefaultsFallsBackToPremium() { + UserDefaults.standard.set("invalid_tier", forKey: tierKey) + XCTAssertEqual(ModelQoS.activeTier, .premium) + } + + // MARK: - Claude models are tier-independent + + func testClaudeModelsIdenticalAcrossTiers() { + for tier in ModelTier.allCases { + ModelQoS.activeTier = tier + XCTAssertEqual(ModelQoS.Claude.chat, "claude-sonnet-4-6") + XCTAssertEqual(ModelQoS.Claude.floatingBar, "claude-sonnet-4-6") + XCTAssertEqual(ModelQoS.Claude.synthesis, "claude-haiku-4-5-20251001") + XCTAssertEqual(ModelQoS.Claude.chatLabQuery, "claude-sonnet-4-20250514") + XCTAssertEqual(ModelQoS.Claude.chatLabGrade, "claude-haiku-4-5-20251001") + XCTAssertEqual(ModelQoS.Claude.defaultSelection, "claude-sonnet-4-6") + } + } + + // MARK: - Synthesis uses Haiku (extraction workloads) + + func testSynthesisUsesHaiku() { + XCTAssertEqual(ModelQoS.Claude.synthesis, "claude-haiku-4-5-20251001") + } + + // MARK: - Chat uses Sonnet (user-facing) + + func testChatUsesSonnet() { + XCTAssertEqual(ModelQoS.Claude.chat, "claude-sonnet-4-6") + } + + // MARK: - Available models (Sonnet only, both tiers) + + func testAvailableModelsSonnetOnlyBothTiers() { + for tier in ModelTier.allCases { + ModelQoS.activeTier = tier + let ids = ModelQoS.Claude.availableModels.map(\.id) + XCTAssertEqual(ids, ["claude-sonnet-4-6"]) + } + } + + // MARK: - Gemini models are tier-independent + + func testGeminiModelsIdenticalAcrossTiers() { + for tier in ModelTier.allCases { + ModelQoS.activeTier = tier + XCTAssertEqual(ModelQoS.Gemini.proactive, "gemini-3-flash-preview") + XCTAssertEqual(ModelQoS.Gemini.taskExtraction, "gemini-3-flash-preview") + XCTAssertEqual(ModelQoS.Gemini.insight, "gemini-3-flash-preview") + XCTAssertEqual(ModelQoS.Gemini.embedding, "gemini-embedding-001") + } + } + + // MARK: - Tier description + + func testTierDescription() { + ModelQoS.activeTier = .premium + XCTAssertEqual(ModelQoS.tierDescription, "Premium (cost-optimized)") + + ModelQoS.activeTier = .max + XCTAssertEqual(ModelQoS.tierDescription, "Max (quality-optimized)") + } + + // MARK: - Sanitized selection + + func testSanitizedSelectionAllowsValidModel() { + XCTAssertEqual(ModelQoS.Claude.sanitizedSelection("claude-sonnet-4-6"), "claude-sonnet-4-6") + } + + func testSanitizedSelectionFallsBackForUnknownModel() { + XCTAssertEqual(ModelQoS.Claude.sanitizedSelection("claude-opus-4-6"), "claude-sonnet-4-6") + } + + func testSanitizedSelectionHandlesNil() { + XCTAssertEqual(ModelQoS.Claude.sanitizedSelection(nil), "claude-sonnet-4-6") + } + + func testSanitizedSelectionHandlesUnknownModel() { + XCTAssertEqual(ModelQoS.Claude.sanitizedSelection("gpt-4o"), "claude-sonnet-4-6") + } + + // MARK: - Tier change notification + + func testTierChangePostsNotification() { + let expectation = expectation(forNotification: .modelTierDidChange, object: nil) + ModelQoS.activeTier = .max + wait(for: [expectation], timeout: 1.0) + } + + // MARK: - Model count (5 unique model IDs) + + func testOnlyFiveUniqueModelIDs() { + let allModels: Set = [ + ModelQoS.Claude.chat, + ModelQoS.Claude.floatingBar, + ModelQoS.Claude.synthesis, + ModelQoS.Claude.chatLabQuery, + ModelQoS.Claude.chatLabGrade, + ModelQoS.Claude.defaultSelection, + ModelQoS.Gemini.proactive, + ModelQoS.Gemini.taskExtraction, + ModelQoS.Gemini.insight, + ModelQoS.Gemini.embedding, + ] + XCTAssertEqual(allModels.count, 5, "Expected exactly 5 unique model IDs: \(allModels)") + } +}