|
18 | 18 | VALUE_KEYS = {REGEX_KEY} |
19 | 19 | ALL_KEYS = {NODE_ID_KEY, EDGE_ID_KEY, POPULATION_KEY, OR_KEY, AND_KEY, NODE_SET_KEY} | VALUE_KEYS |
20 | 20 |
|
| 21 | +GT_KEY = "$gt" |
| 22 | +LT_KEY = "$lt" |
| 23 | +GTE_KEY = "$gte" |
| 24 | +LTE_KEY = "$lte" |
| 25 | + |
21 | 26 |
|
22 | 27 | # TODO: move to `libsonata` library |
23 | 28 | def _complex_query(prop, query): |
@@ -158,3 +163,45 @@ def _collect(queries, queries_key): |
158 | 163 | queries = deepcopy(queries) |
159 | 164 | traverse_queries_bottom_up(queries, _collect) |
160 | 165 | return _merge_queries_masks(queries) |
| 166 | + |
| 167 | + |
| 168 | +def _convert(queries, node_sets, node_set_name): |
| 169 | + for queries_key, queries_value in queries.items(): |
| 170 | + if queries_key == OR_KEY: |
| 171 | + # create a node_set for each item, and a combined node_set with the list |
| 172 | + assert len(queries) == 1, f"Mixing {OR_KEY} and other keys isn't supported yet" |
| 173 | + names = [] |
| 174 | + for n, val in enumerate(queries_value): |
| 175 | + assert isinstance(val, dict) |
| 176 | + name = f"{node_set_name}_{n}" |
| 177 | + _convert(val, node_sets, name) |
| 178 | + names.append(name) |
| 179 | + node_sets[node_set_name] = names |
| 180 | + elif queries_key == AND_KEY: |
| 181 | + assert len(queries) == 1, f"Mixing {AND_KEY} and other keys isn't supported yet" |
| 182 | + raise NotImplementedError |
| 183 | + else: |
| 184 | + if isinstance(queries_value, tuple) and len(queries_value) == 2: |
| 185 | + start, stop = queries_value |
| 186 | + queries_value = {GTE_KEY: start, LTE_KEY: stop} |
| 187 | + assert ( |
| 188 | + isinstance(queries_value, (str, int, float)) |
| 189 | + or isinstance(queries_value, list) |
| 190 | + and all(isinstance(i, (str, int, float)) for i in queries_value) |
| 191 | + or isinstance(queries_value, dict) |
| 192 | + and {REGEX_KEY, GT_KEY, LT_KEY, GTE_KEY, LTE_KEY}.issuperset(queries_value) |
| 193 | + and all(isinstance(i, (str, int, float)) for i in queries_value.values()) |
| 194 | + ), ( |
| 195 | + "Value should be a scalar, a list of scalars, a dict of operators, " |
| 196 | + "or a tuple of 2 elements representing an interval." |
| 197 | + ) |
| 198 | + node_sets.setdefault(node_set_name, {})[queries_key] = deepcopy(queries_value) |
| 199 | + |
| 200 | + |
| 201 | +def to_node_set(queries): |
| 202 | + """Convert a query to node_sets.""" |
| 203 | + name = "ns" |
| 204 | + node_sets = {} |
| 205 | + _convert(queries, node_sets, name) |
| 206 | + return node_sets, name |
| 207 | + # return NodeSets.from_dict(node_sets)[name] |
0 commit comments