-
Notifications
You must be signed in to change notification settings - Fork 0
Open
Description
Targets assume alignment on some key, often time. However, in some applications the data are just a scalar number intended to match model output. I ended up created a very simple "MetricAlignment" for this purpose, but expect there's a better way.
@dataclass
class MetricAlignment(AlignmentStrategy):
"""Alignment strategy for selecting a specific metric from a metrics DataFrame."""
metric_name: str
@property
def on(self) -> list[str]:
return ["key"] # Use dummy key for grouping like ScalarAlignment
def align(
self,
observed: pl.DataFrame,
simulated: list[pl.DataFrame],
) -> AlignedData:
# Extract the specific metric from observed data
obs_filtered = observed.filter(pl.col("metric") == self.metric_name)
if obs_filtered.height == 0:
raise ValueError(f"Metric '{self.metric_name}' not found in observed data")
obs_value = obs_filtered["value"][0]
aligned_frames = []
for i, sim_df in enumerate(simulated):
# Extract the specific metric from simulated data
sim_filtered = sim_df.filter(pl.col("metric") == self.metric_name)
if sim_filtered.height == 0:
raise ValueError(f"Metric '{self.metric_name}' not found in simulation data")
sim_value = sim_filtered["value"][0]
# Create aligned row similar to ScalarAlignment
aligned_row = pl.DataFrame({
"key": [1], # Dummy key for grouping
f"value{SUFFIX_OBS}": [obs_value],
f"value{SUFFIX_SIM}": [sim_value],
REPLICATE_COL: [i]
})
aligned_frames.append(aligned_row)
return AlignedData(
data=pl.concat(aligned_frames),
on_cols=["key"],
replicate_col=REPLICATE_COL,
)Metadata
Metadata
Assignees
Labels
No labels