Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import polars as pl
import polars.selectors as cs
from meds import DataSchema
from meds import CodeMetadataSchema, DataSchema
from omegaconf import DictConfig, ListConfig, OmegaConf

from .. import Stage
Expand Down Expand Up @@ -1031,5 +1031,9 @@ def aggregation_schema_updates(stage_cfg: dict | None = None) -> dict[str, pl.Da


stage = Stage.register(
map_fn=mapper_fntr, reduce_fn=reducer_fntr, output_schema_updates=aggregation_schema_updates
map_fn=mapper_fntr,
reduce_fn=reducer_fntr,
output_schema_updates=aggregation_schema_updates,
input_schema=DataSchema,
metadata_output_schema=CodeMetadataSchema,
)
Comment thread
mmcdermott marked this conversation as resolved.
44 changes: 44 additions & 0 deletions src/MEDS_transforms/stages/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,13 @@

if TYPE_CHECKING:
import polars as pl
from flexible_schema import Schema

from ..compute_modes import ANY_COMPUTE_FN_T
from ..dataframe import READ_FN_T, WRITE_FN_T

SchemaLike = type[Schema] | Schema

logger = logging.getLogger(__name__)

MAIN_FN_T = Callable[[DictConfig], None]
Expand Down Expand Up @@ -482,6 +485,10 @@ class Stage:
output_schema_updates: dict[str, pl.DataType] | Callable[[dict | None], dict[str, pl.DataType]] | None = (
None
)
input_schema: SchemaLike | None = None
output_schema: SchemaLike | None = None
metadata_input_schema: SchemaLike | None = None
metadata_output_schema: SchemaLike | None = None
is_metadata: bool | None = None

__mimic_fn: Callable | None = None
Expand Down Expand Up @@ -546,6 +553,10 @@ def __init__(
output_schema_updates: (
dict[str, pl.DataType] | Callable[[dict | None], dict[str, pl.DataType]] | None
) = None,
input_schema: SchemaLike | None = None,
output_schema: SchemaLike | None = None,
metadata_input_schema: SchemaLike | None = None,
metadata_output_schema: SchemaLike | None = None,
examples_dir: Path | None = None,
default_config: dict[str, Any] | DictConfig | Path | str | None = None,
is_metadata: bool | None = None,
Expand Down Expand Up @@ -616,6 +627,11 @@ def __init__(
else:
self.output_schema_updates = copy.deepcopy(output_schema_updates)

self.input_schema = input_schema
self.output_schema = output_schema
self.metadata_input_schema = metadata_input_schema
self.metadata_output_schema = metadata_output_schema

self.example_class = example_class if example_class is not None else StageExample

def __infer_stage_dir(self, stage_definition_file: Path | None) -> Path | None:
Expand Down Expand Up @@ -815,6 +831,34 @@ def _resolve_output_schema_updates(

return dict(self.output_schema_updates(resolved_cfg))

@property
def declared_schemas(self) -> dict[str, SchemaLike | None]:
"""Return all declared ``flexible_schema`` schemas on this stage, keyed by role.

The four roles are ``input``, ``output``, ``metadata_input``, ``metadata_output``. Values
are ``None`` when the stage has not declared that particular schema. Intended to feed
pipeline-load-time schema validation (see #324) and composer schema checks (see #56).

Examples:
>>> def compute(cfg):
... '''docstring'''
... return 0
>>> from meds import DataSchema
>>> stage = Stage(map_fn=compute, input_schema=DataSchema, output_schema=DataSchema)
>>> sorted(stage.declared_schemas.keys())
['input', 'metadata_input', 'metadata_output', 'output']
>>> stage.declared_schemas["input"] is DataSchema
True
>>> stage.declared_schemas["metadata_input"] is None
True
"""
return {
"input": self.input_schema,
"output": self.output_schema,
"metadata_input": self.metadata_input_schema,
"metadata_output": self.metadata_output_schema,
}

@property
def test_cases(self) -> dict[str, StageExample]:
if self.examples_dir is None:
Expand Down
Loading