Skip to content

Commit 6cd6207

Browse files
committed
adding transfer entropy
1 parent e8ac05d commit 6cd6207

20 files changed

+558
-4
lines changed

hoi/core/combinatory.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import itertools
22
from math import comb as ccomb
3+
from math import perm as pperm
34

45
import jax.numpy as jnp
56
import numpy as np
67

78

8-
def _combinations(n, k, order, target):
9-
for c in itertools.combinations(range(n), k):
9+
def _combinations(n, k, order, target, fnc=None):
10+
for c in fnc(range(n), k):
1011
# convert to list
1112
c = list(c) + target
1213

@@ -25,6 +26,7 @@ def combinations(
2526
order=False,
2627
fill_value=-1,
2728
target=None,
29+
directed=False,
2830
):
2931
"""Get combinations.
3032
@@ -53,13 +55,21 @@ def combinations(
5355
combinations of k elements.
5456
"""
5557
# ________________________________ ITERATOR _______________________________
58+
59+
if directed:
60+
fnc = itertools.permutations
61+
fnc_nmult = pperm
62+
else:
63+
fnc = itertools.combinations
64+
fnc_nmult = ccomb
65+
5666
if not isinstance(maxsize, int):
5767
maxsize = minsize
5868
target = [] if target is None else list(target)
5969
assert maxsize >= minsize
6070
iterators = []
6171
for msize in range(minsize, maxsize + 1):
62-
iterators.append(_combinations(n, msize, order, target))
72+
iterators.append(_combinations(n, msize, order, target, fnc))
6373
iterators = itertools.chain(*tuple(iterators))
6474

6575
if astype == "iterator":
@@ -70,7 +80,7 @@ def combinations(
7080
combs = np.asarray([c for c in iterators]).astype(int)
7181
else:
7282
# get the number of combinations
73-
n_mults = sum([ccomb(n, c) for c in range(minsize, maxsize + 1)])
83+
n_mults = sum([fnc_nmult(n, c) for c in range(minsize, maxsize + 1)])
7484

7585
# prepare output
7686
combs = np.full(

hoi/core/redundancies.py

Whitespace-only changes.

hoi/metrics/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@
1212
from .syn_mmi import SynergyMMI # noqa
1313
from .tc import TC # noqa
1414
from .pairwise_mi import MI # noqa
15+
from .transfer_entropy import TransferEntropy # noqa

hoi/metrics/base_hoi.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,7 @@ def get_combinations(self, minsize, maxsize=None, astype="jax"):
283283
astype=astype,
284284
order=False,
285285
fill_value=-1,
286+
directed=self._directed,
286287
)
287288

288289
return self._multiplets, self.order

hoi/metrics/do_tot.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ class DOtot(HOIEstimator):
7272
_positive = "redundancy"
7373
_negative = "synergy"
7474
_symmetric = False
75+
_directed = False
7576

7677
def __init__(self, x, multiplets=None, verbose=None):
7778
HOIEstimator.__init__(

hoi/metrics/dtc.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ class DTC(HOIEstimator):
6666
_negative = "null"
6767
_positive = "null"
6868
_symmetric = True
69+
_directed = False
6970

7071
def __init__(self, x, y=None, multiplets=None, verbose=None):
7172
HOIEstimator.__init__(

hoi/metrics/gradient_oinfo.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class GradientOinfo(HOIEstimator):
4242
_positive = "redundancy"
4343
_negative = "synergy"
4444
_symmetric = True
45+
_directed = False
4546

4647
def __init__(self, x, y, multiplets=None, base_model=Oinfo, verbose=None):
4748
kw_oinfo = dict(multiplets=multiplets, verbose=verbose)

hoi/metrics/info_tot.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class InfoTot(HOIEstimator):
3838
_positive = "info"
3939
_negative = "null"
4040
_symmetric = False
41+
_directed = False
4142

4243
def __init__(self, x, y, multiplets=None, verbose=None):
4344
HOIEstimator.__init__(

hoi/metrics/infotopo.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ class InfoTopo(HOIEstimator):
7575
_positive = "redundancy"
7676
_negative = "synergy"
7777
_symmetric = True
78+
_directed = False
7879

7980
def __init__(self, x, y=None, verbose=None):
8081
# for infotopo, the multiplets are set to None because this metric

hoi/metrics/oinfo.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ class Oinfo(HOIEstimator):
7171
_positive = "redundancy"
7272
_negative = "synergy"
7373
_symmetric = True
74+
_directed = False
7475

7576
def __init__(self, x, y=None, multiplets=None, verbose=None):
7677
HOIEstimator.__init__(

0 commit comments

Comments
 (0)