Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
907d48c
refactor: simplify TSQueryPredicate test logic and add null checks
DavidBakerEffendi Feb 21, 2026
a4808ba
feat: add query predicate parsing to TSQuery
DavidBakerEffendi Feb 21, 2026
293560b
enh: implement query predicate filtering in TSQueryCursor
DavidBakerEffendi Feb 21, 2026
45f1e19
test: remove redundant null checks in TSQueryTest
DavidBakerEffendi Feb 21, 2026
0278aa5
test: refactor query cursor and predicate tests
DavidBakerEffendi Feb 21, 2026
9f45a84
Remove accidentally added new line
DavidBakerEffendi Feb 21, 2026
fffbe2d
refactor: fix byte-to-char mapping in TSQueryCursor and predicate logic
DavidBakerEffendi Feb 22, 2026
edd502a
refactor: use capture IDs in TSQuery predicates and fix typo
DavidBakerEffendi Feb 22, 2026
64b9344
refactor: improve query predicate logic and remove redundant checks
DavidBakerEffendi Feb 22, 2026
b288ad4
refactor: update TSQuery and TSQueryCursor tests and fix naming
DavidBakerEffendi Feb 22, 2026
ba62b20
test: simplify query in TSQueryCursorTest
DavidBakerEffendi Feb 22, 2026
a0b45ea
test: add multi-byte character predicate tests to TSQueryPredicateTest
DavidBakerEffendi Feb 22, 2026
df4f625
docs: clarify UTF-8 encoding requirement in TSQueryCursor exec methods
DavidBakerEffendi Feb 22, 2026
22a525f
enh: add byte-based test method to TSQueryPredicate for performance
DavidBakerEffendi Feb 22, 2026
3aa4735
Use import instead of fully-qualified type
DavidBakerEffendi Feb 25, 2026
d012c44
fix: throw TSQueryException on invalid query syntax
DavidBakerEffendi Feb 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 133 additions & 2 deletions tree-sitter/src/main/java/org/treesitter/TSQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@

import java.lang.ref.Cleaner.Cleanable;

import java.util.ArrayList;
import java.util.List;

import static org.treesitter.TSParser.*;

