Skip to content

Commit 7543b9f

Browse files
committed
conversions: enhance py2r conversions
1 parent 1e01b10 commit 7543b9f

File tree

8 files changed

+206
-27
lines changed

8 files changed

+206
-27
lines changed

pysits/conversions/common.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,10 @@ def convert_to_r(obj):
162162

163163
# Handle ``SITSBase`` objects
164164
if getattr(obj, "_instance", None):
165+
# Sync instance with R
166+
if getattr(obj, "_sync_instance", None):
167+
obj._sync_instance()
168+
165169
return obj._instance
166170

167171
# Handle ``raw R`` / Expressions objects
@@ -263,7 +267,7 @@ def fix_reserved_words_parameters(**kwargs) -> dict:
263267
keys_to_remove.append(key)
264268

265269
# Save new value
266-
new_values[key[:-1]] = kwargs.pop(key)
270+
new_values[key[:-1]] = kwargs[key]
267271

268272
# Remove keys that were converted
269273
for key in keys_to_remove:

pysits/conversions/tibble_arrow.py

Lines changed: 118 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@
2222
from collections.abc import Callable
2323

2424
from pandas import DataFrame as PandasDataFrame
25+
from pandas.core.generic import NDFrame as PandasNDFrame
2526
from pyarrow import feather
27+
from rpy2.rinterface_lib.sexp import NULLType
2628
from rpy2.robjects import StrVector
2729
from rpy2.robjects import globalenv as rpy2_globalenv
2830
from rpy2.robjects import r as rpy2_r_interface
@@ -32,6 +34,9 @@
3234
from pysits.backend.pkgs import r_pkg_arrow, r_pkg_base, r_pkg_sits
3335

3436

