Skip to content

Commit a19c059

Browse files
authored
Merge pull request #295 from ACEsuit/asp
ASP tutorial
2 parents 152e505 + 3a21474 commit a19c059

File tree

3 files changed

+122
-1
lines changed

3 files changed

+122
-1
lines changed

docs/make.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ Literate.markdown(_tutorial_src * "/dataset_analysis.jl",
2828
Literate.markdown(_tutorial_src * "/descriptor.jl",
2929
_tutorial_out; documenter = true)
3030

31-
31+
Literate.markdown(_tutorial_src * "/asp.jl",
32+
_tutorial_out; documenter = true)
3233
# Literate.markdown(_tutorial_src * "/first_example_model.jl",
3334
# _tutorial_out; documenter = true)
3435

@@ -71,6 +72,7 @@ makedocs(;
7172
"literate_tutorials/dataset_analysis.md",
7273
"tutorials/scripting.md",
7374
"literate_tutorials/descriptor.md",
75+
"literate_tutorials/asp.md",
7476
],
7577
"Additional Topics" => Any[
7678
"gettingstarted/parallel-fitting.md",

docs/src/tutorials/asp.jl

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# # Sparse Solvers
2+
#
3+
# This short tutorial introduces the use of the Lasso Homotopy (ASP) and Orthogonal Matching Pursuit (OMP) solvers.
4+
# These are sparse solvers that compute the entire regularization path,
5+
# providing insight into how the support evolves as the regularization parameter changes.
6+
# For more details on the algorithms and their implementation,
7+
# see [ActiveSetPursuit.jl](https://github.com/MPF-Optimization-Laboratory/ActiveSetPursuit.jl)
8+
9+
# We start by importing `ACEpotentials` (and possibly other required libraries)
10+
using ACEpotentials
11+
using Random, Plots
12+
using ACEpotentials.Models: fast_evaluator
13+
using SparseArrays
14+
using Plots
15+
16+
17+
# Since sparse solvers automatically select the most relevant features, we usually begin with a model that has a large basis.
18+
# Here, for demonstration purposes, we use a relatively small model.
19+
20+
model = ace1_model(elements = [:Si], order = 3, totaldegree = 12)
21+
P = algebraic_smoothness_prior(model; p = 4)
22+
23+
# Next, we load a dataset. We split the dataset into training, validation, and test sets.
24+
# The training set is used to compute the solution path, the validation set is used to select the best solution, and the test set is used to evaluate the final model.
25+
26+
_train_data, test_data, _ = ACEpotentials.example_dataset("Zuo20_Si")
27+
shuffle!(_train_data);
28+
_train_data = _train_data[1:100] # Limit the dataset size for this tutorial
29+
isplit = floor(Int, 0.8 * length(_train_data))
30+
train_data = _train_data[1:isplit]
31+
val_data = _train_data[isplit+1:end]
32+
33+
# We can now assemble the linear system for the training and validation sets.
34+
35+
At, yt, Wt = ACEpotentials.assemble(train_data, model);
36+
Av, yv, Wv = ACEpotentials.assemble(val_data, model);
37+
38+
# We can now compute sparse solution paths using the `ASP` and `OMP` solvers.
39+
# These solvers support customizable selection criteria for choosing a solution along the path.
40+
#
41+
# The `select` keyword controls which solution is returned:
42+
# - `:final` selects the final iterate on the path.
43+
# - `(:bysize, n)` selects the solution with exactly `n` active parameters.
44+
# - `(:byerror, ε)` selects the smallest solution whose validation error is within a factor `ε` of the minimum validation error.
45+
46+
# The `tsvd` keyword controls whether the solution is post-processed using truncated SVD.
47+
# This is often beneficial for `ASP`, as ℓ1-regularization can shrink coefficients toward zero too aggressively.
48+
49+
# The `actMax` keyword controls the maximum number of active parameters in the solution.
50+
51+
solver_asp = ACEfit.ASP(; P = P, select = :final, tsvd = true, actMax = 100, loglevel = 0);
52+
asp_result = ACEfit.solve(solver_asp, Wt .* At, Wt .* yt, Wv .* Av, Wv .* yv);
53+
54+
55+
# We can also compute the OMP path, which is a greedy algorithm that selects the most relevant features iteratively.
56+
57+
solver_omp = ACEfit.OMP(; P = P, select = :final, tsvd = false, actMax = 100, loglevel = 0);
58+
omp_result = ACEfit.solve(solver_omp, Wt .* At, Wt .* yt, Wv .* Av, Wv .* yv);
59+
60+
61+
# To demonstrate the use of the sparse solvers, we will generate models with different numbers of active parameters.
62+
# We can select the final model, a model with 500 active parameters, and a model with a validation error within 1.3 times the minimum validation error.
63+
# We can use the `ACEfit.asp_select` function to select the desired models from the result.
64+
65+
asp_final = set_parameters!( deepcopy(model),
66+
ACEfit.asp_select(asp_result, :final)[1]);
67+
asp_size_50 = set_parameters!( deepcopy(model),
68+
ACEfit.asp_select(asp_result, (:bysize, 50))[1]);
69+
asp_error13 = set_parameters!( deepcopy(model),
70+
ACEfit.asp_select(asp_result, (:byerror, 1.3))[1]);
71+
72+
pot_final = fast_evaluator(asp_final; aa_static = false);
73+
pot_50 = fast_evaluator(asp_size_50; aa_static = true);
74+
pot_13 = fast_evaluator(asp_error13; aa_static = true);
75+
76+
err_13 = ACEpotentials.compute_errors(test_data, pot_13);
77+
err_50 = ACEpotentials.compute_errors(test_data, pot_50);
78+
err_fin = ACEpotentials.compute_errors(test_data, pot_final);
79+
80+
81+
# Similarly, we can compute the errors for the OMP models.
82+
83+
omp_final = set_parameters!( deepcopy(model),
84+
ACEfit.asp_select(omp_result, :final)[1]);
85+
omp_50 = set_parameters!( deepcopy(model),
86+
ACEfit.asp_select(omp_result, (:bysize, 50))[1]);
87+
omp_13 = set_parameters!( deepcopy(model),
88+
ACEfit.asp_select(omp_result, (:byerror, 1.3))[1]);
89+
90+
pot_fin = fast_evaluator(omp_final; aa_static = false);
91+
pot_50 = fast_evaluator(omp_50; aa_static = true);
92+
pot_13 = fast_evaluator(omp_13; aa_static = true);
93+
94+
err_13 = ACEpotentials.compute_errors(test_data, pot_13);
95+
err_50 = ACEpotentials.compute_errors(test_data, pot_50);
96+
err_fin = ACEpotentials.compute_errors(test_data, pot_fin);
97+
98+
99+
# Finally, we can visualize the results along the solution path.
100+
# We plot the validation error as a function of the number of active parameters for both ASP and OMP.
101+
102+
path_asp = asp_result["path"];
103+
path_omp = omp_result["path"];
104+
105+
nz_counts_asp = [nnz(p.solution) for p in path_asp];
106+
nz_counts_omp = [nnz(p.solution) for p in path_omp];
107+
108+
rmses_asp = [p.rmse for p in path_asp];
109+
rmses_omp = [p.rmse for p in path_omp];
110+
111+
plot(nz_counts_asp, rmses_asp;
112+
xlabel = "# Nonzero Coefficients",
113+
ylabel = "RMSE",
114+
title = "RMSE vs Sparsity Level",
115+
marker = :o,
116+
grid = true, yscale = :log10, label = "ASP")
117+
plot!(nz_counts_omp, rmses_omp; marker = :o, label = "OMP")

docs/src/tutorials/index.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,6 @@
66
* [Smoothness Priors](../literate_tutorials/smoothness_priors.md) : brief introduction to smoothness priors
77
* [Basic Dataset Analysis](../literate_tutorials/dataset_analysis.md) : basic techniques to visualize training datasets and correlate such observations to the choice of geometric priors
88
* [Descriptors](../literate_tutorials/descriptor.md) : `ACEpotentials` can be used as descriptors of atomic environments or structures, which is described here.
9+
* [Sparse Solvers](../literate_tutorials/asp.md) : basic tutorial on using the `ASP` and `OMP` solvers.
10+
911

0 commit comments

Comments
 (0)