public class TSQuery implements AutoCloseable {
private final long ptr;
private TSLanguage lang;
private List<List<TSQueryPredicate>> predicates;
private final Cleanable cleanable;
private boolean closed = false;

Expand All @@ -25,7 +29,9 @@ public TSQueryCleanRunner(long ptr) {

@Override
public void run() {
ts_query_delete(ptr);
if (ptr != 0) {
ts_query_delete(ptr);
}
}
}

Expand Down Expand Up @@ -55,7 +61,11 @@ public void close() {
*/
public TSQuery(TSLanguage language, String query){
this(ts_query_new(language.getPtr(), query));
if (ptr == 0) {
throw new TSQueryException("Syntax error in query: " + query);
}
this.lang = language;
this.predicates = parsePredicates();
}

protected long getPtr() {
Expand Down Expand Up @@ -209,6 +219,122 @@ public String getCaptureNameForId(int captureId) {
return ts_query_capture_name_for_id(ptr, captureId);
}

/**
* Get the predicates for the given pattern.
*
* @param patternIndex The index of the pattern.
* @return The list of predicates for the pattern.
* @throws IndexOutOfBoundsException if the pattern index is out of bounds.
*/
public List<TSQueryPredicate> getPredicatesForPattern(int patternIndex) {
if (patternIndex < 0 || patternIndex >= predicates.size()) {
throw new IndexOutOfBoundsException("Pattern index " + patternIndex + " is out of bounds");
}
return predicates.get(patternIndex);
}

private List<List<TSQueryPredicate>> parsePredicates() {
int patternCount = getPatternCount();
List<List<TSQueryPredicate>> result = new ArrayList<>(patternCount);
for (int i = 0; i < patternCount; i++) {
TSQueryPredicateStep[] steps = getPredicateForPattern(i);
List<TSQueryPredicate> patternPredicates = new ArrayList<>();
if (steps == null) {
result.add(patternPredicates);
continue;
}
int stepIndex = 0;
while (stepIndex < steps.length) {
// Find the number of arguments until Done sentinel
int nargs = 0;
while (stepIndex + nargs < steps.length &&
steps[stepIndex + nargs].getType() != TSQueryPredicateStepType.TSQueryPredicateStepTypeDone) {
nargs++;
}

if (nargs > 0) {
TSQueryPredicateStep firstStep = steps[stepIndex];
if (firstStep.getType() != TSQueryPredicateStepType.TSQueryPredicateStepTypeString) {
throw new TSQueryException("Predicate must begin with a string");
}
String name = getStringValueForId(firstStep.getValueId());

if (TSQueryPredicate.TSQueryPredicateEq.NAMES.contains(name)) {
patternPredicates.add(handleEq(name, steps, stepIndex, nargs));
} else if (TSQueryPredicate.TSQueryPredicateMatch.NAMES.contains(name)) {
patternPredicates.add(handleMatch(name, steps, stepIndex, nargs));
} else if (TSQueryPredicate.TSQueryPredicateAnyOf.NAMES.contains(name)) {
patternPredicates.add(handleAnyOf(name, steps, stepIndex, nargs));
} else {
patternPredicates.add(new TSQueryPredicate.TSQueryPredicateGeneric(name));
}
}
stepIndex += nargs + 1; // Move past arguments and the Done sentinel
}
result.add(patternPredicates);
}
return result;
}

private TSQueryPredicate handleEq(String name, TSQueryPredicateStep[] steps, int start, int nargs) {
if (nargs != 3) {
throw new TSQueryException(String.format("Predicate #%s expects 2 arguments, got %d", name, nargs - 1));
}
TSQueryPredicateStep arg1 = steps[start + 1];
if (arg1.getType() != TSQueryPredicateStepType.TSQueryPredicateStepTypeCapture) {
throw new TSQueryException(String.format("First argument to #%s must be a capture", name));
}
int captureId = arg1.getValueId();

TSQueryPredicateStep arg2 = steps[start + 2];
int arg2ValueId = arg2.getValueId();
boolean isCapture = arg2.getType() == TSQueryPredicateStepType.TSQueryPredicateStepTypeCapture;
String literalValue = isCapture ? null : getStringValueForId(arg2ValueId);

return new TSQueryPredicate.TSQueryPredicateEq(name, captureId, literalValue, arg2ValueId, isCapture);
}

private TSQueryPredicate handleMatch(String name, TSQueryPredicateStep[] steps, int start, int nargs) {
if (nargs != 3) {
throw new TSQueryException(String.format("Predicate #%s expects 2 arguments, got %d", name, nargs - 1));
}
TSQueryPredicateStep arg1 = steps[start + 1];
if (arg1.getType() != TSQueryPredicateStepType.TSQueryPredicateStepTypeCapture) {
throw new TSQueryException(String.format("First argument to #%s must be a capture", name));
}
int captureId = arg1.getValueId();

TSQueryPredicateStep arg2 = steps[start + 2];
if (arg2.getType() != TSQueryPredicateStepType.TSQueryPredicateStepTypeString) {
throw new TSQueryException(String.format("Second argument to #%s must be a string literal", name));
}
String patternStr = getStringValueForId(arg2.getValueId());

return new TSQueryPredicate.TSQueryPredicateMatch(name, captureId, patternStr);
}

private TSQueryPredicate handleAnyOf(String name, TSQueryPredicateStep[] steps, int start, int nargs) {
if (nargs < 3) {
throw new TSQueryException(String.format("Predicate #%s expects at least 2 arguments, got %d", name, nargs - 1));
}
TSQueryPredicateStep arg1 = steps[start + 1];
if (arg1.getType() != TSQueryPredicateStepType.TSQueryPredicateStepTypeCapture) {
throw new TSQueryException(String.format("First argument to #%s must be a capture", name));
}
int captureId = arg1.getValueId();

List<String> values = new ArrayList<>(nargs - 2);
for (int i = 2; i < nargs; i++) {
TSQueryPredicateStep arg = steps[start + i];
if (arg.getType() != TSQueryPredicateStepType.TSQueryPredicateStepTypeString) {
throw new TSQueryException(String.format("Arguments to #%s must be string literals", name));
}
values.add(getStringValueForId(arg.getValueId()));
}

return new TSQueryPredicate.TSQueryPredicateAnyOf(name, captureId, values);
}

/**
* Get the quantifier of the query's captures. Each capture is * associated
* with a numeric id based on the order that it appeared in the query's source.
Expand All @@ -235,7 +361,12 @@ public TSQuantifier getCaptureQuantifierForId(int patternId, int captureId) {
* Get TSQueryPredicateStepTypeString by id. See {@link #getPredicateForPattern(int)}
* @param id the <code>valueId</code> got from {@link #getPredicateForPattern(int)}.
* @return the literal string value.
* @throws TSQueryException if the id is invalid.
* @throws TSException if the id is invalid.
*/
/**
* Get the string value for the given id.
* @param id the string id.
* @return the string value.
*/
public String getStringValueForId(int id) {
ensureOpen();
Expand Down
75 changes: 66 additions & 9 deletions tree-sitter/src/main/java/org/treesitter/TSQueryCursor.java
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package org.treesitter;

import java.lang.ref.Cleaner.Cleanable;
import java.util.Iterator;
import java.util.NoSuchElementException;
import java.util.function.BiFunction;
import java.nio.charset.StandardCharsets;
import java.util.*;
import java.util.List;
import java.util.Objects;
import java.util.function.Function;

import static org.treesitter.TSParser.*;
import static org.treesitter.TSParser.ts_query_cursor_next_match;
Expand All @@ -24,6 +26,7 @@ private void ensureOpen() {

private TSNode node;
private TSQuery query;
private byte[] sourceBytes;

private static class TSQueryCursorCleanAction implements Runnable {
private final long ptr;
Expand Down Expand Up @@ -87,10 +90,27 @@ public TSQueryCursor() {
* @param node The node to run the query on.
*/
public void exec(TSQuery query, TSNode node){
exec(query, node, null);
}

/**
* Start running a given query on a given node with source text for predicate filtering.
* <p>
* Note: The {@code sourceText} is encoded as <b>UTF-8</b> to align with Tree-sitter's
* default byte offsets. If the tree was parsed with a different encoding (e.g. UTF-16),
* predicate results may be incorrect.
*
* @param query The query to run.
* @param node The node to run the query on.
* @param sourceText The source text used to resolve predicates like {@code #eq?}.
*/
public void exec(TSQuery query, TSNode node, CharSequence sourceText){
ensureOpen();
executed = true;
this.node = node;
this.query = query;
this.sourceBytes = sourceText == null ? null :
sourceText.toString().getBytes(StandardCharsets.UTF_8);
ts_query_cursor_exec(ptr, query.getPtr(), node);
}

Expand All @@ -103,10 +123,29 @@ public void exec(TSQuery query, TSNode node){
* @param progress The progress callback.
*/
public void execWithOptions(TSQuery query, TSNode node, TSQueryProgress progress){
execWithOptions(query, node, null, progress);
}

/**
* Start running a given query on a given node, with some options and source text.
* <p>
* Note: The {@code sourceText} is encoded as <b>UTF-8</b> to align with Tree-sitter's
* default byte offsets. If the tree was parsed with a different encoding (e.g. UTF-16),
* predicate results may be incorrect.
*
* @see #exec(TSQuery, TSNode, CharSequence)
* @param query The query to run.
* @param node The node to run the query on.
* @param sourceText The source text for predicates.
* @param progress The progress callback.
*/
public void execWithOptions(TSQuery query, TSNode node, CharSequence sourceText, TSQueryProgress progress){
ensureOpen();
executed = true;
this.node = node;
this.query = query;
this.sourceBytes = sourceText == null ? null :
sourceText.toString().getBytes(java.nio.charset.StandardCharsets.UTF_8);
ts_query_cursor_exec_with_options(ptr, query.getPtr(), node, progress, progressPayloadPtr);
}

Expand Down Expand Up @@ -238,9 +277,13 @@ public boolean setContainingPointRange(TSPoint startPoint, TSPoint endPoint){
public boolean nextMatch(TSQueryMatch match){
ensureOpen();
assertExecuted();
boolean ret = ts_query_cursor_next_match(ptr, match);
addTsTreeRef(match);
return ret;
while (ts_query_cursor_next_match(ptr, match)) {
addTsTreeRef(match);
if (satisfiesPredicates(match)) {
return true;
}
}
return false;
}


Expand All @@ -267,9 +310,13 @@ public void removeMatch(int matchId){
public boolean nextCapture(TSQueryMatch match){
ensureOpen();
assertExecuted();
boolean ret = ts_query_cursor_next_capture(ptr, match);
addTsTreeRef(match);
return ret;
while (ts_query_cursor_next_capture(ptr, match)) {
addTsTreeRef(match);
if (satisfiesPredicates(match)) {
return true;
}
}
return false;
}

private void addTsTreeRef(TSQueryMatch match){
Expand All @@ -282,6 +329,16 @@ private void addTsTreeRef(TSQueryMatch match){
}
}

private boolean satisfiesPredicates(TSQueryMatch match) {
if (query == null) return true;
List<TSQueryPredicate> patternPredicates = query.getPredicatesForPattern(match.getPatternIndex());
if (patternPredicates == null || patternPredicates.isEmpty()) {
return true;
}

return patternPredicates.stream().allMatch(predicate -> predicate.test(match, sourceBytes));
}

private void assertExecuted(){
if(!executed){
throw new TSException("Query not executed, call exec() first.");
Expand Down
Loading
Loading