Skip to content

Commit 89cfdde

Browse files
speedstorm1copybara-github
authored andcommitted
feat: support hyperparameters in distillation tuning
PiperOrigin-RevId: 890622141
1 parent 30bf7bb commit 89cfdde

File tree

1 file changed

+53
-4
lines changed

1 file changed

+53
-4
lines changed

GeneratedFirebaseAI/Sources/Types.swift

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20329,7 +20329,7 @@ extension PreferenceOptimizationSpec: Codable {
2032920329
}
2033020330
}
2033120331

20332-
/// Hyperparameters for Distillation. This data type is not supported in Gemini API.
20332+
/// Hyperparameters for distillation.
2033320333
@available(iOS 15.0, macOS 13.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
2033420334
public struct DistillationHyperParameters: Sendable {
2033520335
/// Optional. Adapter size for distillation.
@@ -20342,15 +20342,26 @@ public struct DistillationHyperParameters: Sendable {
2034220342
/// Optional. Multiplier for adjusting the default learning rate.
2034320343
public let learningRateMultiplier: Double?
2034420344

20345+
/// The batch size hyperparameter for tuning.
20346+
/// This is only supported for OSS models in Vertex.
20347+
public let batchSize: Int32?
20348+
20349+
/// The learning rate for tuning. OSS models only.
20350+
public let learningRate: Float?
20351+
2034520352
/// Default initializer.
2034620353
public init(
2034720354
adapterSize: AdapterSize? = nil,
2034820355
epochCount: Int64? = nil,
20349-
learningRateMultiplier: Double? = nil
20356+
learningRateMultiplier: Double? = nil,
20357+
batchSize: Int32? = nil,
20358+
learningRate: Float? = nil
2035020359
) {
2035120360
self.adapterSize = adapterSize
2035220361
self.epochCount = epochCount
2035320362
self.learningRateMultiplier = learningRateMultiplier
20363+
self.batchSize = batchSize
20364+
self.learningRate = learningRate
2035420365
}
2035520366
}
2035620367

@@ -20362,6 +20373,8 @@ extension DistillationHyperParameters: Codable {
2036220373
case adapterSize = "adapterSize"
2036320374
case epochCount = "epochCount"
2036420375
case learningRateMultiplier = "learningRateMultiplier"
20376+
case batchSize = "batchSize"
20377+
case learningRate = "learningRate"
2036520378
}
2036620379

2036720380
public init(from decoder: any Decoder) throws {
@@ -20382,6 +20395,16 @@ extension DistillationHyperParameters: Codable {
2038220395
Double.self,
2038320396
forKey: .learningRateMultiplier
2038420397
)
20398+
20399+
batchSize = try VertexKeysContainer.decodeIfPresent(
20400+
Int32.self,
20401+
forKey: .batchSize
20402+
)
20403+
20404+
learningRate = try VertexKeysContainer.decodeIfPresent(
20405+
Float.self,
20406+
forKey: .learningRate
20407+
)
2038520408
}
2038620409

2038720410
public func encode(to encoder: any Encoder) throws {
@@ -20405,6 +20428,16 @@ extension DistillationHyperParameters: Codable {
2040520428
forKey: .learningRateMultiplier
2040620429
)
2040720430

20431+
try VertexKeysContainer.encodeIfPresent(
20432+
batchSize,
20433+
forKey: .batchSize
20434+
)
20435+
20436+
try VertexKeysContainer.encodeIfPresent(
20437+
learningRate,
20438+
forKey: .learningRate
20439+
)
20440+
2040820441
}
2040920442
}
2041020443
}
@@ -20444,6 +20477,9 @@ public struct DistillationSpec: Sendable {
2044420477
/// The dataset must be formatted as a JSONL file.
2044520478
public let validationDatasetUri: String?
2044620479

20480+
/// Tuning mode for tuning.
20481+
public let tuningMode: TuningMode?
20482+
2044720483
/// Default initializer.
2044820484
public init(
2044920485
promptDatasetUri: String? = nil,
@@ -20453,7 +20489,8 @@ public struct DistillationSpec: Sendable {
2045320489
studentModel: String? = nil,
2045420490
trainingDatasetUri: String? = nil,
2045520491
tunedTeacherModelSource: String? = nil,
20456-
validationDatasetUri: String? = nil
20492+
validationDatasetUri: String? = nil,
20493+
tuningMode: TuningMode? = nil
2045720494
) {
2045820495
self.promptDatasetUri = promptDatasetUri
2045920496
self.baseTeacherModel = baseTeacherModel
@@ -20463,6 +20500,7 @@ public struct DistillationSpec: Sendable {
2046320500
self.trainingDatasetUri = trainingDatasetUri
2046420501
self.tunedTeacherModelSource = tunedTeacherModelSource
2046520502
self.validationDatasetUri = validationDatasetUri
20503+
self.tuningMode = tuningMode
2046620504
}
2046720505
}
2046820506

@@ -20479,6 +20517,7 @@ extension DistillationSpec: Codable {
2047920517
case trainingDatasetUri = "trainingDatasetUri"
2048020518
case tunedTeacherModelSource = "tunedTeacherModelSource"
2048120519
case validationDatasetUri = "validationDatasetUri"
20520+
case tuningMode = "tuningMode"
2048220521
}
2048320522

2048420523
public init(from decoder: any Decoder) throws {
@@ -20524,6 +20563,11 @@ extension DistillationSpec: Codable {
2052420563
String.self,
2052520564
forKey: .validationDatasetUri
2052620565
)
20566+
20567+
tuningMode = try VertexKeysContainer.decodeIfPresent(
20568+
TuningMode.self,
20569+
forKey: .tuningMode
20570+
)
2052720571
}
2052820572

2052920573
public func encode(to encoder: any Encoder) throws {
@@ -20572,6 +20616,11 @@ extension DistillationSpec: Codable {
2057220616
forKey: .validationDatasetUri
2057320617
)
2057420618

20619+
try VertexKeysContainer.encodeIfPresent(
20620+
tuningMode,
20621+
forKey: .tuningMode
20622+
)
20623+
2057520624
}
2057620625
}
2057720626
}
@@ -24834,7 +24883,7 @@ public struct CreateTuningJobConfig: Sendable {
2483424883
/// Adapter size for tuning.
2483524884
public let adapterSize: AdapterSize?
2483624885

24837-
/// Tuning mode for SFT tuning.
24886+
/// Tuning mode for tuning.
2483824887
public let tuningMode: TuningMode?
2483924888

2484024889
/// Custom base model for tuning. This is only supported for OSS models in Vertex.

0 commit comments

Comments
 (0)