Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions FirebaseAI/Sources/AutomaticFunction.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import Foundation

/// A wrapper for a function declaration and its executable logic.
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public struct AutomaticFunction: Sendable {
/// The declaration of the function, describing it to the model.
public let declaration: FunctionDeclaration

/// The closure to execute when the function is called.
public let execute: @Sendable ([String: JSONValue]) async throws -> JSONObject

/// Creates a new `AutomaticFunction`.
/// - Parameters:
/// - declaration: The function declaration.
/// - execute: The execution logic.
public init(declaration: FunctionDeclaration,
execute: @escaping @Sendable ([String: JSONValue]) async throws -> JSONObject) {
self.declaration = declaration
self.execute = execute
}

/// Creates a new `AutomaticFunction` with a simplified declaration.
/// - Parameters:
/// - name: The name of the function.
/// - description: A brief description of the function.
/// - parameters: Describes the parameters to this function.
/// - optionalParameters: The names of parameters that may be omitted by the model.
/// - execute: The execution logic.
public init(name: String,
description: String,
parameters: [String: Schema] = [:],
optionalParameters: [String] = [],
execute: @escaping @Sendable ([String: JSONValue]) async throws -> JSONObject) {
declaration = FunctionDeclaration(name: name,
description: description,
parameters: parameters,
optionalParameters: optionalParameters)
self.execute = execute
}
}

#if canImport(FoundationModels)
import FoundationModels

