Skip to content

Commit 331dd9d

Browse files
authored
Refactoring: move SurrogateBuilder and mixed integer moe from ego to moe (#382)
* feat: add SurrogateBuilder trait and XType enum for surrogate model configuration - Introduced the `SurrogateBuilder` trait in `surrogate_builder.rs` to define methods for configuring and training surrogate models used by the Egor optimizer. - Implemented the `SurrogateBuilder` trait for `GpMixtureParams<f64>`, providing functionality for setting regression specifications, correlation models, clustering, and training methods. - Added the `XType` enum in `xtypes.rs` to represent different variable types (Float, Int, Ord, Enum) and a helper function `discrete` to check for discrete types in a given list of `XType`. - Included unit tests for the `discrete` function to ensure correct identification of discrete variable types. * Deprecate re-export from moe * Linting * Documention * fix: improve documentation for LHS design and correlation model trait
1 parent 8d9c64c commit 331dd9d

File tree

23 files changed

+224
-189
lines changed

23 files changed

+224
-189
lines changed

crates/doe/src/lhs.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ pub enum LhsKind {
3434

3535
type RngRef<R> = Arc<RwLock<R>>;
3636

37-
/// The LHS design is built as follows: each dimension space is divided into ns sections
38-
/// where ns is the number of sampling points, and one point in selected in each section.
39-
/// The selection method gives different kind of LHS (see [LhsKind])
37+
/// The LHS design is built as follows: each dimension space is divided into `ns` sections
38+
/// where `ns` is the number of sampling points, and one point is selected in each section.
39+
/// The selection method corresponds to different kinds of LHS: see [`LhsKind`]
4040
#[derive(Clone, Debug)]
4141
#[cfg_attr(feature = "serializable", derive(Serialize, Deserialize))]
4242
pub struct Lhs<F: Float, R: Rng> {

crates/ego/src/criteria/ei.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -176,11 +176,9 @@ pub const LOG_EI: LogExpectedImprovement = LogExpectedImprovement {};
176176
#[cfg(test)]
177177
mod tests {
178178
use super::*;
179-
use crate::{
180-
gpmix::mixint::{MixintContext, MoeBuilder},
181-
types::*,
182-
};
179+
use crate::types::*;
183180
use approx::assert_abs_diff_eq;
181+
use egobox_moe::{MixintContext, MoeBuilder};
184182
// use egobox_moe::GpSurrogate;
185183
use finitediff::vec;
186184
use linfa::Dataset;

crates/ego/src/egor.rs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,10 @@ use crate::EgorConfig;
105105
use crate::EgorState;
106106
use crate::HotStartMode;
107107
use crate::errors::Result;
108-
use crate::gpmix::mixint::*;
109108
use crate::types::*;
110109
use crate::{CHECKPOINT_FILE, CheckpointingFrequency, HotStartCheckpoint};
111110
use crate::{EgorSolver, to_xtypes};
111+
use egobox_moe::{MixintGpMixtureParams, to_discrete_space};
112112

113113
use argmin::core::observers::ObserverMode;
114114

@@ -117,7 +117,7 @@ use log::info;
117117
use ndarray::{Array2, ArrayBase, Axis, Data, Ix2, concatenate};
118118

119119
use argmin::core::{Error, Executor, KV, State, observers::Observe};
120-
use serde::de::DeserializeOwned;
120+
use serde::{Serialize, de::DeserializeOwned};
121121

122122
use ndarray_npy::write_npy;
123123
use std::path::PathBuf;
@@ -219,14 +219,14 @@ impl<O: GroupFunc, C: CstrFn> EgorFactory<O, C> {
219219
pub struct Egor<
220220
O: GroupFunc,
221221
C: CstrFn = Cstr,
222-
SB: SurrogateBuilder + DeserializeOwned = GpMixtureParams<f64>,
222+
SB: SurrogateBuilder + Serialize + DeserializeOwned = GpMixtureParams<f64>,
223223
> {
224224
fobj: ObjFunc<O, C>,
225225
solver: EgorSolver<SB, C>,
226226
run_info: Option<RunInfo>,
227227
}
228228

229-
impl<O: GroupFunc, C: CstrFn, SB: SurrogateBuilder + DeserializeOwned> Egor<O, C, SB> {
229+
impl<O: GroupFunc, C: CstrFn, SB: SurrogateBuilder + Serialize + DeserializeOwned> Egor<O, C, SB> {
230230
/// Runs the (constrained) optimization of the objective function.
231231
pub fn run(&self) -> Result<OptimResult<f64>> {
232232
let xtypes = self.solver.config.xtypes.clone();
@@ -422,7 +422,7 @@ mod tests {
422422
use argmin::core::{TerminationReason, TerminationStatus};
423423
use argmin_testfunctions::rosenbrock;
424424
use egobox_doe::{Lhs, SamplingMethod};
425-
use egobox_moe::NbClusters;
425+
use egobox_moe::{NbClusters, as_continuous_limits};
426426
use ndarray::{Array1, Array2, ArrayView1, ArrayView2, Ix1, Zip, array, s};
427427
use ndarray_rand::rand::SeedableRng;
428428
use rand_xoshiro::Xoshiro256Plus;
@@ -432,7 +432,8 @@ mod tests {
432432
use serial_test::serial;
433433
use std::time::Instant;
434434

435-
use crate::{CoegoStatus, DOE_FILE, DOE_INITIAL_FILE, gpmix::spec::*, utils::EGOBOX_LOG};
435+
use crate::{CoegoStatus, DOE_FILE, DOE_INITIAL_FILE, utils::EGOBOX_LOG};
436+
use egobox_moe::{CorrelationSpec, RegressionSpec};
436437

437438
#[cfg(not(feature = "blas"))]
438439
use linfa_linalg::norm::*;

crates/ego/src/gpmix/spec.rs

Lines changed: 0 additions & 7 deletions
This file was deleted.

crates/ego/src/lib.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,6 @@
314314
#![warn(rustdoc::broken_intra_doc_links)]
315315

316316
pub mod criteria;
317-
pub mod gpmix;
318317

319318
mod egor;
320319
mod errors;
@@ -323,14 +322,14 @@ mod types;
323322

324323
pub use crate::egor::*;
325324
pub use crate::errors::*;
326-
pub use crate::gpmix::spec::{CorrelationSpec, RegressionSpec};
327325
pub use crate::solver::*;
328326
pub use crate::types::*;
329327
pub use crate::utils::{
330328
CHECKPOINT_FILE, Checkpoint, CheckpointingFrequency, EGOBOX_LOG, EGOR_GP_FILENAME,
331329
EGOR_INITIAL_GP_FILENAME, EGOR_USE_GP_RECORDER, EGOR_USE_GP_VAR_PORTFOLIO,
332330
EGOR_USE_MAX_PROBA_OF_FEASIBILITY, HotStartCheckpoint, HotStartMode, find_best_result_index,
333331
};
332+
pub use egobox_moe::{CorrelationSpec, RegressionSpec};
334333

335334
mod optimizers;
336335
mod utils;

crates/ego/src/solver/coego.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use egobox_gp::ThetaTuning;
77
use egobox_moe::MixtureGpSurrogate;
88
use ndarray::{Array, Array1, Array2, ArrayBase, Axis, Data, Ix1, RemoveAxis, s};
99
use rand_xoshiro::Xoshiro256Plus;
10-
use serde::de::DeserializeOwned;
10+
use serde::{Serialize, de::DeserializeOwned};
1111

1212
use ndarray_rand::rand::seq::SliceRandom;
1313

@@ -45,7 +45,7 @@ where
4545

4646
impl<SB, C> EgorSolver<SB, C>
4747
where
48-
SB: SurrogateBuilder + DeserializeOwned,
48+
SB: SurrogateBuilder + Serialize + DeserializeOwned,
4949
C: CstrFn,
5050
{
5151
/// Compute array of components indices each row being used as

crates/ego/src/solver/egor_service.rs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,11 @@
4444
//!
4545
use std::marker::PhantomData;
4646

47-
use crate::{EgorConfig, EgorSolver, errors::Result, gpmix::mixint::*, to_xtypes, types::*};
48-
49-
use egobox_moe::GpMixtureParams;
47+
use crate::{EgorConfig, EgorSolver, errors::Result, to_xtypes, types::*};
48+
use egobox_moe::{GpMixtureParams, MixintGpMixtureParams, to_continuous_space, to_discrete_space};
5049
use ndarray::{Array2, ArrayBase, Data, Ix2};
5150

52-
use serde::de::DeserializeOwned;
51+
use serde::{Serialize, de::DeserializeOwned};
5352

5453
/// EGO optimizer service builder allowing to use Egor optimizer
5554
/// as a service.
@@ -109,11 +108,11 @@ impl<C: CstrFn> EgorServiceFactory<C> {
109108

110109
/// Egor optimizer service API.
111110
#[derive(Clone)]
112-
pub struct EgorServiceApi<SB: SurrogateBuilder + DeserializeOwned, C: CstrFn = Cstr> {
111+
pub struct EgorServiceApi<SB: SurrogateBuilder + Serialize + DeserializeOwned, C: CstrFn = Cstr> {
113112
solver: EgorSolver<SB, C>,
114113
}
115114

116-
impl<SB: SurrogateBuilder + DeserializeOwned, C: CstrFn> EgorServiceApi<SB, C> {
115+
impl<SB: SurrogateBuilder + Serialize + DeserializeOwned, C: CstrFn> EgorServiceApi<SB, C> {
117116
/// Given an evaluated doe (x, y) data, return the next promising x point
118117
/// where optimum may be located with regard to the infill criterion.
119118
/// This function inverses the control of the optimization and can be used
@@ -136,8 +135,8 @@ pub type EgorServiceBuilder = EgorServiceFactory<Cstr>;
136135
#[cfg(test)]
137136
mod tests {
138137
use super::*;
139-
use crate::gpmix::spec::*;
140138
use approx::assert_abs_diff_eq;
139+
use egobox_moe::{CorrelationSpec, RegressionSpec};
141140
use ndarray::{ArrayView2, Axis, array, concatenate};
142141

143142
use ndarray_stats::QuantileExt;

crates/ego/src/solver/egor_solver.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ impl<O, SB, C> Solver<O, EgorState<f64>> for EgorSolver<SB, C>
173173
where
174174
O: CostFunction<Param = Array2<f64>, Output = Array2<f64>> + DomainConstraints<C>,
175175
C: CstrFn,
176-
SB: SurrogateBuilder + DeserializeOwned,
176+
SB: SurrogateBuilder + Serialize + DeserializeOwned,
177177
{
178178
fn name(&self) -> &str {
179179
"Egor"
@@ -446,7 +446,7 @@ where
446446

447447
impl<SB, C: CstrFn> EgorSolver<SB, C>
448448
where
449-
SB: SurrogateBuilder + DeserializeOwned,
449+
SB: SurrogateBuilder + Serialize + DeserializeOwned,
450450
{
451451
/// Iteration of EGO algorithm
452452
fn ego_iteration<

crates/ego/src/solver/solver_computations.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use crate::errors::Result;
2-
use crate::gpmix::mixint::to_discrete_space;
32
use crate::{types::*, utils};
3+
use egobox_moe::to_discrete_space;
44

55
use crate::utils::{
66
EGOR_DO_NOT_USE_MIDDLEPICKER_MULTISTARTER, compute_cstr_scales, logpofs, logpofs_grad, pofs,

crates/ego/src/solver/solver_impl.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ use std::marker::PhantomData;
22

33
use crate::errors::{EgoError, Result};
44
use crate::find_best_result_index;
5-
use crate::gpmix::mixint::{as_continuous_limits, to_discrete_space};
65
use crate::solver::solver_computations::MiddlePickerMultiStarter;
76
use crate::solver::solver_infill_optim::InfillOptProblem;
87
use crate::utils::{
@@ -11,6 +10,7 @@ use crate::utils::{
1110
};
1211
use crate::{DEFAULT_CSTR_TOL, EgorSolver, MAX_POINT_ADDITION_RETRY, ValidEgorConfig};
1312
use crate::{EgorState, types::*};
13+
use egobox_moe::{as_continuous_limits, to_discrete_space};
1414

1515
use argmin::argmin_error_closure;
1616
use argmin::core::{CostFunction, Problem, State};
@@ -25,11 +25,11 @@ use ndarray::{Array1, Array2, ArrayBase, Axis, Data, Ix1, Ix2, Zip, concatenate,
2525
use ndarray_rand::rand::{Rng, SeedableRng};
2626
use rand_xoshiro::Xoshiro256Plus;
2727
use rayon::prelude::*;
28-
use serde::de::DeserializeOwned;
28+
use serde::{Serialize, de::DeserializeOwned};
2929

3030
use super::coego::COEGO_IMPROVEMENT_CHECK;
3131

32-
impl<SB: SurrogateBuilder + DeserializeOwned, C: CstrFn> EgorSolver<SB, C> {
32+
impl<SB: SurrogateBuilder + Serialize + DeserializeOwned, C: CstrFn> EgorSolver<SB, C> {
3333
/// Constructor of the optimization of the function `f` with specified random generator
3434
/// to get reproducibility.
3535
///
@@ -104,7 +104,7 @@ impl<SB: SurrogateBuilder + DeserializeOwned, C: CstrFn> EgorSolver<SB, C> {
104104

105105
impl<SB, C> EgorSolver<SB, C>
106106
where
107-
SB: SurrogateBuilder + DeserializeOwned,
107+
SB: SurrogateBuilder + Serialize + DeserializeOwned,
108108
C: CstrFn,
109109
{
110110
/// Whether we have to recluster the data

0 commit comments

Comments
 (0)