Skip to content

Commit f5879bf

Browse files
committed
tests: add model serialization test
1 parent ce34d16 commit f5879bf

File tree

3 files changed

+42
-1
lines changed

3 files changed

+42
-1
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ dev = [
8080
"pytest>=7.4.4",
8181
"pytest-cov>=6.1.1",
8282
"pytest-randomly>=3.16.0",
83+
"cloudpickle>=3.1.1",
8384
]
8485

8586
[tool.ruff]

tests/test_models.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,17 @@
1717

1818
"""Unit tests for ml/dl models."""
1919

20+
import cloudpickle
2021
import pytest
2122

2223
from pysits.models import SITSMachineLearningMethod
23-
from pysits.sits.context import samples_l8_rondonia_2bands
24+
from pysits.sits.classification import sits_classify
25+
from pysits.sits.context import (
26+
point_mt_6bands,
27+
samples_l8_rondonia_2bands,
28+
samples_modis_ndvi,
29+
)
30+
from pysits.sits.data import sits_labels, sits_select
2431
from pysits.sits.ml import (
2532
sits_formula_linear,
2633
sits_formula_logref,
@@ -72,6 +79,37 @@ def test_model_training(model_fn):
7279
pytest.fail(f"Training failed: {str(e)}")
7380

7481

82+
@pytest.mark.parametrize("model_fn", ALL_MODELS)
83+
def test_model_serialization(model_fn, tmp_path):
84+
"""Test model serialization."""
85+
ml_method = model_fn()
86+
model = sits_train(samples_modis_ndvi, ml_method=ml_method)
87+
88+
assert isinstance(model, SITSMachineLearningMethod)
89+
90+
# Serialize model
91+
serialized_model = cloudpickle.dumps(model)
92+
93+
# Save serialized model to file
94+
model_file = tmp_path / "model.pkl"
95+
96+
with model_file.open("wb") as f:
97+
f.write(serialized_model)
98+
99+
# Load serialized model from file
100+
with model_file.open("rb") as f:
101+
loaded_model = cloudpickle.load(f)
102+
103+
assert isinstance(loaded_model, SITSMachineLearningMethod)
104+
105+
# Classify a time-series point
106+
point_ndvi = sits_select(point_mt_6bands, bands=("NDVI"))
107+
point_class = sits_classify(data=point_ndvi, ml_model=loaded_model)
108+
109+
assert point_class.shape[0] == 1 # noqa: PLR2004 - number of points
110+
assert len(sits_labels(point_class)) == 1 # noqa: PLR2004 - number of labels
111+
112+
75113
def test_model_svm_params():
76114
"""Test SVM parameters."""
77115

uv.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)