diff --git a/.github/workflows/cli_desktop_ci.yml b/.github/workflows/cli_desktop_ci.yml index 4fcfd25723..d3a0c058c8 100644 --- a/.github/workflows/cli_desktop_ci.yml +++ b/.github/workflows/cli_desktop_ci.yml @@ -17,6 +17,8 @@ jobs: defaults: run: shell: bash + env: + LIBONNXRUNTIME_NO_PKG_CONFIG: "1" steps: - uses: actions/checkout@v4 - uses: ./.github/actions/install_cli_deps @@ -25,5 +27,5 @@ jobs: components: clippy - uses: Swatinem/rust-cache@v2 - run: cargo check -p cli --features desktop - - run: cargo clippy -p cli --features desktop -- -D warnings + - run: cargo clippy -p cli --features desktop - run: cargo test -p cli --features desktop diff --git a/Cargo.lock b/Cargo.lock index d01feae6c3..37c5bd2452 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10887,6 +10887,8 @@ dependencies = [ "futures-util", "llm-types", "model-manager", + "pico-args", + "reqwest 0.13.2", "serde", "serde_json", "thiserror 2.0.18", @@ -20911,6 +20913,7 @@ dependencies = [ "tower 0.5.3", "tracing", "tracing-subscriber", + "transcribe-core", "url", "urlencoding", "utoipa", diff --git a/apps/cli/src/stt/config.rs b/apps/cli/src/stt/config.rs index c80a1b6a6f..d611790c22 100644 --- a/apps/cli/src/stt/config.rs +++ b/apps/cli/src/stt/config.rs @@ -78,6 +78,7 @@ impl ResolvedSttConfig { num_speakers: None, min_speakers: None, max_speakers: None, + known_speaker_references: vec![], } } } diff --git a/crates/cactus/src/llm/complete.rs b/crates/cactus/src/llm/complete.rs index 18cbf8abc5..215f922d7a 100644 --- a/crates/cactus/src/llm/complete.rs +++ b/crates/cactus/src/llm/complete.rs @@ -152,7 +152,7 @@ impl Model { }; let (rc, buf) = self.call_complete( - &guard, + guard, &request.messages_c, &request.options_c, Some(token_trampoline::), diff --git a/crates/listener-core/src/actors/source/pipeline.rs b/crates/listener-core/src/actors/source/pipeline.rs index eefdbe8d30..a2c579c59a 100644 --- a/crates/listener-core/src/actors/source/pipeline.rs +++ b/crates/listener-core/src/actors/source/pipeline.rs @@ -14,6 +14,7 @@ use hypr_audio_utils::f32_to_i16_bytes; use hypr_vad_masking::VadMask; use super::{ListenerRouting, SourceFrame}; +use hypr_audio::CaptureFrame; const AUDIO_AMPLITUDE_THROTTLE: Duration = Duration::from_millis(100); const MAX_BUFFER_CHUNKS: usize = 150; diff --git a/crates/listener2-core/src/batch/mod.rs b/crates/listener2-core/src/batch/mod.rs index a351bb80d1..15920b0ba2 100644 --- a/crates/listener2-core/src/batch/mod.rs +++ b/crates/listener2-core/src/batch/mod.rs @@ -55,6 +55,8 @@ pub struct BatchParams { pub min_speakers: Option, #[serde(default)] pub max_speakers: Option, + #[serde(default)] + pub known_speaker_references: Vec, } #[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)] @@ -235,6 +237,7 @@ fn build_listen_params( num_speakers: params.num_speakers, min_speakers: params.min_speakers, max_speakers: params.max_speakers, + known_speaker_references: params.known_speaker_references.clone(), custom_query: None, } } @@ -326,6 +329,7 @@ mod tests { num_speakers: None, min_speakers: None, max_speakers: None, + known_speaker_references: vec![], } } @@ -353,6 +357,20 @@ mod tests { assert!(listen_params.custom_query.is_none()); } + #[test] + fn build_listen_params_preserves_known_speaker_references() { + let mut params = batch_params(BatchProvider::OpenAI, "https://api.openai.com/v1"); + params.known_speaker_references = vec![owhisper_interface::KnownSpeakerReference { + name: "agent".to_string(), + audio_data_url: "data:audio/wav;base64,AAA=".to_string(), + }]; + + let listen_params = build_listen_params(¶ms, 1, 16_000); + + assert_eq!(listen_params.known_speaker_references.len(), 1); + assert_eq!(listen_params.known_speaker_references[0].name, "agent"); + } + #[test] fn am_routes_pyannote_to_direct_batch() { let params = batch_params(BatchProvider::Am, "https://api.pyannote.ai"); diff --git a/crates/openai-transcription/src/batch/response.rs b/crates/openai-transcription/src/batch/response.rs index 0f9f60adab..b757385687 100644 --- a/crates/openai-transcription/src/batch/response.rs +++ b/crates/openai-transcription/src/batch/response.rs @@ -8,6 +8,9 @@ pub enum ParsedTranscriptionStreamEvent { partial_text: String, logprobs: Vec, }, + TextSegment { + segment: DiarizedTranscriptionSegment, + }, TextDone { text: String, logprobs: Vec, @@ -65,6 +68,22 @@ impl TranscriptionStreamEventParser { }) } } + TranscriptionStreamEvent::TextSegment { + id, + end, + speaker, + start, + text, + } => Some(ParsedTranscriptionStreamEvent::TextSegment { + segment: DiarizedTranscriptionSegment { + id, + end, + speaker, + start, + text, + segment_type: TranscriptionDiarizedSegmentType::TranscriptTextSegment, + }, + }), TranscriptionStreamEvent::TextDone { text, logprobs, @@ -266,6 +285,14 @@ pub enum TranscriptionStreamEvent { #[serde(default)] logprobs: Vec, }, + #[serde(rename = "transcript.text.segment")] + TextSegment { + id: String, + end: f64, + speaker: String, + start: f64, + text: String, + }, #[serde(rename = "transcript.text.done")] TextDone { text: String, @@ -366,6 +393,17 @@ mod tests { }"#, ) .expect("parse delta"); + let segment: TranscriptionStreamEvent = serde_json::from_str( + r#"{ + "type": "transcript.text.segment", + "id": "seg_001", + "start": 0.0, + "end": 1.5, + "text": "hello there", + "speaker": "agent" + }"#, + ) + .expect("parse segment"); let done: TranscriptionStreamEvent = serde_json::from_str( r#"{ "type": "transcript.text.done", @@ -382,6 +420,10 @@ mod tests { .expect("parse done"); assert!(matches!(delta, TranscriptionStreamEvent::TextDelta { .. })); + assert!(matches!( + segment, + TranscriptionStreamEvent::TextSegment { .. } + )); assert!(matches!(done, TranscriptionStreamEvent::TextDone { .. })); } @@ -410,4 +452,25 @@ mod tests { ParsedTranscriptionStreamEvent::TextDone { text, .. } if text == "hello" )); } + + #[test] + fn parser_preserves_diarized_segment_events() { + let mut parser = TranscriptionStreamEventParser::new(); + + let segment = parser + .parse_sse_block( + r#"data: {"type":"transcript.text.segment","id":"seg_001","start":0.0,"end":1.5,"text":"hello there","speaker":"agent"}"#, + ) + .expect("parse segment") + .expect("expected segment"); + + assert!(matches!( + segment, + ParsedTranscriptionStreamEvent::TextSegment { segment } + if segment.id == "seg_001" + && segment.speaker == "agent" + && segment.text == "hello there" + )); + assert_eq!(parser.partial_text(), ""); + } } diff --git a/crates/owhisper-client/src/adapter/openai/batch.rs b/crates/owhisper-client/src/adapter/openai/batch.rs index 29601d02b9..d1e40d5b65 100644 --- a/crates/owhisper-client/src/adapter/openai/batch.rs +++ b/crates/owhisper-client/src/adapter/openai/batch.rs @@ -3,11 +3,13 @@ use std::path::{Path, PathBuf}; use futures_util::StreamExt; use openai_transcription::batch::{ CreateTranscriptionOptions, CreateTranscriptionResponse, DiarizedTranscriptionResponse, - ParsedTranscriptionStreamEvent, TranscriptionStreamEventParser, TranscriptionUsage, + DiarizedTranscriptionSegment, ParsedTranscriptionStreamEvent, TranscriptionStreamEventParser, + TranscriptionUsage, }; use owhisper_interface::ListenParams; use owhisper_interface::batch::{Alternatives, Channel, Response as BatchResponse, Results, Word}; use owhisper_interface::batch_stream::BatchStreamEvent; +use owhisper_interface::stream; use reqwest::multipart::{Form, Part}; use crate::adapter::{ @@ -103,6 +105,9 @@ impl OpenAIAdapter { return Some((event, state)); } } + if let Some(event) = state.finish_stream() { + return Some((event, state)); + } return None; } } @@ -184,6 +189,16 @@ fn build_transcription_options( options.push_language(lang.iso639().code().to_string()); } + if matches!(options, CreateTranscriptionOptions::Diarize(_)) { + let common = options.common_mut(); + for speaker in ¶ms.known_speaker_references { + common.known_speaker_names.push(speaker.name.clone()); + common + .known_speaker_references + .push(speaker.audio_data_url.clone()); + } + } + options } @@ -223,6 +238,10 @@ struct OpenAISseParserState { pending_events: std::collections::VecDeque>, parser: TranscriptionStreamEventParser, progress: OpenAISyntheticProgress, + speaker_labels: Vec, + saw_transcript_event: bool, + saw_terminal_result: bool, + emitted_incomplete_stream_error: bool, } impl OpenAISseParserState { @@ -233,6 +252,10 @@ impl OpenAISseParserState { pending_events: std::collections::VecDeque::new(), parser: TranscriptionStreamEventParser::new(), progress: OpenAISyntheticProgress::default(), + speaker_labels: Vec::new(), + saw_transcript_event: false, + saw_terminal_result: false, + emitted_incomplete_stream_error: false, } } @@ -251,6 +274,33 @@ impl OpenAISseParserState { } } + fn finish_stream(&mut self) -> Option> { + if self.emitted_incomplete_stream_error + || !self.saw_transcript_event + || self.saw_terminal_result + { + return None; + } + + self.emitted_incomplete_stream_error = true; + Some(Err(Error::WebSocket( + "OpenAI stream ended before transcript.text.done".to_string(), + ))) + } + + fn speaker_index_for(&mut self, speaker: &str) -> i32 { + if let Some(index) = self + .speaker_labels + .iter() + .position(|label| label == speaker) + { + return index as i32; + } + + self.speaker_labels.push(speaker.to_string()); + (self.speaker_labels.len() - 1) as i32 + } + fn parse_sse_block(&mut self, block: &str) -> Option> { let event = match self.parser.parse_sse_block(block) { Ok(Some(event)) => event, @@ -264,12 +314,23 @@ impl OpenAISseParserState { match event { ParsedTranscriptionStreamEvent::TextDelta { partial_text, .. } => { + self.saw_transcript_event = true; Some(Ok(BatchStreamEvent::Progress { percentage: self.progress.observe_delta(&partial_text), partial_text: Some(partial_text), })) } + ParsedTranscriptionStreamEvent::TextSegment { segment } => { + self.saw_transcript_event = true; + let speaker_index = self.speaker_index_for(&segment.speaker); + Some(Ok(BatchStreamEvent::Segment { + percentage: self.progress.observe_delta(&segment.text), + response: build_stream_response_for_segment(&segment, speaker_index), + })) + } ParsedTranscriptionStreamEvent::TextDone { text, usage, .. } => { + self.saw_transcript_event = true; + self.saw_terminal_result = true; Some(Ok(BatchStreamEvent::Result { response: build_batch_response( text.trim().to_string(), @@ -452,6 +513,70 @@ fn convert_diarized_words(response: &DiarizedTranscriptionResponse) -> (Vec Vec { + let tokens = segment.text.split_whitespace().collect::>(); + if tokens.is_empty() { + return Vec::new(); + } + + let segment_duration = (segment.end - segment.start).max(0.0); + let word_duration = segment_duration / tokens.len() as f64; + + tokens + .iter() + .enumerate() + .map(|(index, token)| { + let normalized = strip_punctuation(token); + let start = segment.start + word_duration * index as f64; + let end = if index + 1 == tokens.len() { + segment.end + } else { + segment.start + word_duration * (index + 1) as f64 + }; + + stream::Word { + word: if normalized.is_empty() { + (*token).to_string() + } else { + normalized + }, + start, + end, + confidence: 1.0, + speaker: Some(speaker_index), + punctuated_word: Some((*token).to_string()), + language: None, + } + }) + .collect() +} + +fn build_stream_response_for_segment( + segment: &DiarizedTranscriptionSegment, + speaker_index: i32, +) -> stream::StreamResponse { + stream::StreamResponse::TranscriptResponse { + start: segment.start, + duration: (segment.end - segment.start).max(0.0), + is_final: true, + speech_final: false, + from_finalize: false, + channel: stream::Channel { + alternatives: vec![stream::Alternatives { + transcript: segment.text.clone(), + words: convert_diarized_segment_words(segment, speaker_index), + confidence: 1.0, + languages: vec![], + }], + }, + metadata: stream::Metadata::default(), + channel_index: vec![0, 1], + } +} + fn build_batch_response( transcript: String, words: Vec, @@ -514,6 +639,63 @@ mod tests { assert!(!fields.iter().any(|field| field.name == "stream")); } + #[test] + fn build_transcription_options_includes_known_speaker_references_for_diarize_model() { + let options = build_transcription_options( + &ListenParams::default() + .with_known_speaker_reference("agent", "data:audio/wav;base64,AAA="), + true, + false, + ); + + let fields = options + .multipart_text_fields() + .expect("serialize multipart"); + + assert!(matches!(options, CreateTranscriptionOptions::Diarize(_))); + assert!( + fields + .iter() + .any(|field| { field.name == "known_speaker_names[]" && field.value == "agent" }) + ); + assert!(fields.iter().any(|field| { + field.name == "known_speaker_references[]" + && field.value == "data:audio/wav;base64,AAA=" + })); + } + + #[test] + fn build_transcription_options_omits_known_speaker_references_for_non_diarize_model() { + let options = build_transcription_options( + &ListenParams { + model: Some("gpt-4o-transcribe".to_string()), + known_speaker_references: vec![owhisper_interface::KnownSpeakerReference { + name: "agent".to_string(), + audio_data_url: "data:audio/wav;base64,AAA=".to_string(), + }], + ..Default::default() + }, + false, + false, + ); + + let fields = options + .multipart_text_fields() + .expect("serialize multipart"); + + assert!(matches!(options, CreateTranscriptionOptions::Gpt(_))); + assert!( + !fields + .iter() + .any(|field| field.name == "known_speaker_names[]") + ); + assert!( + !fields + .iter() + .any(|field| field.name == "known_speaker_references[]") + ); + } + #[test] fn parse_sse_delta_accumulates_partial_text() { let mut state = OpenAISseParserState::new(()); @@ -559,6 +741,28 @@ mod tests { assert_eq!(response.metadata["usage"]["type"], "tokens"); } + #[test] + fn parse_sse_segment_emits_stream_segment() { + let mut state = OpenAISseParserState::new(()); + let event = state + .parse_sse_block( + r#"data: {"type":"transcript.text.segment","id":"seg_001","start":0.0,"end":1.5,"text":"hello there","speaker":"agent"}"#, + ) + .expect("expected segment event") + .expect("expected valid segment event"); + + let BatchStreamEvent::Segment { response, .. } = event else { + panic!("expected segment event"); + }; + + let stream::StreamResponse::TranscriptResponse { channel, .. } = response else { + panic!("expected transcript response"); + }; + + assert_eq!(channel.alternatives[0].transcript, "hello there"); + assert_eq!(channel.alternatives[0].words[0].speaker, Some(0)); + } + #[test] fn parse_buffer_handles_crlf_delimited_sse_blocks() { let mut state = OpenAISseParserState::new(()); @@ -596,6 +800,24 @@ mod tests { assert!(capped <= OPENAI_PROGRESS_CAP); } + #[test] + fn finish_stream_after_progress_returns_error() { + let mut state = OpenAISseParserState::new(()); + let _ = state + .parse_sse_block(r#"data: {"type":"transcript.text.delta","delta":"hello"}"#) + .expect("expected progress event"); + + let event = state + .finish_stream() + .expect("expected incomplete stream error"); + let Err(Error::WebSocket(message)) = event else { + panic!("expected websocket error"); + }; + + assert!(message.contains("transcript.text.done")); + assert!(state.finish_stream().is_none()); + } + #[test] fn convert_diarized_response_preserves_speaker_segments() { let response: CreateTranscriptionResponse = serde_json::from_str( diff --git a/crates/owhisper-interface/src/batch_sse.rs b/crates/owhisper-interface/src/batch_sse.rs index 3b7cb27fdc..ad78c70364 100644 --- a/crates/owhisper-interface/src/batch_sse.rs +++ b/crates/owhisper-interface/src/batch_sse.rs @@ -34,3 +34,128 @@ impl From for batch_stream::BatchStreamEvent { } } } + +impl BatchSseMessage { + pub fn from_batch_stream_event( + event: batch_stream::BatchStreamEvent, + fallback_provider: &str, + ) -> Option { + match event { + batch_stream::BatchStreamEvent::Progress { + percentage, + partial_text, + } => Some(Self::Progress { + progress: InferenceProgress { + percentage, + partial_text, + phase: crate::progress::InferencePhase::Transcribing, + }, + }), + batch_stream::BatchStreamEvent::Segment { .. } + | batch_stream::BatchStreamEvent::Terminal { .. } => None, + batch_stream::BatchStreamEvent::Result { response } => Some(Self::Result { response }), + batch_stream::BatchStreamEvent::Error { + error_message, + provider, + .. + } => Some(Self::Error { + error: if provider.is_empty() { + fallback_provider.to_string() + } else { + provider + }, + detail: error_message, + }), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::batch; + + #[test] + fn from_batch_stream_event_maps_progress() { + let message = BatchSseMessage::from_batch_stream_event( + batch_stream::BatchStreamEvent::Progress { + percentage: 0.5, + partial_text: Some("hello".to_string()), + }, + "openai", + ) + .expect("expected progress message"); + + let BatchSseMessage::Progress { progress } = message else { + panic!("expected progress"); + }; + + assert_eq!(progress.percentage, 0.5); + assert_eq!(progress.partial_text.as_deref(), Some("hello")); + assert_eq!( + progress.phase, + crate::progress::InferencePhase::Transcribing + ); + } + + #[test] + fn from_batch_stream_event_maps_result() { + let response = batch::Response { + metadata: serde_json::json!({}), + results: batch::Results { channels: vec![] }, + }; + + let message = BatchSseMessage::from_batch_stream_event( + batch_stream::BatchStreamEvent::Result { + response: response.clone(), + }, + "openai", + ) + .expect("expected result message"); + + let BatchSseMessage::Result { + response: mapped_response, + } = message + else { + panic!("expected result"); + }; + + assert_eq!(mapped_response, response); + } + + #[test] + fn from_batch_stream_event_uses_fallback_provider_for_errors() { + let message = BatchSseMessage::from_batch_stream_event( + batch_stream::BatchStreamEvent::Error { + error_code: None, + error_message: "boom".to_string(), + provider: String::new(), + }, + "openai", + ) + .expect("expected error message"); + + let BatchSseMessage::Error { error, detail } = message else { + panic!("expected error"); + }; + + assert_eq!(error, "openai"); + assert_eq!(detail, "boom"); + } + + #[test] + fn from_batch_stream_event_ignores_non_batch_sse_events() { + assert!( + BatchSseMessage::from_batch_stream_event( + batch_stream::BatchStreamEvent::Terminal { + request_id: "req_123".to_string(), + created: "now".to_string(), + duration: 1.0, + channels: 1, + }, + "openai", + ) + .is_none() + ); + } +} diff --git a/crates/owhisper-interface/src/lib.rs b/crates/owhisper-interface/src/lib.rs index 618940fa76..308486a4e1 100644 --- a/crates/owhisper-interface/src/lib.rs +++ b/crates/owhisper-interface/src/lib.rs @@ -87,6 +87,13 @@ common_derives! { } } +common_derives! { + pub struct KnownSpeakerReference { + pub name: String, + pub audio_data_url: String, + } +} + common_derives! { #[serde(tag = "type", content = "value")] pub enum ListenInputChunk { @@ -153,6 +160,8 @@ common_derives! { #[serde(default)] pub max_speakers: Option, #[serde(default)] + pub known_speaker_references: Vec, + #[serde(default)] #[cfg_attr(feature = "openapi", schema(value_type = Option))] pub custom_query: Option>, } @@ -169,6 +178,7 @@ impl Default for ListenParams { num_speakers: None, min_speakers: None, max_speakers: None, + known_speaker_references: Vec::new(), custom_query: None, } } @@ -182,4 +192,43 @@ impl ListenParams { fn default_sample_rate() -> u32 { 16000 } + + pub fn add_known_speaker_reference( + &mut self, + name: impl Into, + audio_data_url: impl Into, + ) -> &mut Self { + self.known_speaker_references.push(KnownSpeakerReference { + name: name.into(), + audio_data_url: audio_data_url.into(), + }); + self + } + + pub fn with_known_speaker_reference( + mut self, + name: impl Into, + audio_data_url: impl Into, + ) -> Self { + self.add_known_speaker_reference(name, audio_data_url); + self + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn listen_params_builder_adds_known_speaker_reference() { + let params = ListenParams::default() + .with_known_speaker_reference("agent", "data:audio/wav;base64,AAA="); + + assert_eq!(params.known_speaker_references.len(), 1); + assert_eq!(params.known_speaker_references[0].name, "agent"); + assert_eq!( + params.known_speaker_references[0].audio_data_url, + "data:audio/wav;base64,AAA=" + ); + } } diff --git a/crates/transcribe-proxy/Cargo.toml b/crates/transcribe-proxy/Cargo.toml index 18d3cfa4f2..7306db1d11 100644 --- a/crates/transcribe-proxy/Cargo.toml +++ b/crates/transcribe-proxy/Cargo.toml @@ -11,6 +11,7 @@ hypr-audio-mime = { workspace = true } hypr-language = { workspace = true } hypr-observability = { workspace = true } hypr-supabase-storage = { workspace = true } +hypr-transcribe-core = { workspace = true } owhisper-client = { workspace = true } owhisper-interface = { workspace = true, features = ["openapi"] } diff --git a/crates/transcribe-proxy/src/provider_selector.rs b/crates/transcribe-proxy/src/provider_selector.rs index 165e7bc950..c30319435c 100644 --- a/crates/transcribe-proxy/src/provider_selector.rs +++ b/crates/transcribe-proxy/src/provider_selector.rs @@ -5,6 +5,7 @@ use owhisper_client::Provider; use crate::error::SelectionError; +#[derive(Clone)] pub struct SelectedProvider { provider: Provider, api_key: String, diff --git a/crates/transcribe-proxy/src/routes/batch/mod.rs b/crates/transcribe-proxy/src/routes/batch/mod.rs index 85a0605144..aebdde277f 100644 --- a/crates/transcribe-proxy/src/routes/batch/mod.rs +++ b/crates/transcribe-proxy/src/routes/batch/mod.rs @@ -6,8 +6,8 @@ use std::io::Write; use axum::{ Json, body::Bytes, - extract::State, - http::{HeaderMap, StatusCode}, + extract::{FromRequestParts, State}, + http::{HeaderMap, StatusCode, request::Parts}, response::{IntoResponse, Response}, }; use hypr_api_auth::AuthContext; @@ -21,9 +21,29 @@ use crate::query_params::QueryParams; use super::AppState; +pub(crate) struct WantsBatchSse(pub bool); + +impl FromRequestParts for WantsBatchSse +where + S: Send + Sync, +{ + type Rejection = std::convert::Infallible; + + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { + Ok(Self( + parts + .headers + .get("accept") + .and_then(|value| value.to_str().ok()) + .is_some_and(|value| value.contains("text/event-stream")), + )) + } +} + pub async fn handler( State(state): State, auth: Option>, + WantsBatchSse(wants_batch_sse): WantsBatchSse, headers: HeaderMap, mut params: QueryParams, body: Bytes, @@ -56,7 +76,17 @@ pub async fn handler( let use_hyprnote_routing = should_use_hyprnote_routing(provider_param.as_deref()); if use_hyprnote_routing { - return sync::handle_hyprnote_batch(&state, ¶ms, listen_params, body, content_type) + if wants_batch_sse { + return sync::handle_routed_batch_sse( + &state, + ¶ms, + listen_params, + body, + content_type, + ); + } + + return sync::handle_routed_batch_json(&state, ¶ms, listen_params, body, content_type) .await; } @@ -78,6 +108,16 @@ pub async fn handler( .map(|r| r.retry_config().clone()) .unwrap_or_default(); + if wants_batch_sse { + return sync::handle_direct_batch_sse( + selected, + listen_params, + body, + content_type, + retry_config, + ); + } + match sync::transcribe_with_retry(&selected, listen_params, body, content_type, &retry_config) .await { @@ -129,6 +169,7 @@ fn write_to_temp_file( mod tests { use super::*; use crate::query_params::QueryValue; + use axum::http::Request; use hypr_language::ISO639; #[test] @@ -152,4 +193,31 @@ mod tests { assert_eq!(listen_params.languages[1].iso639(), ISO639::Ko); assert_eq!(listen_params.languages[1].region(), Some("KR")); } + + #[tokio::test] + async fn wants_batch_sse_detects_event_stream_accept_header() { + let request = Request::builder() + .header("accept", "application/json, text/event-stream") + .body(()) + .expect("request"); + let (mut parts, _) = request.into_parts(); + + let wants_sse = WantsBatchSse::from_request_parts(&mut parts, &()) + .await + .expect("extractor"); + + assert!(wants_sse.0); + } + + #[tokio::test] + async fn wants_batch_sse_defaults_to_false() { + let request = Request::builder().body(()).expect("request"); + let (mut parts, _) = request.into_parts(); + + let wants_sse = WantsBatchSse::from_request_parts(&mut parts, &()) + .await + .expect("extractor"); + + assert!(!wants_sse.0); + } } diff --git a/crates/transcribe-proxy/src/routes/batch/sync.rs b/crates/transcribe-proxy/src/routes/batch/sync.rs index b6644c9254..0cb70b9989 100644 --- a/crates/transcribe-proxy/src/routes/batch/sync.rs +++ b/crates/transcribe-proxy/src/routes/batch/sync.rs @@ -7,12 +7,15 @@ use axum::{ response::{IntoResponse, Response}, }; use backon::{ExponentialBuilder, Retryable}; +use futures_util::StreamExt; +use hypr_transcribe_core::batch_sse_response; use owhisper_client::{ AssemblyAIAdapter, BatchClient, DeepgramAdapter, ElevenLabsAdapter, FireworksAdapter, GladiaAdapter, MistralAdapter, OpenAIAdapter, Provider, PyannoteAdapter, SonioxAdapter, }; use owhisper_interface::ListenParams; use owhisper_interface::batch::Response as BatchResponse; +use owhisper_interface::batch_sse::BatchSseMessage; use crate::hyprnote_routing::{RetryConfig, RoutingMode}; use crate::provider_selector::SelectedProvider; @@ -97,24 +100,41 @@ fn resolve_listen_params_for_provider( resolved_params } -pub(super) async fn handle_hyprnote_batch( +struct PreparedBatchRoute { + provider_chain: Vec, + retry_config: RetryConfig, + trace: BatchRoutingTrace, +} + +struct BatchChainFailure { + last_error: Option, + providers_tried: Vec, + terminal_provider: Option, +} + +enum AttemptOutcome { + Success { + value: T, + retries: usize, + }, + Failure { + error: BatchAttemptError, + retries: usize, + can_fallback: bool, + }, +} + +fn prepare_routed_batch( state: &AppState, params: &QueryParams, - listen_params: ListenParams, - body: Bytes, + listen_params: &ListenParams, + body: &Bytes, content_type: &str, -) -> Response { +) -> Result { let provider_chain = state.resolve_hyprnote_provider_chain_for_mode(RoutingMode::Batch, params); if provider_chain.is_empty() { - return ( - StatusCode::BAD_REQUEST, - Json(serde_json::json!({ - "error": "no_providers_available", - "detail": "No providers available for the requested language(s)" - })), - ) - .into_response(); + return Err("No providers available for the requested language(s)".to_string()); } let retry_config = state @@ -130,39 +150,68 @@ pub(super) async fn handle_hyprnote_batch( "hyprnote_batch_transcription_request" ); + Ok(PreparedBatchRoute { + trace: BatchRoutingTrace { + request_model: listen_params.model.clone(), + request_languages: listen_params + .languages + .iter() + .map(|lang| lang.iso639().code().to_string()) + .collect(), + provider_chain: provider_chain + .iter() + .map(|selected| selected.provider().to_string()) + .collect(), + attempts: Vec::new(), + outcome: "in_progress".to_string(), + }, + provider_chain, + retry_config, + }) +} + +fn no_providers_available_response(detail: impl Into) -> Response { + ( + StatusCode::BAD_REQUEST, + Json(serde_json::json!({ + "error": "no_providers_available", + "detail": detail.into() + })), + ) + .into_response() +} + +async fn run_routed_batch_chain( + prepared: PreparedBatchRoute, + listen_params: &ListenParams, + body: &Bytes, + content_type: &str, + mut attempt_provider: F, +) -> Result +where + F: FnMut(SelectedProvider, ListenParams, Bytes, String, RetryConfig) -> Fut, + Fut: std::future::Future>, +{ let mut last_error: Option = None; let mut providers_tried = Vec::new(); - let mut trace = BatchRoutingTrace { - request_model: listen_params.model.clone(), - request_languages: listen_params - .languages - .iter() - .map(|lang| lang.iso639().code().to_string()) - .collect(), - provider_chain: provider_chain - .iter() - .map(|selected| selected.provider().to_string()) - .collect(), - attempts: Vec::new(), - outcome: "in_progress".to_string(), - }; + let mut trace = prepared.trace; - for (attempt, selected) in provider_chain.iter().enumerate() { + for (attempt, selected) in prepared.provider_chain.iter().enumerate() { let provider = selected.provider(); - let provider_listen_params = resolve_listen_params_for_provider(provider, &listen_params); + let provider_listen_params = resolve_listen_params_for_provider(provider, listen_params); let resolved_model = provider_listen_params.model.clone(); providers_tried.push(provider); - match transcribe_with_retry( - selected, + match attempt_provider( + selected.clone(), provider_listen_params, body.clone(), - content_type, - &retry_config, + content_type.to_string(), + prepared.retry_config.clone(), ) .await { - Ok((response, retries)) => { + AttemptOutcome::Success { value, retries } => { tracing::info!( hyprnote.stt.provider.name = ?provider, hyprnote.attempt.number = attempt + 1, @@ -177,14 +226,19 @@ pub(super) async fn handle_hyprnote_batch( trace.outcome = "success".to_string(); log_batch_routing_trace(&trace, true); - return Json(response).into_response(); + return Ok(value); } - Err((e, retries)) => { + AttemptOutcome::Failure { + error: e, + retries, + can_fallback, + .. + } => { tracing::warn!( hyprnote.stt.provider.name = ?provider, error = %e, hyprnote.attempt.number = attempt + 1, - hyprnote.remaining_provider_count = provider_chain.len() - attempt - 1, + hyprnote.remaining_provider_count = prepared.provider_chain.len() - attempt - 1, "provider_failed_trying_next" ); trace.attempts.push(BatchRoutingAttempt { @@ -194,6 +248,17 @@ pub(super) async fn handle_hyprnote_batch( result: format!("{}: {}", e.kind(), e.message()), }); last_error = Some(e.message().to_string()); + + if !can_fallback { + trace.outcome = "terminal_failure".to_string(); + log_batch_routing_trace(&trace, false); + + return Err(BatchChainFailure { + last_error, + providers_tried, + terminal_provider: Some(provider), + }); + } } } } @@ -201,12 +266,328 @@ pub(super) async fn handle_hyprnote_batch( trace.outcome = "all_providers_failed".to_string(); log_batch_routing_trace(&trace, false); + Err(BatchChainFailure { + last_error, + providers_tried, + terminal_provider: None, + }) +} + +pub(super) fn handle_routed_batch_sse( + state: &AppState, + params: &QueryParams, + listen_params: ListenParams, + body: Bytes, + content_type: &str, +) -> Response { + let prepared = match prepare_routed_batch(state, params, &listen_params, &body, content_type) { + Ok(prepared) => prepared, + Err(detail) => return no_providers_available_response(detail), + }; + let (event_tx, event_rx) = tokio::sync::mpsc::unbounded_channel::(); + + let content_type = content_type.to_string(); + tokio::spawn(async move { + let result = run_routed_batch_chain( + prepared, + &listen_params, + &body, + &content_type, + |selected, params, body, content_type, retry_config| { + let event_tx = event_tx.clone(); + async move { + stream_batch_attempt_as_sse( + &selected, + params, + &body, + &content_type, + &retry_config, + &event_tx, + ) + .await + } + }, + ) + .await; + + if let Err(failure) = result { + let detail = failure + .last_error + .unwrap_or_else(|| "Unknown error".to_string()); + let _ = event_tx.send(BatchSseMessage::Error { + error: failure + .terminal_provider + .map(|provider| provider.to_string()) + .unwrap_or_else(|| "all_providers_failed".to_string()), + detail, + }); + } + }); + + batch_sse_response(event_rx) +} + +pub(super) fn handle_direct_batch_sse( + selected: SelectedProvider, + listen_params: ListenParams, + body: Bytes, + content_type: &str, + retry_config: RetryConfig, +) -> Response { + let (event_tx, event_rx) = tokio::sync::mpsc::unbounded_channel::(); + let content_type = content_type.to_string(); + + tokio::spawn(async move { + let outcome = stream_batch_attempt_as_sse( + &selected, + listen_params, + &body, + &content_type, + &retry_config, + &event_tx, + ) + .await; + + if let AttemptOutcome::Failure { error, .. } = outcome { + let _ = event_tx.send(BatchSseMessage::Error { + error: selected.provider().to_string(), + detail: error.to_string(), + }); + } + }); + + batch_sse_response(event_rx) +} + +pub(super) async fn handle_routed_batch_json( + state: &AppState, + params: &QueryParams, + listen_params: ListenParams, + body: Bytes, + content_type: &str, +) -> Response { + let prepared = match prepare_routed_batch(state, params, &listen_params, &body, content_type) { + Ok(prepared) => prepared, + Err(detail) => return no_providers_available_response(detail), + }; + + match run_routed_batch_chain( + prepared, + &listen_params, + &body, + content_type, + |selected, params, body, content_type, retry_config| async move { + match transcribe_with_retry(&selected, params, body, &content_type, &retry_config).await + { + Ok((response, retries)) => AttemptOutcome::Success { + value: response, + retries, + }, + Err((error, retries)) => AttemptOutcome::Failure { + error, + retries, + can_fallback: true, + }, + } + }, + ) + .await + { + Ok(response) => Json(response).into_response(), + Err(failure) => all_providers_failed_response(failure), + } +} + +async fn stream_batch_attempt_as_sse( + selected: &SelectedProvider, + params: ListenParams, + audio_bytes: &Bytes, + content_type: &str, + retry_config: &RetryConfig, + event_tx: &tokio::sync::mpsc::UnboundedSender, +) -> AttemptOutcome<()> { + if selected.provider() == Provider::OpenAI + && OpenAIAdapter::supports_progressive_batch_model(params.model.as_deref()) + { + return stream_openai_batch_sse_with_retry( + selected, + params, + audio_bytes, + content_type, + retry_config, + event_tx, + ) + .await; + } + + match transcribe_with_retry( + selected, + params, + audio_bytes.clone(), + content_type, + retry_config, + ) + .await + { + Ok((response, retries)) => { + let _ = event_tx.send(BatchSseMessage::Result { response }); + AttemptOutcome::Success { value: (), retries } + } + Err((error, retries)) => AttemptOutcome::Failure { + error, + retries, + can_fallback: true, + }, + } +} + +struct StreamingAttemptFailure { + error: BatchAttemptError, + emitted_output: bool, +} + +async fn stream_openai_batch_sse_with_retry( + selected: &SelectedProvider, + params: ListenParams, + audio_bytes: &Bytes, + content_type: &str, + retry_config: &RetryConfig, + event_tx: &tokio::sync::mpsc::UnboundedSender, +) -> AttemptOutcome<()> { + let mut retries = 0usize; + + loop { + match stream_openai_batch_sse_once(selected, ¶ms, audio_bytes, content_type, event_tx) + .await + { + Ok(()) => { + return AttemptOutcome::Success { value: (), retries }; + } + Err(failure) + if !failure.emitted_output + && failure.error.is_retryable() + && retries < retry_config.num_retries => + { + retries += 1; + let delay_secs = (1u64 << retries).min(retry_config.max_delay_secs); + let delay = Duration::from_secs(delay_secs); + tracing::warn!( + hyprnote.stt.provider.name = ?selected.provider(), + error = %failure.error, + hyprnote.retry.delay_ms = delay.as_millis(), + "retrying_transcription" + ); + tokio::time::sleep(delay).await; + } + Err(failure) => { + return AttemptOutcome::Failure { + error: failure.error, + retries, + can_fallback: !failure.emitted_output, + }; + } + } + } +} + +async fn stream_openai_batch_sse_once( + selected: &SelectedProvider, + params: &ListenParams, + audio_bytes: &Bytes, + content_type: &str, + event_tx: &tokio::sync::mpsc::UnboundedSender, +) -> Result<(), StreamingAttemptFailure> { + let temp_file = + write_to_temp_file(audio_bytes, content_type).map_err(|e| StreamingAttemptFailure { + error: BatchAttemptError::Client(format!("failed to create temp file: {e}")), + emitted_output: false, + })?; + + let file_path = temp_file.path().to_path_buf(); + let provider_name = selected.provider().to_string(); + let api_base = selected + .upstream_url() + .unwrap_or(selected.provider().default_api_base()); + let api_key = selected.api_key(); + + let mut stream = + OpenAIAdapter::transcribe_file_streaming(api_base, api_key, params, &file_path) + .await + .map_err(|error| StreamingAttemptFailure { + error: map_provider_error(error), + emitted_output: false, + })?; + + let mut emitted_output = false; + + while let Some(event) = stream.next().await { + match event { + Ok(owhisper_interface::batch_stream::BatchStreamEvent::Progress { + percentage, + partial_text, + }) => { + emitted_output = true; + let _ = event_tx.send(BatchSseMessage::Progress { + progress: owhisper_interface::InferenceProgress { + percentage, + partial_text, + phase: owhisper_interface::progress::InferencePhase::Transcribing, + }, + }); + } + Ok(owhisper_interface::batch_stream::BatchStreamEvent::Segment { + response, .. + }) => { + emitted_output = true; + let _ = event_tx.send(BatchSseMessage::Segment { response }); + } + Ok(owhisper_interface::batch_stream::BatchStreamEvent::Result { response }) => { + emitted_output = true; + let _ = event_tx.send(BatchSseMessage::Result { response }); + } + Ok(owhisper_interface::batch_stream::BatchStreamEvent::Error { + error_message, + provider, + .. + }) => { + return Err(StreamingAttemptFailure { + error: classify_audio_processing_message(if provider.is_empty() { + error_message + } else { + format!("{provider}: {error_message}") + }), + emitted_output, + }); + } + Ok(owhisper_interface::batch_stream::BatchStreamEvent::Terminal { .. }) => {} + Err(error) => { + return Err(StreamingAttemptFailure { + error: map_provider_error(error), + emitted_output, + }); + } + } + } + + if !emitted_output { + return Err(StreamingAttemptFailure { + error: BatchAttemptError::Retryable(format!( + "{provider_name} stream ended before emitting any events" + )), + emitted_output: false, + }); + } + + Ok(()) +} + +fn all_providers_failed_response(failure: BatchChainFailure) -> Response { ( StatusCode::BAD_GATEWAY, Json(serde_json::json!({ "error": "all_providers_failed", - "detail": last_error.unwrap_or_else(|| "Unknown error".to_string()), - "providers_tried": providers_tried.iter().map(|p| format!("{:?}", p)).collect::>() + "detail": failure.last_error.unwrap_or_else(|| "Unknown error".to_string()), + "providers_tried": failure.providers_tried.iter().map(|p| format!("{:?}", p)).collect::>() })), ) .into_response() diff --git a/crates/transcribe-proxy/tests/openai_batch_sse_mock.rs b/crates/transcribe-proxy/tests/openai_batch_sse_mock.rs new file mode 100644 index 0000000000..89a9cbd7ad --- /dev/null +++ b/crates/transcribe-proxy/tests/openai_batch_sse_mock.rs @@ -0,0 +1,608 @@ +mod common; + +use std::net::SocketAddr; + +use axum::{ + Router, + body::Body, + extract::RawQuery, + http::StatusCode, + http::header, + response::{IntoResponse, Response}, + routing::post, +}; +use common::{env_with_provider, start_server}; +use owhisper_client::Provider; +use owhisper_interface::batch_sse::BatchSseMessage; +use transcribe_proxy::{HyprnoteRoutingConfig, SttProxyConfig}; + +#[tokio::test] +async fn hyprnote_batch_sse_streams_openai_events_with_cactus_contract() { + let upstream_addr = start_openai_sse_upstream().await; + + let env = env_with_provider(Provider::OpenAI, "mock-api-key".to_string()); + let supabase_env = hypr_api_env::SupabaseEnv { + supabase_url: String::new(), + supabase_anon_key: String::new(), + supabase_service_role_key: String::new(), + }; + + let config = SttProxyConfig::new(&env, &supabase_env) + .with_default_provider(Provider::OpenAI) + .with_upstream_url(Provider::OpenAI, &format!("http://{upstream_addr}/v1")) + .with_hyprnote_routing(HyprnoteRoutingConfig::default()); + let addr = start_server(config).await; + + let audio_bytes = + std::fs::read(hypr_data::english_1::AUDIO_PATH).expect("failed to read test audio"); + + let response = reqwest::Client::new() + .post(format!( + "http://{addr}/listen?provider=hyprnote&model=gpt-4o-transcribe&language=en" + )) + .header("content-type", "audio/wav") + .header("accept", "text/event-stream") + .body(audio_bytes) + .send() + .await + .expect("request failed"); + + assert_eq!(response.status(), StatusCode::OK); + assert!( + response + .headers() + .get("content-type") + .and_then(|value| value.to_str().ok()) + .is_some_and(|value| value.contains("text/event-stream")) + ); + + let body = response.text().await.expect("failed to read SSE response"); + let events = parse_batch_sse_messages(&body); + + let progress = events + .iter() + .find_map(|event| match event { + BatchSseMessage::Progress { progress } => Some(progress), + _ => None, + }) + .expect("expected progress event"); + assert!(progress.percentage > 0.0); + assert_eq!(progress.partial_text.as_deref(), Some("hello ")); + + let result = events + .iter() + .find_map(|event| match event { + BatchSseMessage::Result { response } => Some(response), + _ => None, + }) + .expect("expected result event"); + + assert_eq!( + result.results.channels[0].alternatives[0].transcript, + "hello world" + ); + assert_eq!(result.metadata["usage"]["total_tokens"], 3); +} + +#[tokio::test] +async fn hyprnote_batch_sse_returns_sse_for_non_streaming_first_provider() { + let deepgram_addr = start_json_batch_upstream("deepgram ok").await; + + let mut env = transcribe_proxy::Env::default(); + env.stt.deepgram_api_key = Some("deepgram-key".to_string()); + env.stt.openai_api_key = Some("openai-key".to_string()); + + let supabase_env = hypr_api_env::SupabaseEnv { + supabase_url: String::new(), + supabase_anon_key: String::new(), + supabase_service_role_key: String::new(), + }; + + let config = SttProxyConfig::new(&env, &supabase_env) + .with_default_provider(Provider::Deepgram) + .with_upstream_url(Provider::Deepgram, &format!("http://{deepgram_addr}/v1")) + .with_hyprnote_routing(HyprnoteRoutingConfig::default()); + let addr = start_server(config).await; + + let audio_bytes = + std::fs::read(hypr_data::english_1::AUDIO_PATH).expect("failed to read test audio"); + + let response = reqwest::Client::new() + .post(format!( + "http://{addr}/listen?provider=hyprnote&language=en" + )) + .header("content-type", "audio/wav") + .header("accept", "text/event-stream") + .body(audio_bytes) + .send() + .await + .expect("request failed"); + + assert_eq!(response.status(), StatusCode::OK); + assert!( + response + .headers() + .get("content-type") + .and_then(|value| value.to_str().ok()) + .is_some_and(|value| value.contains("text/event-stream")) + ); + + let body = response.text().await.expect("failed to read SSE response"); + let events = parse_batch_sse_messages(&body); + let result = events + .iter() + .find_map(|event| match event { + BatchSseMessage::Result { response } => Some(response), + _ => None, + }) + .expect("expected result event"); + + assert_eq!( + result.results.channels[0].alternatives[0].transcript, + "deepgram ok" + ); +} + +#[tokio::test] +async fn direct_batch_sse_returns_sse_for_non_streaming_provider() { + let deepgram_addr = start_json_batch_upstream("deepgram direct").await; + + let env = env_with_provider(Provider::Deepgram, "deepgram-key".to_string()); + let supabase_env = hypr_api_env::SupabaseEnv { + supabase_url: String::new(), + supabase_anon_key: String::new(), + supabase_service_role_key: String::new(), + }; + + let config = SttProxyConfig::new(&env, &supabase_env) + .with_default_provider(Provider::Deepgram) + .with_upstream_url(Provider::Deepgram, &format!("http://{deepgram_addr}/v1")); + let addr = start_server(config).await; + + let audio_bytes = + std::fs::read(hypr_data::english_1::AUDIO_PATH).expect("failed to read test audio"); + + let response = reqwest::Client::new() + .post(format!( + "http://{addr}/listen?provider=deepgram&language=en" + )) + .header("content-type", "audio/wav") + .header("accept", "text/event-stream") + .body(audio_bytes) + .send() + .await + .expect("request failed"); + + assert_eq!(response.status(), StatusCode::OK); + assert!( + response + .headers() + .get("content-type") + .and_then(|value| value.to_str().ok()) + .is_some_and(|value| value.contains("text/event-stream")) + ); + + let body = response.text().await.expect("failed to read SSE response"); + let events = parse_batch_sse_messages(&body); + let result = events + .iter() + .find_map(|event| match event { + BatchSseMessage::Result { response } => Some(response), + _ => None, + }) + .expect("expected result event"); + + assert_eq!( + result.results.channels[0].alternatives[0].transcript, + "deepgram direct" + ); +} + +#[tokio::test] +async fn hyprnote_batch_sse_falls_back_before_openai_emits_progress() { + let openai_addr = start_openai_error_upstream(StatusCode::TOO_MANY_REQUESTS).await; + let mistral_addr = start_mistral_batch_upstream("mistral fallback").await; + + let mut env = transcribe_proxy::Env::default(); + env.stt.mistral_api_key = Some("mistral-key".to_string()); + env.stt.openai_api_key = Some("openai-key".to_string()); + + let supabase_env = hypr_api_env::SupabaseEnv { + supabase_url: String::new(), + supabase_anon_key: String::new(), + supabase_service_role_key: String::new(), + }; + + let config = SttProxyConfig::new(&env, &supabase_env) + .with_default_provider(Provider::OpenAI) + .with_upstream_url(Provider::OpenAI, &format!("http://{openai_addr}/v1")) + .with_upstream_url(Provider::Mistral, &format!("http://{mistral_addr}/v1")) + .with_hyprnote_routing(HyprnoteRoutingConfig { + priorities: vec![Provider::OpenAI, Provider::Mistral], + retry_config: Default::default(), + }); + let addr = start_server(config).await; + + let audio_bytes = + std::fs::read(hypr_data::english_1::AUDIO_PATH).expect("failed to read test audio"); + + let response = reqwest::Client::new() + .post(format!( + "http://{addr}/listen?provider=hyprnote&model=gpt-4o-transcribe&language=en" + )) + .header("content-type", "audio/wav") + .header("accept", "text/event-stream") + .body(audio_bytes) + .send() + .await + .expect("request failed"); + + assert_eq!(response.status(), StatusCode::OK); + + let body = response.text().await.expect("failed to read SSE response"); + let events = parse_batch_sse_messages(&body); + let result = events + .iter() + .find_map(|event| match event { + BatchSseMessage::Result { response } => Some(response), + _ => None, + }) + .expect("expected result event"); + + assert_eq!( + result.results.channels[0].alternatives[0].transcript, + "mistral fallback" + ); + assert!( + !events + .iter() + .any(|event| matches!(event, BatchSseMessage::Progress { .. })), + "should not emit OpenAI progress before fallback" + ); +} + +#[tokio::test] +async fn direct_openai_batch_sse_emits_single_error_when_stream_ends_before_done() { + let openai_addr = start_openai_incomplete_sse_upstream().await; + + let env = env_with_provider(Provider::OpenAI, "openai-key".to_string()); + let supabase_env = hypr_api_env::SupabaseEnv { + supabase_url: String::new(), + supabase_anon_key: String::new(), + supabase_service_role_key: String::new(), + }; + + let config = SttProxyConfig::new(&env, &supabase_env) + .with_default_provider(Provider::OpenAI) + .with_upstream_url(Provider::OpenAI, &format!("http://{openai_addr}/v1")); + let addr = start_server(config).await; + + let audio_bytes = + std::fs::read(hypr_data::english_1::AUDIO_PATH).expect("failed to read test audio"); + + let response = reqwest::Client::new() + .post(format!( + "http://{addr}/listen?provider=openai&model=gpt-4o-transcribe&language=en" + )) + .header("content-type", "audio/wav") + .header("accept", "text/event-stream") + .body(audio_bytes) + .send() + .await + .expect("request failed"); + + assert_eq!(response.status(), StatusCode::OK); + + let body = response.text().await.expect("failed to read SSE response"); + let events = parse_batch_sse_messages(&body); + + assert_eq!(count_errors(&events), 1); + assert!( + events + .iter() + .any(|event| matches!(event, BatchSseMessage::Progress { .. })) + ); + assert!( + !events + .iter() + .any(|event| matches!(event, BatchSseMessage::Result { .. })) + ); +} + +#[tokio::test] +async fn hyprnote_batch_sse_does_not_fallback_after_openai_progress() { + let openai_addr = start_openai_incomplete_sse_upstream().await; + let mistral_addr = start_mistral_batch_upstream("mistral fallback").await; + + let mut env = transcribe_proxy::Env::default(); + env.stt.mistral_api_key = Some("mistral-key".to_string()); + env.stt.openai_api_key = Some("openai-key".to_string()); + + let supabase_env = hypr_api_env::SupabaseEnv { + supabase_url: String::new(), + supabase_anon_key: String::new(), + supabase_service_role_key: String::new(), + }; + + let config = SttProxyConfig::new(&env, &supabase_env) + .with_default_provider(Provider::OpenAI) + .with_upstream_url(Provider::OpenAI, &format!("http://{openai_addr}/v1")) + .with_upstream_url(Provider::Mistral, &format!("http://{mistral_addr}/v1")) + .with_hyprnote_routing(HyprnoteRoutingConfig { + priorities: vec![Provider::OpenAI, Provider::Mistral], + retry_config: Default::default(), + }); + let addr = start_server(config).await; + + let audio_bytes = + std::fs::read(hypr_data::english_1::AUDIO_PATH).expect("failed to read test audio"); + + let response = reqwest::Client::new() + .post(format!( + "http://{addr}/listen?provider=hyprnote&model=gpt-4o-transcribe&language=en" + )) + .header("content-type", "audio/wav") + .header("accept", "text/event-stream") + .body(audio_bytes) + .send() + .await + .expect("request failed"); + + assert_eq!(response.status(), StatusCode::OK); + + let body = response.text().await.expect("failed to read SSE response"); + let events = parse_batch_sse_messages(&body); + + assert_eq!(count_errors(&events), 1); + assert!( + events + .iter() + .any(|event| matches!(event, BatchSseMessage::Progress { .. })) + ); + assert!( + !events + .iter() + .any(|event| matches!(event, BatchSseMessage::Result { .. })) + ); +} + +#[tokio::test] +async fn hyprnote_batch_sse_preflight_failures_return_http_json() { + let env = transcribe_proxy::Env::default(); + let supabase_env = hypr_api_env::SupabaseEnv { + supabase_url: String::new(), + supabase_anon_key: String::new(), + supabase_service_role_key: String::new(), + }; + + let config = SttProxyConfig::new(&env, &supabase_env) + .with_hyprnote_routing(HyprnoteRoutingConfig::default()); + let addr = start_server(config).await; + + let audio_bytes = + std::fs::read(hypr_data::english_1::AUDIO_PATH).expect("failed to read test audio"); + + let response = reqwest::Client::new() + .post(format!( + "http://{addr}/listen?provider=hyprnote&language=en" + )) + .header("content-type", "audio/wav") + .header("accept", "text/event-stream") + .body(audio_bytes) + .send() + .await + .expect("request failed"); + + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + assert!( + response + .headers() + .get("content-type") + .and_then(|value| value.to_str().ok()) + .is_some_and(|value| value.contains("application/json")) + ); + + let body: serde_json::Value = response.json().await.expect("json body"); + assert_eq!(body["error"], "no_providers_available"); +} + +async fn start_openai_sse_upstream() -> SocketAddr { + let app = Router::new().route( + "/v1/audio/transcriptions", + post(|| async { openai_sse_response() }), + ); + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0") + .await + .expect("bind upstream listener"); + let addr = listener.local_addr().expect("upstream local addr"); + + tokio::spawn(async move { + axum::serve(listener, app).await.expect("serve upstream"); + }); + + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + addr +} + +async fn start_openai_incomplete_sse_upstream() -> SocketAddr { + let app = Router::new().route( + "/v1/audio/transcriptions", + post(|| async { openai_incomplete_sse_response() }), + ); + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0") + .await + .expect("bind upstream listener"); + let addr = listener.local_addr().expect("upstream local addr"); + + tokio::spawn(async move { + axum::serve(listener, app).await.expect("serve upstream"); + }); + + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + addr +} + +async fn start_openai_error_upstream(status: StatusCode) -> SocketAddr { + let app = Router::new().route( + "/v1/audio/transcriptions", + post(move || async move { + ( + status, + [(header::CONTENT_TYPE, "application/json")], + Body::from(r#"{"error":"rate_limited"}"#), + ) + .into_response() + }), + ); + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0") + .await + .expect("bind upstream listener"); + let addr = listener.local_addr().expect("upstream local addr"); + + tokio::spawn(async move { + axum::serve(listener, app).await.expect("serve upstream"); + }); + + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + addr +} + +async fn start_json_batch_upstream(transcript: &'static str) -> SocketAddr { + let app = Router::new().route( + "/v1/listen", + post(move |_query: RawQuery| async move { + ( + [(header::CONTENT_TYPE, "application/json")], + Body::from( + serde_json::json!({ + "metadata": {}, + "results": { + "channels": [{ + "alternatives": [{ + "transcript": transcript, + "confidence": 1.0, + "words": [] + }] + }] + } + }) + .to_string(), + ), + ) + .into_response() + }), + ); + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0") + .await + .expect("bind upstream listener"); + let addr = listener.local_addr().expect("upstream local addr"); + + tokio::spawn(async move { + axum::serve(listener, app).await.expect("serve upstream"); + }); + + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + addr +} + +async fn start_mistral_batch_upstream(transcript: &'static str) -> SocketAddr { + let app = Router::new().route( + "/v1/audio/transcriptions", + post(move || async move { + ( + [(header::CONTENT_TYPE, "application/json")], + Body::from( + serde_json::json!({ + "text": transcript, + "words": [], + "segments": [], + }) + .to_string(), + ), + ) + .into_response() + }), + ); + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0") + .await + .expect("bind upstream listener"); + let addr = listener.local_addr().expect("upstream local addr"); + + tokio::spawn(async move { + axum::serve(listener, app).await.expect("serve upstream"); + }); + + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + addr +} + +fn openai_sse_response() -> Response { + let delta = serde_json::json!({ + "type": "transcript.text.delta", + "delta": "hello ", + }); + let done = serde_json::json!({ + "type": "transcript.text.done", + "text": "hello world", + "usage": { + "type": "tokens", + "input_tokens": 1, + "output_tokens": 2, + "total_tokens": 3, + }, + }); + + ( + [(header::CONTENT_TYPE, "text/event-stream")], + Body::from(format!("data: {delta}\n\ndata: {done}\n\n")), + ) + .into_response() +} + +fn openai_incomplete_sse_response() -> Response { + let delta = serde_json::json!({ + "type": "transcript.text.delta", + "delta": "hello ", + }); + + ( + [(header::CONTENT_TYPE, "text/event-stream")], + Body::from(format!("data: {delta}\n\n")), + ) + .into_response() +} + +fn parse_batch_sse_messages(body: &str) -> Vec { + body.split("\n\n") + .filter_map(|block| { + let mut data = String::new(); + + for line in block.lines() { + if let Some(rest) = line.strip_prefix("data:") { + if !data.is_empty() { + data.push('\n'); + } + data.push_str(rest.trim()); + } + } + + if data.is_empty() { + return None; + } + + serde_json::from_str(&data).ok() + }) + .collect() +} + +fn count_errors(events: &[BatchSseMessage]) -> usize { + events + .iter() + .filter(|event| matches!(event, BatchSseMessage::Error { .. })) + .count() +} diff --git a/crates/transcript/src/processor.rs b/crates/transcript/src/processor.rs index 388802ecfc..7b1d98c8db 100644 --- a/crates/transcript/src/processor.rs +++ b/crates/transcript/src/processor.rs @@ -296,8 +296,38 @@ impl PartialSnapshot { #[cfg(test)] mod tests { use super::*; + use owhisper_interface::stream::{Alternatives, Channel, Metadata, StreamResponse, Word}; + use crate::types::RawWord; + fn transcript_response(transcript: &str, word: &str, start: f64, end: f64) -> StreamResponse { + StreamResponse::TranscriptResponse { + is_final: true, + speech_final: false, + from_finalize: false, + start, + duration: end - start, + channel: Channel { + alternatives: vec![Alternatives { + transcript: transcript.to_string(), + words: vec![Word { + word: word.to_string(), + start, + end, + confidence: 1.0, + speaker: None, + punctuated_word: Some(word.to_string()), + language: None, + }], + confidence: 1.0, + languages: vec![], + }], + }, + metadata: Metadata::default(), + channel_index: vec![0, 1], + } + } + #[test] fn partial_snapshot_carries_speaker_index_on_words() { let mut processor = TranscriptProcessor::new(); @@ -342,4 +372,22 @@ mod tests { assert_eq!(snapshot.partials[1].speaker_index, None); assert_eq!(snapshot.partials[2].speaker_index, Some(7)); } + + #[test] + fn final_chunks_with_leading_space_do_not_stitch_into_previous_word() { + let mut processor = TranscriptProcessor::new(); + + let first = transcript_response(" Maybe", "Maybe", 0.0, 1.0); + let second = transcript_response(" this", "this", 1.0, 2.0); + + assert!(processor.process(&first).is_none()); + + let second_delta = processor.process(&second).expect("second delta"); + assert_eq!(second_delta.new_words.len(), 1); + assert_eq!(second_delta.new_words[0].text, " Maybe"); + + let flushed = processor.flush(); + assert_eq!(flushed.new_words.len(), 1); + assert_eq!(flushed.new_words[0].text, " this"); + } } diff --git a/plugins/transcription/src/api.rs b/plugins/transcription/src/api.rs index 27ae606f6d..65be19a447 100644 --- a/plugins/transcription/src/api.rs +++ b/plugins/transcription/src/api.rs @@ -135,6 +135,8 @@ pub struct TranscriptionParams { pub min_speakers: Option, #[serde(default)] pub max_speakers: Option, + #[serde(default)] + pub known_speaker_references: Vec, } #[derive(Debug, Clone, serde::Serialize, serde::Deserialize, specta::Type)] @@ -282,6 +284,7 @@ impl From for listener2::BatchParams { num_speakers: value.num_speakers, min_speakers: value.min_speakers, max_speakers: value.max_speakers, + known_speaker_references: value.known_speaker_references, } } } @@ -314,7 +317,7 @@ mod tests { #[test] fn defaults_openai_capture_to_batch_mode() { - let params = capture_params("https://api.openai.com/v1", "gpt-4o-transcribe"); + let params = capture_params("https://api.openai.com/v1", "gpt-4o-transcribe-diarize"); assert_eq!( params.default_transcription_mode(),