diff --git a/README.md b/README.md index 234f26fa..24b38d84 100644 --- a/README.md +++ b/README.md @@ -109,6 +109,7 @@ unimplemented for the last 15 years of publications including: | Surrogate Outcomes | [Tikka and Karvanen, 2018](https://arxiv.org/abs/1806.07172) | | Counterfactual Transportability | [Correia, Lee, Bareinboim, 2022](https://proceedings.mlr.press/v162/correa22a.html) | | Cyclic ID | [Forré and Mooij, 2019](https://arxiv.org/abs/1901.00433v2) | +| Cyclic IDC | _ours_ | Apply an algorithm to an Acyclic Directed Mixed Graph (ADMG) and a causal query to generate an estimand represented in the DSL like: diff --git a/src/y0/algorithm/do_calculus.py b/src/y0/algorithm/do_calculus.py index f7d5a89e..0fdd507a 100644 --- a/src/y0/algorithm/do_calculus.py +++ b/src/y0/algorithm/do_calculus.py @@ -1,8 +1,9 @@ """Do Calculus.""" -from typing import Annotated +from typing import Annotated, Literal from .conditional_independencies import are_d_separated +from .separation import are_sigma_separated from ..dsl import Variable from ..graph import NxMixedGraph from ..util import InPaperAs @@ -23,6 +24,7 @@ def rule_2_of_do_calculus_applies( outcomes: Annotated[set[Variable], InPaperAs(r"\mathbf{Y}")], conditions: Annotated[set[Variable], InPaperAs(r"\mathbf{Z}")], condition: Variable, + separation_implementation: Literal["d", "sigma"] | None = None, ) -> bool: r"""Check if Rule 2 of the Do-Calculus applies to the conditioned variable. @@ -31,6 +33,8 @@ def rule_2_of_do_calculus_applies( :param conditions: :param outcomes: :param condition: The condition to check + :param separation_implementation: The separation implementation. Defaults to d + separation, but can be generalized to sigma separation :returns: If rule 2 applies, see below. @@ -50,7 +54,15 @@ def rule_2_of_do_calculus_applies( reduced_graph: Annotated[NxMixedGraph, InPaperAs(r"G_{\bar{x}, \underbar{z}}")] = ( graph.remove_in_edges(treatments).remove_out_edges(condition) ) - return all( - are_d_separated(reduced_graph, outcome, condition, conditions=reduced_conditions) - for outcome in outcomes - ) + if separation_implementation == "d" or separation_implementation is None: + return all( + are_d_separated(reduced_graph, outcome, condition, conditions=reduced_conditions) + for outcome in outcomes + ) + elif separation_implementation == "sigma": + return all( + are_sigma_separated(reduced_graph, outcome, condition, conditions=reduced_conditions) + for outcome in outcomes + ) + else: + raise ValueError(f"Unknown separation implementation: {separation_implementation}") diff --git a/src/y0/algorithm/identify/__init__.py b/src/y0/algorithm/identify/__init__.py index 26ca579b..723c8c99 100644 --- a/src/y0/algorithm/identify/__init__.py +++ b/src/y0/algorithm/identify/__init__.py @@ -13,6 +13,7 @@ ID* [shpitser2012]_ :mod:`y0.algorithm.identify.id_star` IDC* [shpitser2012]_ :mod:`y0.algorithm.identify.idc_star` Cyclic ID [forre2019]_ :mod:`y0.algorithm.identify.cyclic_id` +Cyclic IDC :mod:`y0.algorithm.identify.cyclic_idc` gID [correa2019]_ `Issue #72 `_ gID* [correa2021]_ `Issue #121 @@ -71,6 +72,7 @@ from .api import identify_outcomes from .cyclic_id import cyclic_id +from .cyclic_idc import cyclic_idc from .id_c import idc from .id_star import id_star from .id_std import identify @@ -82,6 +84,7 @@ "Query", "Unidentifiable", "cyclic_id", + "cyclic_idc", "id_star", "idc", "idc_star", diff --git a/src/y0/algorithm/identify/cyclic_idc.py b/src/y0/algorithm/identify/cyclic_idc.py new file mode 100644 index 00000000..620d9689 --- /dev/null +++ b/src/y0/algorithm/identify/cyclic_idc.py @@ -0,0 +1,54 @@ +"""An implementation of Cyclic IDC, based off of the extension from ID to IDC.""" + +from collections.abc import Iterable, Sequence + +from .cyclic_id import cyclic_id +from .utils import Identification +from ..do_calculus import rule_2_of_do_calculus_applies +from ...dsl import Expression, Variable +from ...graph import NxMixedGraph + +__all__ = ["cyclic_idc"] + + +def cyclic_idc( + graph: NxMixedGraph, + outcomes: Variable | Iterable[Variable], + interventions: Variable | Iterable[Variable], + conditions: Variable | Iterable[Variable], + *, + ordering: Sequence[Variable] | None = None, +) -> Expression: + """Run cyclic ID with support for conditions.""" + identification = Identification.from_parts( + graph=graph, + outcomes=outcomes, + treatments=interventions, + conditions=conditions, + ) + for condition in identification.conditions: + if rule_2_of_do_calculus_applies( + graph=identification.graph, + treatments=identification.treatments, + outcomes=identification.outcomes, + conditions=identification.conditions, + condition=condition, + separation_implementation="sigma", + ): + modified = identification.exchange_observation_with_action(condition) + return cyclic_idc( + graph=modified.graph, + outcomes=modified.outcomes, + interventions=modified.treatments, + conditions=modified.conditions, + ordering=ordering, + ) + + modified = identification.uncondition() + id_estimand = cyclic_id( + graph=modified.graph, + outcomes=modified.outcomes, + interventions=modified.treatments, + ordering=ordering, + ) + return id_estimand.normalize_marginalize(identification.outcomes)