11import itertools
22from math import comb as ccomb
3+ from math import perm as pperm
34
45import jax .numpy as jnp
56import numpy as np
67
78
8- def _combinations (n , k , order , target ):
9- for c in itertools . combinations (range (n ), k ):
9+ def _combinations (n , k , order , target , fnc = None ):
10+ for c in fnc (range (n ), k ):
1011 # convert to list
1112 c = list (c ) + target
1213
@@ -25,6 +26,7 @@ def combinations(
2526 order = False ,
2627 fill_value = - 1 ,
2728 target = None ,
29+ directed = False ,
2830):
2931 """Get combinations.
3032
@@ -53,13 +55,21 @@ def combinations(
5355 combinations of k elements.
5456 """
5557 # ________________________________ ITERATOR _______________________________
58+
59+ if directed :
60+ fnc = itertools .permutations
61+ fnc_nmult = pperm
62+ else :
63+ fnc = itertools .combinations
64+ fnc_nmult = ccomb
65+
5666 if not isinstance (maxsize , int ):
5767 maxsize = minsize
5868 target = [] if target is None else list (target )
5969 assert maxsize >= minsize
6070 iterators = []
6171 for msize in range (minsize , maxsize + 1 ):
62- iterators .append (_combinations (n , msize , order , target ))
72+ iterators .append (_combinations (n , msize , order , target , fnc ))
6373 iterators = itertools .chain (* tuple (iterators ))
6474
6575 if astype == "iterator" :
@@ -70,7 +80,7 @@ def combinations(
7080 combs = np .asarray ([c for c in iterators ]).astype (int )
7181 else :
7282 # get the number of combinations
73- n_mults = sum ([ccomb (n , c ) for c in range (minsize , maxsize + 1 )])
83+ n_mults = sum ([fnc_nmult (n , c ) for c in range (minsize , maxsize + 1 )])
7484
7585 # prepare output
7686 combs = np .full (
0 commit comments