Skip to content

Commit e68382b

Browse files
committed
tests: add models functions
1 parent 4de3724 commit e68382b

File tree

1 file changed

+83
-0
lines changed

1 file changed

+83
-0
lines changed

tests/test_models.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
#
2+
# Copyright (C) 2025 sits developers.
3+
#
4+
# This program is free software; you can redistribute it and/or modify it
5+
# under the terms of the GNU General Public License as published by
6+
# the Free Software Foundation; either version 2 of the License, or
7+
# (at your option) any later version.
8+
#
9+
# This program is distributed in the hope that it will be useful,
10+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
11+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12+
# GNU General Public License for more details.
13+
#
14+
# You should have received a copy of the GNU General Public License
15+
# along with this program; if not, see <https://www.gnu.org/licenses/>.
16+
#
17+
18+
"""Unit tests for ml/dl models."""
19+
20+
import pytest
21+
22+
from pysits.models import SITSMachineLearningMethod
23+
from pysits.sits.context import samples_l8_rondonia_2bands
24+
from pysits.sits.ml import (
25+
sits_lighttae,
26+
sits_mlp,
27+
sits_model_export,
28+
sits_resnet,
29+
sits_rfor,
30+
sits_svm,
31+
sits_tae,
32+
sits_tempcnn,
33+
sits_train,
34+
sits_xgboost,
35+
)
36+
37+
#
38+
# Models available to test
39+
#
40+
ALL_MODELS = [
41+
sits_tae,
42+
sits_tempcnn,
43+
sits_lighttae,
44+
sits_mlp,
45+
sits_resnet,
46+
sits_rfor,
47+
sits_svm,
48+
sits_xgboost,
49+
]
50+
51+
52+
#
53+
# Test training for all available models
54+
#
55+
@pytest.mark.parametrize("model_fn", ALL_MODELS)
56+
def test_model_training(model_fn):
57+
"""Test training for all available models."""
58+
try:
59+
# Create model instance with parameters
60+
ml_method = model_fn()
61+
62+
# Train model
63+
model = sits_train(samples_l8_rondonia_2bands, ml_method=ml_method)
64+
65+
# Basic assertions to verify the model was trained
66+
assert model is not None
67+
assert isinstance(model, SITSMachineLearningMethod)
68+
69+
except Exception as e:
70+
pytest.fail(f"Training failed: {str(e)}")
71+
72+
73+
#
74+
# Test model export
75+
#
76+
def test_model_export():
77+
"""Test model export."""
78+
# Train model
79+
model = sits_train(samples_l8_rondonia_2bands, ml_method=sits_svm())
80+
81+
# Try to export model
82+
with pytest.raises(NotImplementedError):
83+
sits_model_export(model)

0 commit comments

Comments
 (0)