Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 0 additions & 101 deletions FirebaseAI/Sources/Types/Internal/Imagen/ImagenGCSImage.swift

This file was deleted.

30 changes: 0 additions & 30 deletions FirebaseAI/Sources/Types/Public/Imagen/ImagenModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -93,36 +93,6 @@ public final class ImagenModel {
)
}

/// Generates images using the Imagen model and stores them in Cloud Storage (GCS) for Firebase.
///
/// The generated images are stored in a subdirectory of the requested `gcsURI`, named as a random
/// numeric hash. For example, for the `gcsURI` `"gs://bucket-name/path/"`, the generated images
/// are stored in `"gs://bucket-name/path/1234567890123/"` with the names `sample_0.png`,
/// `sample_1.png`, `sample_2.png`, ..., `sample_N.png`. In this example, `1234567890123` is the
/// hash value and `N` is the number of images that were generated, up to the number requested in
/// ``ImagenGenerationConfig/numberOfImages``. The individual ``ImagenGCSImage/gcsURI`` is
/// provided for each of the generated ``ImagenGenerationResponse/images``.
///
/// > Note: By default, 1 image sample is generated; see ``ImagenGenerationConfig/numberOfImages``
/// to configure the number of images that are generated.
///
/// - Parameters:
/// - prompt: A text prompt describing the image(s) to generate.
/// - gcsURI: The Cloud Storage (GCS) for Firebase URI where the generated images are stored.
/// This is a `"gs://"`-prefixed URI , for example, `"gs://bucket-name/path/"`.
///
func generateImages(prompt: String, gcsURI: String) async throws
-> ImagenGenerationResponse<ImagenGCSImage> {
return try await generateImages(
prompt: prompt,
parameters: ImagenModel.imageGenerationParameters(
storageURI: gcsURI,
generationConfig: generationConfig,
safetySettings: safetySettings
)
)
}

func generateImages<T>(prompt: String,
parameters: ImageGenerationParameters) async throws
-> ImagenGenerationResponse<T> where T: Decodable, T: ImagenImageRepresentable {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@ import Testing
import UIKit
#endif // canImport(UIKit)

// TODO(#14452): Remove `@testable import` when `generateImages(prompt:gcsURI:)` is public.
@testable import class FirebaseAILogic.ImagenModel

@Suite(
.enabled(
if: ProcessInfo.processInfo.environment["VTXIntegrationImagen"] != nil,
Expand All @@ -44,7 +41,9 @@ struct ImagenIntegrationTests {
storage = Storage.storage()
}

@Test func generateImage_inlineImage() async throws {
@Test
@available(*, deprecated)
func generateImage_inlineImage() async throws {
let generationConfig = ImagenGenerationConfig(
negativePrompt: "snow, frost",
aspectRatio: .portrait3x4,
Expand Down Expand Up @@ -75,45 +74,9 @@ struct ImagenIntegrationTests {
#endif // canImport(UIKit)
}

@Test func generateImages_gcsImages() async throws {
let generationConfig = ImagenGenerationConfig(
numberOfImages: 3,
aspectRatio: .landscape16x9,
imageFormat: .jpeg(compressionQuality: 60),
addWatermark: true
)
let model = vertex.imagenModel(
modelName: "imagen-3.0-fast-generate-001",
generationConfig: generationConfig,
safetySettings: ImagenSafetySettings(
safetyFilterLevel: .blockMediumAndAbove,
personFilterLevel: .blockAll
)
)
let prompt = "A dense jungle with light streaming through the treetops"
let storageRef = storage.reference(
withPath: "/vertexai/imagen/authenticated/user/\(userID1)"
)

let response = try await model.generateImages(prompt: prompt, gcsURI: storageRef.gsURI)

#expect(response.filteredReason == nil)
#expect(response.images.count == generationConfig.numberOfImages)
for image in response.images {
#expect(image.mimeType == "image/jpeg")
let imageRef = storage.reference(forURL: image.gcsURI)
let imageData = try await imageRef.data(maxSize: 1_000_000) // ~1MB
#expect(imageData.isEmpty == false)
#if canImport(UIKit)
let uiImage = try #require(UIImage(data: imageData))
#expect(uiImage.size.width == 1408.0)
#expect(uiImage.size.height == 768.0)
#endif // canImport(UIKit)
try await imageRef.delete()
}
}

@Test func generateImage_allImagesFilteredOut() async throws {
@Test
@available(*, deprecated)
func generateImage_allImagesFilteredOut() async throws {
let generationConfig = ImagenGenerationConfig(numberOfImages: 2, imageFormat: .jpeg())
let model = vertex.imagenModel(
modelName: "imagen-3.0-fast-generate-001",
Expand All @@ -134,8 +97,4 @@ struct ImagenIntegrationTests {
return error.localizedDescription.contains("39322892")
}
}

// TODO(#14221): Add an integration test for the prompt being blocked.

// TODO(#14452): Add integration tests for validating that Storage Rules are enforced.
}
80 changes: 0 additions & 80 deletions FirebaseAI/Tests/Unit/Types/Imagen/ImagenGCSImageTests.swift

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -66,26 +66,6 @@ final class ImagenGenerationRequestTests: XCTestCase {
)
}

func testInitializeRequest_fileDataImage() throws {
let request = ImagenGenerationRequest<ImagenGCSImage>(
model: modelName,
apiConfig: apiConfig,
options: requestOptions,
instances: [instance],
parameters: parameters
)

XCTAssertEqual(request.model, modelName)
XCTAssertEqual(request.options, requestOptions)
XCTAssertEqual(request.instances, [instance])
XCTAssertEqual(request.parameters, parameters)
XCTAssertEqual(
try request.getURL(),
URL(string:
"\(apiConfig.service.endpoint.rawValue)/\(apiConfig.version.rawValue)/\(modelName):predict")
)
}

// MARK: - Encoding Tests

func testEncodeRequest_inlineDataImage() throws {
Expand Down Expand Up @@ -117,34 +97,4 @@ final class ImagenGenerationRequestTests: XCTestCase {
}
""")
}

func testEncodeRequest_fileDataImage() throws {
let request = ImagenGenerationRequest<ImagenGCSImage>(
model: modelName,
apiConfig: apiConfig,
options: RequestOptions(),
instances: [instance],
parameters: parameters
)

let jsonData = try encoder.encode(request)

let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
XCTAssertEqual(json, """
{
"instances" : [
{
"prompt" : "\(instance.prompt)"
}
],
"parameters" : {
"aspectRatio" : "\(aspectRatio)",
"includeRaiReason" : \(includeResponsibleAIFilterReason),
"includeSafetyAttributes" : \(includeSafetyAttributes),
"safetySetting" : "\(safetyFilterLevel)",
"sampleCount" : \(sampleCount)
}
}
""")
}
}
Loading
Loading