Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 23 additions & 71 deletions FirebaseAI/Tests/TestApp/Tests/Integration/LiveSessionTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,7 @@ struct LiveSessionTests {
),
]),
]
private let textConfig = LiveGenerationConfig(
responseModalities: [.text]
)
private let audioConfig = LiveGenerationConfig(
private let generationConfig = LiveGenerationConfig(
responseModalities: [.audio],
outputAudioTranscription: AudioTranscriptionConfig()
)
Expand All @@ -76,8 +73,8 @@ struct LiveSessionTests {
role: "system",
parts: """
When you receive a message, if the message is a single word, assume it's the first name of a \
person, and call the getLastName tool to get the last name of said person. Only respond with \
the last name.
person, and call the getLastName tool to get the last name of said person. Once you get the \
response, say the response.
""".trimmingCharacters(in: .whitespacesAndNewlines)
)

Expand All @@ -95,7 +92,7 @@ struct LiveSessionTests {
modelName: String) async throws {
let model = FirebaseAI.componentInstance(config).liveModel(
modelName: modelName,
generationConfig: audioConfig,
generationConfig: generationConfig,
systemInstruction: SystemInstructions.helloGoodbye
)

Expand All @@ -119,15 +116,12 @@ struct LiveSessionTests {
#expect(modelResponse == "goodbye")
}

@Test(
.disabled("Temporarily disabled"),
.bug("https://github.com/firebase/firebase-ios-sdk/issues/15640"),
arguments: arguments
)
func sendVideoRealtime_receiveText(_ config: InstanceConfig, modelName: String) async throws {
@Test(arguments: arguments)
func sendVideoRealtime_receiveAudioOutputTranscripts(_ config: InstanceConfig,
modelName: String) async throws {
let model = FirebaseAI.componentInstance(config).liveModel(
modelName: modelName,
generationConfig: textConfig,
generationConfig: generationConfig,
systemInstruction: SystemInstructions.animalInVideo
)

Expand All @@ -152,7 +146,7 @@ struct LiveSessionTests {
await session.sendAudioRealtime(audioFile.data)
await session.sendAudioRealtime(Data(repeating: 0, count: audioFile.data.count))

let text = try await session.collectNextTextResponse()
let text = try await session.collectNextAudioOutputTranscript()

await session.close()
let modelResponse = text
Expand All @@ -164,15 +158,11 @@ struct LiveSessionTests {
#expect(["kitten", "cat", "kitty"].contains(modelResponse))
}

@Test(
.disabled("Temporarily disabled"),
.bug("https://github.com/firebase/firebase-ios-sdk/issues/15640"),
arguments: arguments
)
@Test(arguments: arguments)
func realtime_functionCalling(_ config: InstanceConfig, modelName: String) async throws {
let model = FirebaseAI.componentInstance(config).liveModel(
modelName: modelName,
generationConfig: textConfig,
generationConfig: generationConfig,
tools: tools,
systemInstruction: SystemInstructions.lastNames
)
Expand Down Expand Up @@ -200,11 +190,9 @@ struct LiveSessionTests {
functionId: functionCall.functionId
),
])

var text = try await session.collectNextTextResponse()
var text = try await session.collectNextAudioOutputTranscript()
if text.isEmpty {
// The model sometimes sends an empty text response first
text = try await session.collectNextTextResponse()
text = try await session.collectNextAudioOutputTranscript()
}

await session.close()
Expand All @@ -217,8 +205,6 @@ struct LiveSessionTests {
}

@Test(
.disabled("Temporarily disabled"),
.bug("https://github.com/firebase/firebase-ios-sdk/issues/15640"),
arguments: arguments.filter {
// TODO: (b/450982184) Remove when Vertex AI adds support for Function IDs and Cancellation
switch $0.0.apiConfig.service {
Expand All @@ -233,7 +219,7 @@ struct LiveSessionTests {
modelName: String) async throws {
let model = FirebaseAI.componentInstance(config).liveModel(
modelName: modelName,
generationConfig: textConfig,
generationConfig: generationConfig,
tools: tools,
systemInstruction: SystemInstructions.lastNames
)
Expand Down Expand Up @@ -266,7 +252,7 @@ struct LiveSessionTests {
func realtime_interruption(_ config: InstanceConfig, modelName: String) async throws {
let model = FirebaseAI.componentInstance(config).liveModel(
modelName: modelName,
generationConfig: audioConfig
generationConfig: generationConfig
)

let audioFile = try #require(
Expand Down Expand Up @@ -295,23 +281,23 @@ struct LiveSessionTests {
}
}

@Test(
.disabled("Temporarily disabled"),
.bug("https://github.com/firebase/firebase-ios-sdk/issues/15640"),
arguments: arguments
)
@Test(arguments: arguments)
func incremental_works(_ config: InstanceConfig, modelName: String) async throws {
let model = FirebaseAI.componentInstance(config).liveModel(
modelName: modelName,
generationConfig: textConfig,
generationConfig: generationConfig,
systemInstruction: SystemInstructions.yesOrNo
)

let session = try await model.connect()
await session.sendContent("Does five plus")
await session.sendContent(" five equal ten?", turnComplete: true)

let text = try await session.collectNextTextResponse()
var text = try await session.collectNextAudioOutputTranscript()
if text.isEmpty {
// The model sometimes sends an empty text response first
text = try await session.collectNextAudioOutputTranscript()
}

await session.close()
let modelResponse = text
Expand Down Expand Up @@ -339,26 +325,6 @@ struct LiveSessionTests {
}

private extension LiveSession {
/// Collects the text that the model sends for the next turn.
///
/// Will listen for `LiveServerContent` messages from the model,
/// incrementally keeping track of any `TextPart`s it sends. Once
/// the model signals that its turn is complete, the function will return
/// a string concatenated of all the `TextPart`s.
func collectNextTextResponse() async throws -> String {
var text = ""

for try await content in responsesOf(LiveServerContent.self) {
text += content.modelTurn?.allText() ?? ""

if content.isTurnComplete {
break
}
}

return text
}

/// Collects the audio output transcripts that the model sends for the next turn.
///
/// Will listen for `LiveServerContent` messages from the model,
Expand Down Expand Up @@ -395,11 +361,7 @@ private extension LiveSession {
case let .toolCall(toolCall):
return toolCall
case let .content(content):
if let text = content.modelTurn?.allText() {
error += text
} else {
error += content.outputAudioText()
}
error += content.outputAudioText()

if content.isTurnComplete {
Issue.record("The model didn't send a tool call. Text received: \(error)")
Expand Down Expand Up @@ -464,16 +426,6 @@ private struct NoInterruptionError: Error,
var description: String { "The model never sent an interrupted message." }
}

private extension ModelContent {
/// A collection of text from all parts.
///
/// If this doesn't contain any `TextPart`, then an empty
/// string will be returned instead.
func allText() -> String {
parts.compactMap { ($0 as? TextPart)?.text }.joined()
}
}

extension LiveServerContent {
/// Text of the output `LiveAudioTranscript`, or an empty string if it's missing.
func outputAudioText() -> String {
Expand Down
Loading