37+
#
38+
# Helper functions
39+
#
3540
def _load_arrow_table_reader_function() -> Callable[[str, list[str]], RDataFrame]:
3641
"""Load and return an R function for reading Arrow tables with nested columns.
3742
@@ -53,15 +58,40 @@ def _load_arrow_table_reader_function() -> Callable[[str, list[str]], RDataFrame
5358
5459
for (col in nested_cols) {
5560
row_nested <- row_data[[col]]
56-
row_nested <- list(tidyr::unnest(
57-
row_nested,
58-
cols = dplyr::everything()
59-
))
60-
row_data[[col]] <- NULL
61-
row_data <- tibble::tibble(
62-
row_data,
63-
!!col := row_nested
64-
)
61+
62+
# Handle arrow_list class
63+
if (inherits(row_nested, "arrow_list")) {
64+
row_nested <- lapply(row_nested, function(v) {
65+
if (is.null(v)) return(NULL)
66+
# Try to parse as JSON first
67+
tryCatch({
68+
parsed <- jsonlite::fromJSON(v)
69+
setNames(as.character(parsed), names(parsed))
70+
}, error = function(e) {
71+
# If JSON parsing fails, return NULL
72+
NULL
73+
})
74+
})
75+
# If any values in row_nested are NULL, set the whole
76+
# thing to NULL
77+
if (any(sapply(row_nested, is.null))) {
78+
row_nested <- NULL
79+
}
80+
} else {
81+
row_nested <- list(tidyr::unnest(
82+
row_nested,
83+
cols = dplyr::everything()
84+
))
85+
}
86+
87+
# Only create tibble if row_nested is not NULL
88+
if (!is.null(row_nested)) {
89+
row_data[[col]] <- NULL
90+
row_data <- tibble::tibble(
91+
row_data,
92+
!!col := row_nested
93+
)
94+
}
6595
}
6696
row_data
6797
})
@@ -71,8 +101,41 @@ def _load_arrow_table_reader_function() -> Callable[[str, list[str]], RDataFrame
71101
return rpy2_globalenv["load_arrow_table"]
72102

73103

104+
def _named_vector_to_json(x: RDataFrame, colname: str) -> RDataFrame:
105+
"""Convert a named vector to a JSON string.
106+
107+
Args:
108+
x (RDataFrame): R DataFrame containing a column with named vectors
109+
110+
colname (str): Name of the column containing named vectors.
111+
112+
Returns:
113+
RDataFrame: DataFrame with named vectors converted to JSON strings
114+
"""
115+
# Define R code to convert named vector to JSON
116+
rpy2_r_interface(f"""
117+
named_vector_to_json <- function(x) {{
118+
vec_list <- lapply(x${colname}, function(v) {{
119+
if (is.null(names(v))) return(NULL)
120+
class(v) <- NULL
121+
json <- jsonlite::toJSON(as.list(setNames(as.character(v), names(v))),
122+
auto_unbox=TRUE)
123+
class(json) <- NULL
124+
json
125+
}})
126+
x${colname} <- vec_list
127+
x
128+
}}
129+
""")
130+
131+
# Call the R function and return result
132+
return rpy2_globalenv["named_vector_to_json"](x)
133+
134+
74135
def _tibble_to_pandas_arrow(
75-
instance: RDataFrame, nested_columns: list[str] | None = None
136+
instance: RDataFrame,
137+
nested_columns: list[str] | None = None,
138+
table_processor: Callable[[RDataFrame], RDataFrame] | None = None,
76139
) -> PandasDataFrame:
77140
"""Convert an R DataFrame (tibble) to a Pandas DataFrame using Arrow format.
78141
@@ -113,6 +176,10 @@ def _tibble_to_pandas_arrow(
113176
# Select regular columns (using ``[]``) and convert to Pandas
114177
rdf_data = instance.rx(StrVector(data_columns_valid))
115178

179+
# Process table
180+
if table_processor:
181+
rdf_data = table_processor(rdf_data)
182+
116183
# Write to Feather format
117184
r_pkg_arrow.write_feather(rdf_data, tmp_path)
118185

@@ -171,7 +238,11 @@ def _pandas_to_tibble_arrow(
171238
# Convert nested columns to R DataFrame
172239
for nested_column in nested_columns:
173240
instance[nested_column] = instance[nested_column].apply(
174-
lambda arr: arr.to_dict(orient="list")
241+
lambda arr: (
242+
arr.to_dict(orient="list")
243+
if isinstance(arr, PandasNDFrame)
244+
else arr
245+
)
175246
)
176247

177248
# Write to Feather
@@ -188,7 +259,9 @@ def _pandas_to_tibble_arrow(
188259
# General conversions
189260
#
190261
def tibble_nested_to_pandas_arrow(
191-
data: RDataFrame, nested_columns: list[str]
262+
data: RDataFrame,
263+
nested_columns: list[str],
264+
table_processor: Callable[[RDataFrame], RDataFrame] | None = None,
192265
) -> PandasDataFrame:
193266
"""Convert any tibble to Pandas DataFrame.
194267
@@ -198,7 +271,7 @@ def tibble_nested_to_pandas_arrow(
198271
Returns:
199272
pandas.DataFrame: R Data Frame as Pandas.
200273
"""
201-
return _tibble_to_pandas_arrow(data, nested_columns)
274+
return _tibble_to_pandas_arrow(data, nested_columns, table_processor)
202275

203276

204277
def pandas_to_tibble_arrow(
@@ -233,6 +306,7 @@ def tibble_sits_to_pandas_arrow(data: RDataFrame) -> PandasDataFrame:
233306
"label",
234307
"cube",
235308
"time_series",
309+
"predicted",
236310
"cluster",
237311
]
238312

@@ -258,11 +332,17 @@ def pandas_sits_to_tibble_arrow(data: PandasDataFrame) -> RDataFrame:
258332
# Define nested columns
259333
nested_columns = ["time_series", "predicted"]
260334

335+
# Define data classes
336+
data_classes = ["sits", "tbl_df", "tbl", "data.frame"]
337+
338+
if "predicted" in data.columns:
339+
data_classes.append("predicted")
340+
261341
# Convert to R DataFrame
262342
data = pandas_to_tibble_arrow(data, nested_columns)
263343

264344
# Set class
265-
data.rclass = StrVector(["sits", "tbl_df", "tbl", "data.frame"])
345+
data.rclass = StrVector(data_classes)
266346

267347
# Convert to R DataFrame
268348
return data
@@ -297,8 +377,30 @@ def tibble_cube_to_pandas_arrow(data: RDataFrame) -> PandasDataFrame:
297377
# Define nested columns
298378
nested_columns = ["file_info", "vector_info"]
299379

380+
# Define table processor
381+
def table_processor(x: RDataFrame) -> RDataFrame:
382+
"""Process table."""
383+
384+
# Process ``labels`` column
385+
if "labels" in x.colnames:
386+
# Get labels column
387+
labels = x.rx2("labels")
388+
389+
# Check if labels have names
390+
labels_has_names = all(
391+
not isinstance(label.names, NULLType) for label in labels
392+
)
393+
394+
# If labels have names, convert to JSON
395+
if labels_has_names:
396+
x = _named_vector_to_json(x, "labels")
397+
398+
return x
399+
300400
# Convert to Pandas DataFrame
301-
data_converted = tibble_nested_to_pandas_arrow(data, nested_columns)
401+
data_converted = tibble_nested_to_pandas_arrow(
402+
data, nested_columns, table_processor
403+
)
302404

303405
# Select columns
304406
columns_available = [v for v in column_order if v in data_converted.columns]
@@ -314,7 +416,7 @@ def pandas_cube_to_tibble_arrow(data: PandasDataFrame) -> RDataFrame:
314416
data (pandas.DataFrame): The pandas DataFrame to convert to R.
315417
"""
316418
# Define nested columns
317-
nested_columns = ["file_info", "vector_info"]
419+
nested_columns = ["labels", "file_info", "vector_info"]
318420

319421
# Convert to R DataFrame
320422
data = pandas_to_tibble_arrow(data, nested_columns)

pysits/models/data/cube.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def __init__(self, data: PandasSeries, **kwargs) -> None:
7676
to an R DataFrame instance.
7777
"""
7878
# Create a cube from a Pandas Series
79-
if isinstance(data, PandasSeries):
79+
if isinstance(data, PandasSeries) and getattr(data, "_instance", None) is None:
8080
# Check if required columns are present
8181
has_required_columns = all(
8282
col in data.index for col in self.required_columns
@@ -162,6 +162,26 @@ def _convert_from_r(self, instance: RDataFrame) -> PandasDataFrame:
162162
"""
163163
return tibble_cube_to_pandas_arrow(instance)
164164

165+
#
166+
# Data management
167+
#
168+
def _sync_instance(self):
169+
"""Sync instance with R."""
170+
if not self._is_updated:
171+
return
172+
173+
# Save current classes
174+
classes = self._instance.rclass
175+
176+
# Update instance
177+
self._instance = pandas_cube_to_tibble_arrow(self)
178+
179+
# Restore classes
180+
self._instance.rclass = classes
181+
182+
#
183+
# Representation
184+
#
165185
def _repr_html_(self) -> str:
166186
"""Create an HTML representation of the cube."""
167187
from pysits.jinja import get_template

pysits/models/data/frame.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,19 @@
2121
from pandas import DataFrame as PandasDataFrame
2222
from rpy2.robjects.vectors import DataFrame as RDataFrame
2323

24-
from pysits.conversions.tibble import tibble_nested_to_pandas, tibble_to_pandas
24+
from pysits.conversions.tibble import (
25+
tibble_nested_to_pandas,
26+
tibble_to_pandas,
27+
)
2528
from pysits.models.data.base import SITSData
2629

2730

2831
class SITSFrameBase(SITSData):
2932
"""Base class for SITS Data."""
3033

34+
_is_updated = False
35+
"""Whether the instance is updated."""
36+
3137
def __finalize__(self, other, method=None, **kwargs):
3238
"""Propagate metadata from another object to the current one.
3339
@@ -49,6 +55,11 @@ def _constructor(self):
4955
# Always return the current subclass
5056
return self.__class__
5157

58+
def __setitem__(self, key, value):
59+
"""Set item."""
60+
super().__setitem__(key, value)
61+
self._is_updated = True
62+
5263
#
5364
# Convertions
5465
#

pysits/models/data/ts.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,23 @@ def _convert_from_r(self, instance: RDataFrame, **kwargs) -> PandasDataFrame:
108108
"""
109109
return tibble_sits_to_pandas_arrow(instance)
110110

111+
#
112+
# Data management
113+
#
114+
def _sync_instance(self):
115+
"""Sync instance with R."""
116+
if not self._is_updated:
117+
return
118+
119+
# Save current classes
120+
classes = self._instance.rclass
121+
122+
# Update instance
123+
self._instance = pandas_sits_to_tibble_arrow(self)
124+
125+
# Restore classes
126+
self._instance.rclass = classes
127+
111128

112129
class SITSTimeSeriesSFModel(SITSFrameSF):
113130
"""SITS time-series model as sf."""

pysits/models/resolver.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from pysits.models.data.base import SITSBase, SITSData, SITStructureData
2727
from pysits.models.data.cube import SITSCubeModel
2828
from pysits.models.data.frame import SITSFrame, SITSFrameSF
29-
from pysits.models.data.matrix import SITSConfusionMatrix
29+
from pysits.models.data.matrix import SITSConfusionMatrix, SITSMatrix
3030
from pysits.models.data.ts import (
3131
SITSTimeSeriesClassificationModel,
3232
SITSTimeSeriesModel,
@@ -80,9 +80,13 @@ def content_class_resolver(data: Any) -> type[SITSBase]:
8080
content_class = SITSFrameSF
8181

8282
# Data frame
83-
case class_ if "tbl_df" in class_:
83+
case class_ if "tbl_df" in class_ or "data.frame" in class_:
8484
content_class = SITSFrame
8585

86+
# Matrix
87+
case class_ if "matrix" in class_:
88+
content_class = SITSMatrix
89+
8690
# ML model (any `sits_model`, including `random forest`, `svm`, `ltae`, etc.)
8791
case class_ if "sits_model" in class_:
8892
content_class = SITSMachineLearningMethod

pysits/sits/ml.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
from pysits.conversions.common import convert_dict_like_as_list_to_r
2626
from pysits.conversions.decorators import function_call
2727
from pysits.docs import attach_doc
28-
from pysits.models.data.base import SITStructureData
2928
from pysits.models.ml import SITSMachineLearningMethod
29+
from pysits.models.resolver import resolve_and_invoke_accuracy_class
3030

3131

3232
#
@@ -90,10 +90,10 @@ def sits_train(*args, **kwargs) -> SITSMachineLearningMethod:
9090
"""Train a machine learning model."""
9191

9292

93-
@function_call(r_pkg_sits.sits_kfold_validate, SITStructureData)
93+
@function_call(r_pkg_sits.sits_kfold_validate, resolve_and_invoke_accuracy_class)
9494
@attach_doc("sits_kfold_validate")
95-
def sits_kfold_validate(*args, **kwargs) -> SITStructureData:
96-
"""Train a machine learning model."""
95+
def sits_kfold_validate(*args, **kwargs) -> resolve_and_invoke_accuracy_class:
96+
"""Cross-validate time series samples."""
9797

9898

9999
#

0 commit comments

Comments
 (0)