@available(iOS 26.0, macOS 26.0, *)
@available(tvOS, unavailable)
@available(watchOS, unavailable)
public extension AutomaticFunction {
/// Creates an `AutomaticFunction` from a `FoundationModels.Tool`.
///
/// - Parameter tool: The `FoundationModels.Tool` instance to wrap.
init<T: FoundationModels.Tool>(_ tool: T) throws {
// Convert FoundationModels.GenerationSchema to FirebaseAI.Schema (via JSONSchema)
// Tool.parameters is a GenerationSchema instance.
// We encode it to JSON and decode it as our JSONSchema type.
let data = try JSONEncoder().encode(tool.parameters)
let jsonSchema = try JSONDecoder().decode(JSONSchema.self, from: data)
let firebaseSchema = try jsonSchema.asSchema()

// Extract parameter properties
let properties = firebaseSchema.properties ?? [:]
let required = firebaseSchema.requiredProperties ?? []
let requiredSet = Set(required)

self.init(
name: tool.name,
description: tool.description,
parameters: properties,
optionalParameters: properties.keys.filter { !requiredSet.contains($0) }
) { args in
// Convert [String: JSONValue] -> JSONObject (ModelOutput) -> GeneratedContent ->
// T.Arguments
let modelOutput = ModelOutput(jsonValue: .object(args))

let generatedContent = modelOutput.generatedContent
let toolArgs = try T.Arguments(generatedContent)

// Execute the tool
let result = try await tool.call(arguments: toolArgs)

// Convert result -> JSON
// We assume the output is Encodable (common for Generable/PromptRepresentable types that
// are data).
// If it's just a String, we wrap it.
if let encodableResult = result as? Encodable {
let encoder = JSONEncoder()
let data = try encoder.encode(encodableResult)
let jsonValue = try JSONDecoder().decode(JSONValue.self, from: data)
if case let .object(jsonObject) = jsonValue {
return jsonObject
} else {
return ["result": jsonValue]
}
}

// Fallback for non-Encodable or other types: String description
return ["result": .string(String(describing: result))]
}
}
}
#endif
165 changes: 130 additions & 35 deletions FirebaseAI/Sources/Chat.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,19 @@ public final class Chat: Sendable {
private let model: GenerativeModel
private let _history: History

/// A safeguard to prevent infinite loops in function calling.
private static let maxFunctionCalls = 10

private static var maxTurnsError: Error {
GenerateContentError.internalError(underlying: NSError(
domain: "FirebaseAI",
code: -1,
userInfo: [
NSLocalizedDescriptionKey: "Max automatic function calling turns reached.",
]
))
}

init(model: GenerativeModel, history: [ModelContent]) {
self.model = model
_history = History(history: history)
Expand Down Expand Up @@ -57,25 +70,43 @@ public final class Chat: Sendable {
// Ensure that the new content has the role set.
let newContent = content.map(populateContentRole(_:))

// Send the history alongside the new message as context.
let request = history + newContent
let result = try await model.generateContent(request)
guard let reply = result.candidates.first?.content else {
let error = NSError(domain: "com.google.generative-ai",
code: -1,
userInfo: [
NSLocalizedDescriptionKey: "No candidates with content available.",
])
throw GenerateContentError.internalError(underlying: error)
}
var response: GenerateContentResponse
var functionCallCount = 0
var userContentCommitted = false

while true {
// Send the history as context.
// If we haven't committed the user content yet, send it as part of the request
// but don't add it to the official history until we get a successful response.
let requestContent = userContentCommitted ? history : history + newContent
response = try await model.generateContent(requestContent)

guard let candidate = response.candidates.first else {
let error = NSError(domain: "com.google.generative-ai",
code: -1,
userInfo: [
NSLocalizedDescriptionKey: "No candidates with content available.",
])
throw GenerateContentError.internalError(underlying: error)
}

// Make sure we inject the role into the content received.
let toAdd = ModelContent(role: "model", parts: reply.parts)
// Commit user content if not yet done.
if !userContentCommitted {
_history.append(contentsOf: newContent)
userContentCommitted = true
}

let modelContent = ModelContent(role: "model", parts: candidate.content.parts)
let shouldContinue = try await handleFunctionCallingTurn(
modelContent,
functionCallCount: &functionCallCount
)
if !shouldContinue {
break
}
}

// Append the request and successful result to history, then return the value.
_history.append(contentsOf: newContent)
_history.append(toAdd)
return result
return response
}

/// Sends a message using the existing history of this chat as context. If successful, the message
Expand All @@ -98,36 +129,52 @@ public final class Chat: Sendable {
// Ensure that the new content has the role set.
let newContent: [ModelContent] = content.map(populateContentRole(_:))

// Send the history alongside the new message as context.
let request = history + newContent
let stream = try model.generateContentStream(request)
return AsyncThrowingStream { continuation in
Task {
var aggregatedContent: [ModelContent] = []
var functionCallCount = 0
var userContentCommitted = false

do {
for try await chunk in stream {
// Capture any content that's streaming. This should be populated if there's no error.
if let chunkContent = chunk.candidates.first?.content {
aggregatedContent.append(chunkContent)
while true {
// If we haven't committed the user content yet, send it as part of the request
// but don't add it to the official history until we get a successful response start.
let requestContent = userContentCommitted ? history : history + newContent
let stream = try model.generateContentStream(requestContent)
var aggregatedContent: [ModelContent] = []

for try await chunk in stream {
// Capture any content that's streaming. This should be populated if there's no error.
if let chunkContent = chunk.candidates.first?.content {
aggregatedContent.append(chunkContent)
}

// Pass along the chunk.
continuation.yield(chunk)
}

// Stream finished successfully.
// Commit user content if not yet done.
if !userContentCommitted {
_history.append(contentsOf: newContent)
userContentCommitted = true
}

// Pass along the chunk.
continuation.yield(chunk)
// Aggregate the content to add it to the history.
let aggregated = _history.aggregatedChunks(aggregatedContent)
let shouldContinue = try await self.handleFunctionCallingTurn(
aggregated,
functionCallCount: &functionCallCount
)
if !shouldContinue {
break
}
}
continuation.finish()
} catch {
// Rethrow the error that the underlying stream threw. Don't add anything to history.
continuation.finish(throwing: error)
return
}

// Save the request.
_history.append(contentsOf: newContent)

// Aggregate the content to add it to the history before we finish.
let aggregated = self._history.aggregatedChunks(aggregatedContent)
self._history.append(aggregated)
continuation.finish()
}
}
}
Expand All @@ -140,4 +187,52 @@ public final class Chat: Sendable {
return ModelContent(role: "user", parts: content.parts)
}
}

private func handleFunctionCallingTurn(_ modelContent: ModelContent,
functionCallCount: inout Int) async throws -> Bool {
_history.append(modelContent)

if let responseContent = try await executeFunctionCalls(from: modelContent) {
_history.append(responseContent)
functionCallCount += 1
if functionCallCount >= Chat.maxFunctionCalls {
throw Chat.maxTurnsError
}
return true
}

return false
}

private func executeFunctionCalls(from content: ModelContent) async throws -> ModelContent? {
let functionCalls = content.parts.compactMap { ($0 as? FunctionCallPart)?.functionCall }
let handlers = model.functionHandlers
let callsToHandle = functionCalls.compactMap { call in
handlers[call.name].map { (call, $0) }
}
guard !callsToHandle.isEmpty else {
return nil
}

let functionResponses = try await withThrowingTaskGroup(
of: FunctionResponsePart.self,
returning: [FunctionResponsePart].self
) { group in
for (call, handler) in callsToHandle {
group.addTask {
let result = try await handler(call.args)
return FunctionResponsePart(name: call.name, response: result)
}
}

var responses: [FunctionResponsePart] = []
responses.reserveCapacity(callsToHandle.count)
for try await part in group {
responses.append(part)
}
return responses
}

return ModelContent(role: "function", parts: functionResponses)
}
}
8 changes: 7 additions & 1 deletion FirebaseAI/Sources/FirebaseAI.swift
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ public final class FirebaseAI: Sendable {
/// - generationConfig: The content generation parameters your model should use.
/// - safetySettings: A value describing what types of harmful content your model should allow.
/// - tools: A list of ``Tool`` objects that the model may use to generate the next response.
/// - automaticFunctionTools: A list of ``AutomaticFunction``s that the model may use to
/// generate the next response. The model will automatically call these functions when the
/// corresponding function call is generated.
/// - toolConfig: Tool configuration for any `Tool` specified in the request.
/// - systemInstruction: Instructions that direct the model to behave a certain way; currently
/// only text content is supported.
Expand All @@ -78,6 +81,7 @@ public final class FirebaseAI: Sendable {
generationConfig: GenerationConfig? = nil,
safetySettings: [SafetySetting]? = nil,
tools: [Tool]? = nil,
automaticFunctionTools: [AutomaticFunction]? = nil,
toolConfig: ToolConfig? = nil,
systemInstruction: ModelContent? = nil,
requestOptions: RequestOptions = RequestOptions())
Expand All @@ -98,9 +102,11 @@ public final class FirebaseAI: Sendable {
generationConfig: generationConfig,
safetySettings: safetySettings,
tools: tools,
automaticFunctionTools: automaticFunctionTools,
toolConfig: toolConfig,
systemInstruction: systemInstruction,
requestOptions: requestOptions
requestOptions: requestOptions,
urlSession: GenAIURLSession.default
)
}

Expand Down
Loading
Loading