Skip to content
This repository was archived by the owner on Feb 26, 2025. It is now read-only.

Commit 3262bbd

Browse files
WIP Convert query to nodesets
1 parent 2da1c6d commit 3262bbd

File tree

2 files changed

+77
-1
lines changed

2 files changed

+77
-1
lines changed

bluepysnap/query.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@
1818
VALUE_KEYS = {REGEX_KEY}
1919
ALL_KEYS = {NODE_ID_KEY, EDGE_ID_KEY, POPULATION_KEY, OR_KEY, AND_KEY, NODE_SET_KEY} | VALUE_KEYS
2020

21+
GT_KEY = "$gt"
22+
LT_KEY = "$lt"
23+
GTE_KEY = "$gte"
24+
LTE_KEY = "$lte"
25+
2126

2227
# TODO: move to `libsonata` library
2328
def _complex_query(prop, query):
@@ -158,3 +163,45 @@ def _collect(queries, queries_key):
158163
queries = deepcopy(queries)
159164
traverse_queries_bottom_up(queries, _collect)
160165
return _merge_queries_masks(queries)
166+
167+
168+
def _convert(queries, node_sets, node_set_name):
169+
for queries_key, queries_value in queries.items():
170+
if queries_key == OR_KEY:
171+
# create a node_set for each item, and a combined node_set with the list
172+
assert len(queries) == 1, f"Mixing {OR_KEY} and other keys isn't supported yet"
173+
names = []
174+
for n, val in enumerate(queries_value):
175+
assert isinstance(val, dict)
176+
name = f"{node_set_name}_{n}"
177+
_convert(val, node_sets, name)
178+
names.append(name)
179+
node_sets[node_set_name] = names
180+
elif queries_key == AND_KEY:
181+
assert len(queries) == 1, f"Mixing {AND_KEY} and other keys isn't supported yet"
182+
raise NotImplementedError
183+
else:
184+
if isinstance(queries_value, tuple) and len(queries_value) == 2:
185+
start, stop = queries_value
186+
queries_value = {GTE_KEY: start, LTE_KEY: stop}
187+
assert (
188+
isinstance(queries_value, (str, int, float))
189+
or isinstance(queries_value, list)
190+
and all(isinstance(i, (str, int, float)) for i in queries_value)
191+
or isinstance(queries_value, dict)
192+
and {REGEX_KEY, GT_KEY, LT_KEY, GTE_KEY, LTE_KEY}.issuperset(queries_value)
193+
and all(isinstance(i, (str, int, float)) for i in queries_value.values())
194+
), (
195+
"Value should be a scalar, a list of scalars, a dict of operators, "
196+
"or a tuple of 2 elements representing an interval."
197+
)
198+
node_sets.setdefault(node_set_name, {})[queries_key] = deepcopy(queries_value)
199+
200+
201+
def to_node_set(queries):
202+
"""Convert a query to node_sets."""
203+
name = "ns"
204+
node_sets = {}
205+
_convert(queries, node_sets, name)
206+
return node_sets, name
207+
# return NodeSets.from_dict(node_sets)[name]

tests/test_query.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44

55
from bluepysnap import BluepySnapError
6-
from bluepysnap.query import _circuit_mask, _positional_mask, resolve_ids
6+
from bluepysnap.query import _circuit_mask, _positional_mask, resolve_ids, to_node_set
77

88

99
def test_positional_mask():
@@ -58,3 +58,32 @@ def test_resolve_ids():
5858
with pytest.raises(BluepySnapError) as e:
5959
resolve_ids(data, "", {"str": {"$regex": "*.some", "edge_id": 2}})
6060
assert "Value operators can't be used with plain values" in e.value.args[0]
61+
62+
63+
@pytest.mark.parametrize(
64+
"queries, expected",
65+
[
66+
(
67+
{"x": (0, 1), "mtype": "L1_SLAC"},
68+
{"ns": {"mtype": "L1_SLAC", "x": {"$gte": 0, "$lte": 1}}},
69+
),
70+
(
71+
{"$or": [{"layer": [2, 3]}, {"x": (0, 1), "mtype": "L1_SLAC"}]},
72+
{
73+
"ns_0": {"layer": [2, 3]},
74+
"ns_1": {"x": {"$gte": 0, "$lte": 1}, "mtype": "L1_SLAC"},
75+
"ns": ["ns_0", "ns_1"],
76+
},
77+
),
78+
],
79+
)
80+
def test_to_node_set(queries, expected):
81+
node_sets, name = to_node_set(queries)
82+
assert name == "ns"
83+
assert node_sets == expected
84+
85+
86+
def test_to_node_raises():
87+
queries = {"$and": [{"mtype": "L6_Y"}, {"morphology": "morph-B"}]}
88+
with pytest.raises(NotImplementedError):
89+
to_node_set(queries)

0 commit comments

Comments
 (0)