|
17 | 17 |
|
18 | 18 | """Unit tests for ml/dl models.""" |
19 | 19 |
|
| 20 | +import cloudpickle |
20 | 21 | import pytest |
21 | 22 |
|
22 | 23 | from pysits.models import SITSMachineLearningMethod |
23 | | -from pysits.sits.context import samples_l8_rondonia_2bands |
| 24 | +from pysits.sits.classification import sits_classify |
| 25 | +from pysits.sits.context import ( |
| 26 | + point_mt_6bands, |
| 27 | + samples_l8_rondonia_2bands, |
| 28 | + samples_modis_ndvi, |
| 29 | +) |
| 30 | +from pysits.sits.data import sits_labels, sits_select |
24 | 31 | from pysits.sits.ml import ( |
25 | 32 | sits_formula_linear, |
26 | 33 | sits_formula_logref, |
@@ -72,6 +79,37 @@ def test_model_training(model_fn): |
72 | 79 | pytest.fail(f"Training failed: {str(e)}") |
73 | 80 |
|
74 | 81 |
|
| 82 | +@pytest.mark.parametrize("model_fn", ALL_MODELS) |
| 83 | +def test_model_serialization(model_fn, tmp_path): |
| 84 | + """Test model serialization.""" |
| 85 | + ml_method = model_fn() |
| 86 | + model = sits_train(samples_modis_ndvi, ml_method=ml_method) |
| 87 | + |
| 88 | + assert isinstance(model, SITSMachineLearningMethod) |
| 89 | + |
| 90 | + # Serialize model |
| 91 | + serialized_model = cloudpickle.dumps(model) |
| 92 | + |
| 93 | + # Save serialized model to file |
| 94 | + model_file = tmp_path / "model.pkl" |
| 95 | + |
| 96 | + with model_file.open("wb") as f: |
| 97 | + f.write(serialized_model) |
| 98 | + |
| 99 | + # Load serialized model from file |
| 100 | + with model_file.open("rb") as f: |
| 101 | + loaded_model = cloudpickle.load(f) |
| 102 | + |
| 103 | + assert isinstance(loaded_model, SITSMachineLearningMethod) |
| 104 | + |
| 105 | + # Classify a time-series point |
| 106 | + point_ndvi = sits_select(point_mt_6bands, bands=("NDVI")) |
| 107 | + point_class = sits_classify(data=point_ndvi, ml_model=loaded_model) |
| 108 | + |
| 109 | + assert point_class.shape[0] == 1 # noqa: PLR2004 - number of points |
| 110 | + assert len(sits_labels(point_class)) == 1 # noqa: PLR2004 - number of labels |
| 111 | + |
| 112 | + |
75 | 113 | def test_model_svm_params(): |
76 | 114 | """Test SVM parameters.""" |
77 | 115 |
|
|
0 commit comments