Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ unimplemented for the last 15 years of publications including:
| Surrogate Outcomes | [Tikka and Karvanen, 2018](https://arxiv.org/abs/1806.07172) |
| Counterfactual Transportability | [Correia, Lee, Bareinboim, 2022](https://proceedings.mlr.press/v162/correa22a.html) |
| Cyclic ID | [Forré and Mooij, 2019](https://arxiv.org/abs/1901.00433v2) |
| Cyclic IDC | _ours_ |

Apply an algorithm to an Acyclic Directed Mixed Graph (ADMG) and a causal query
to generate an estimand represented in the DSL like:
Expand Down
22 changes: 17 additions & 5 deletions src/y0/algorithm/do_calculus.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Do Calculus."""

from typing import Annotated
from typing import Annotated, Literal

from .conditional_independencies import are_d_separated
from .separation import are_sigma_separated
from ..dsl import Variable
from ..graph import NxMixedGraph
from ..util import InPaperAs
Expand All @@ -23,6 +24,7 @@ def rule_2_of_do_calculus_applies(
outcomes: Annotated[set[Variable], InPaperAs(r"\mathbf{Y}")],
conditions: Annotated[set[Variable], InPaperAs(r"\mathbf{Z}")],
condition: Variable,
separation_implementation: Literal["d", "sigma"] | None = None,
) -> bool:
r"""Check if Rule 2 of the Do-Calculus applies to the conditioned variable.

Expand All @@ -31,6 +33,8 @@ def rule_2_of_do_calculus_applies(
:param conditions:
:param outcomes:
:param condition: The condition to check
:param separation_implementation: The separation implementation. Defaults to d
separation, but can be generalized to sigma separation

:returns: If rule 2 applies, see below.

Expand All @@ -50,7 +54,15 @@ def rule_2_of_do_calculus_applies(
reduced_graph: Annotated[NxMixedGraph, InPaperAs(r"G_{\bar{x}, \underbar{z}}")] = (
graph.remove_in_edges(treatments).remove_out_edges(condition)
)
return all(
are_d_separated(reduced_graph, outcome, condition, conditions=reduced_conditions)
for outcome in outcomes
)
if separation_implementation == "d" or separation_implementation is None:
return all(
are_d_separated(reduced_graph, outcome, condition, conditions=reduced_conditions)
for outcome in outcomes
)
elif separation_implementation == "sigma":
return all(
are_sigma_separated(reduced_graph, outcome, condition, conditions=reduced_conditions)
for outcome in outcomes
)
else:
raise ValueError(f"Unknown separation implementation: {separation_implementation}")
3 changes: 3 additions & 0 deletions src/y0/algorithm/identify/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
ID* [shpitser2012]_ :mod:`y0.algorithm.identify.id_star`
IDC* [shpitser2012]_ :mod:`y0.algorithm.identify.idc_star`
Cyclic ID [forre2019]_ :mod:`y0.algorithm.identify.cyclic_id`
Cyclic IDC :mod:`y0.algorithm.identify.cyclic_idc`
gID [correa2019]_ `Issue #72
<https://github.com/y0-causal-inference/y0/issues/72>`_
gID* [correa2021]_ `Issue #121
Expand Down Expand Up @@ -71,6 +72,7 @@

from .api import identify_outcomes
from .cyclic_id import cyclic_id
from .cyclic_idc import cyclic_idc
from .id_c import idc
from .id_star import id_star
from .id_std import identify
Expand All @@ -82,6 +84,7 @@
"Query",
"Unidentifiable",
"cyclic_id",
"cyclic_idc",
"id_star",
"idc",
"idc_star",
Expand Down
54 changes: 54 additions & 0 deletions src/y0/algorithm/identify/cyclic_idc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""An implementation of Cyclic IDC, based off of the extension from ID to IDC."""

from collections.abc import Iterable, Sequence

from .cyclic_id import cyclic_id
from .utils import Identification
from ..do_calculus import rule_2_of_do_calculus_applies
from ...dsl import Expression, Variable
from ...graph import NxMixedGraph

__all__ = ["cyclic_idc"]


def cyclic_idc(
graph: NxMixedGraph,
outcomes: Variable | Iterable[Variable],
interventions: Variable | Iterable[Variable],
conditions: Variable | Iterable[Variable],
*,
ordering: Sequence[Variable] | None = None,
) -> Expression:
"""Run cyclic ID with support for conditions."""
identification = Identification.from_parts(
graph=graph,
outcomes=outcomes,
treatments=interventions,
conditions=conditions,
)
for condition in identification.conditions:
if rule_2_of_do_calculus_applies(
graph=identification.graph,
treatments=identification.treatments,
outcomes=identification.outcomes,
conditions=identification.conditions,
condition=condition,
separation_implementation="sigma",
):
modified = identification.exchange_observation_with_action(condition)
return cyclic_idc(
graph=modified.graph,
outcomes=modified.outcomes,
interventions=modified.treatments,
conditions=modified.conditions,
ordering=ordering,
)

modified = identification.uncondition()
id_estimand = cyclic_id(
graph=modified.graph,
outcomes=modified.outcomes,
interventions=modified.treatments,
ordering=ordering,
)
return id_estimand.normalize_marginalize(identification.outcomes)
Loading