Skip to content

Commit f02d437

Browse files
committed
fix(#384): Refactor BatchStatementIT and add LWTLoadBalancingMultiDcIT for local DC routing
1 parent 2300517 commit f02d437

File tree

5 files changed

+306
-48
lines changed

5 files changed

+306
-48
lines changed

core/src/main/java/com/datastax/oss/driver/api/core/session/Request.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -203,10 +203,14 @@ default Partitioner getPartitioner() {
203203
Node getNode();
204204

205205
/**
206-
* Returns the routing type configured on this statement.
206+
* Returns the routing type for this request.
207207
*
208-
* @return The routing method configured on this statement, or {@link RequestRoutingType#REGULAR}
209-
* if none is configured.
208+
* <p>The value represents how the request is handled on the server side (for example, regular vs
209+
* lightweight transaction). Load balancing policies use this signal to shape the execution plan
210+
* (eligible coordinators and ordering).
211+
*
212+
* @return The routing type configured on this request, or {@link RequestRoutingType#REGULAR} if
213+
* none is configured.
210214
*/
211215
@Nullable
212216
RequestRoutingType getRequestType();

core/src/main/java/com/datastax/oss/driver/internal/core/loadbalancing/DefaultLoadBalancingPolicy.java

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
import java.util.concurrent.ConcurrentMap;
5252
import java.util.concurrent.ThreadLocalRandom;
5353
import java.util.concurrent.atomic.AtomicLongArray;
54+
import java.util.stream.Collectors;
5455
import net.jcip.annotations.ThreadSafe;
5556
import org.apache.commons.lang3.tuple.Pair;
5657
import org.slf4j.Logger;
@@ -139,7 +140,10 @@ public Queue<Node> newQueryPlan(@Nullable Request request, @Nullable Session ses
139140
Objects.nonNull(request) ? request.getRequestType() : RequestRoutingType.REGULAR;
140141
boolean isLWT = requestType == RequestRoutingType.LWT;
141142
Object[] currentNodes =
142-
isLWT ? replicas.toArray() : getLiveNodes().dc(getLocalDatacenter()).toArray();
143+
isLWT
144+
? getReplicasFromLocalDcForLwt(replicas)
145+
: getLiveNodes().dc(getLocalDatacenter()).toArray();
146+
143147
if (Objects.nonNull(request)
144148
&& request.getRoutingMethod() == RequestRoutingMethod.PRESERVE_REPLICA_ORDER) {
145149
return new SimpleQueryPlan(currentNodes);
@@ -174,8 +178,29 @@ public Queue<Node> newQueryPlan(@Nullable Request request, @Nullable Session ses
174178
return maybeAddDcFailover(request, plan);
175179
}
176180

181+
/** For LWT requests, prefer replicas in the local DC to avoid cross-DC coordination */
182+
private Object[] getReplicasFromLocalDcForLwt(List<Node> replicas) {
183+
// For LWT requests, start from replicas; if a local DC is configured, prefer replicas
184+
// in the local DC to avoid cross-DC coordination. Preserve original replica order.
185+
String localDc = getLocalDatacenter();
186+
if (localDc != null) {
187+
List<Node> filtered =
188+
replicas.stream()
189+
.filter(n -> Objects.equals(n.getDatacenter(), localDc))
190+
.collect(Collectors.toList());
191+
// Fallback to all replicas if none are in the local DC
192+
if (!filtered.isEmpty()) {
193+
return filtered.toArray();
194+
}
195+
}
196+
return replicas.toArray();
197+
}
198+
177199
private Pair<Integer, Integer> moveReplicasToFront(
178200
RequestRoutingType routingType, Object[] currentNodes, List<Node> allReplicas) {
201+
// Note: local rack prioritization is intentionally ignored for LWT requests to prevent
202+
// congestion when different loaders from different racks target distinct rack-local LWT
203+
// leaders.
179204
int replicaCount = 0, localRackReplicaCount = 0;
180205
for (int i = 0; i < currentNodes.length; i++) {
181206
Node node = (Node) currentNodes[i];
@@ -199,6 +224,8 @@ private void shuffleLocalRackReplicasAndReplicas(
199224
Object[] currentNodes,
200225
int replicaCount,
201226
int localRackReplicaCount) {
227+
// For LWT, ignore local rack prioritization to avoid rack-local leader congestion; treat
228+
// all local-DC replicas uniformly.
202229
if (routingType != RequestRoutingType.LWT
203230
&& getLocalRack() != null
204231
&& localRackReplicaCount > 0) {

integration-tests/src/test/java/com/datastax/oss/driver/core/cql/BatchStatementIT.java

Lines changed: 55 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import ch.qos.logback.classic.spi.ILoggingEvent;
3232
import com.datastax.oss.driver.api.core.ConsistencyLevel;
3333
import com.datastax.oss.driver.api.core.CqlSession;
34+
import com.datastax.oss.driver.api.core.RequestRoutingType;
3435
import com.datastax.oss.driver.api.core.config.DefaultDriverOption;
3536
import com.datastax.oss.driver.api.core.config.DriverConfigLoader;
3637
import com.datastax.oss.driver.api.core.cql.BatchStatement;
@@ -66,11 +67,11 @@
6667
@Category(ParallelizableTests.class)
6768
public class BatchStatementIT {
6869

69-
private CcmRule ccmRule = CcmRule.getInstance();
70+
private final CcmRule CCM_RULE = CcmRule.getInstance();
7071

71-
private SessionRule<CqlSession> sessionRule = SessionRule.builder(ccmRule).build();
72+
private final SessionRule<CqlSession> SESSION_RULE = SessionRule.builder(CCM_RULE).build();
7273

73-
@Rule public TestRule chain = RuleChain.outerRule(ccmRule).around(sessionRule);
74+
@Rule public TestRule chain = RuleChain.outerRule(CCM_RULE).around(SESSION_RULE);
7475

7576
@Rule public TestName name = new TestName();
7677

@@ -89,11 +90,11 @@ public void createTable() {
8990
SchemaChangeSynchronizer.withLock(
9091
() -> {
9192
for (String schemaStatement : schemaStatements) {
92-
sessionRule
93+
SESSION_RULE
9394
.session()
9495
.execute(
9596
SimpleStatement.newInstance(schemaStatement)
96-
.setExecutionProfile(sessionRule.slowProfile()));
97+
.setExecutionProfile(SESSION_RULE.slowProfile()));
9798
}
9899
});
99100
}
@@ -103,7 +104,7 @@ public void should_issue_log_warn_if_batched_statement_have_consistency_level_se
103104
SimpleStatement simpleStatement =
104105
SimpleStatement.builder("INSERT INTO test (k0, k1, v) values ('123123', ?, ?)").build();
105106

106-
try (CqlSession session = SessionUtils.newSession(ccmRule, sessionRule.keyspace())) {
107+
try (CqlSession session = SessionUtils.newSession(CCM_RULE, SESSION_RULE.keyspace())) {
107108
PreparedStatement prep = session.prepare(simpleStatement);
108109
BatchStatementBuilder batch = BatchStatement.builder(DefaultBatchType.UNLOGGED);
109110
batch.addStatement(prep.bind(1, 2).setConsistencyLevel(ConsistencyLevel.QUORUM));
@@ -139,7 +140,7 @@ public void should_execute_batch_of_simple_statements_with_variables() {
139140
}
140141

141142
BatchStatement batchStatement = builder.build();
142-
sessionRule.session().execute(batchStatement);
143+
SESSION_RULE.session().execute(batchStatement);
143144

144145
verifyBatchInsert();
145146
}
@@ -154,14 +155,14 @@ public void should_execute_batch_of_bound_statements_with_variables() {
154155
String.format(
155156
"INSERT INTO test (k0, k1, v) values ('%s', ? , ?)", name.getMethodName()))
156157
.build();
157-
PreparedStatement preparedStatement = sessionRule.session().prepare(insert);
158+
PreparedStatement preparedStatement = SESSION_RULE.session().prepare(insert);
158159

159160
for (int i = 0; i < batchCount; i++) {
160161
builder.addStatement(preparedStatement.bind(i, i + 1));
161162
}
162163

163164
BatchStatement batchStatement = builder.build();
164-
sessionRule.session().execute(batchStatement);
165+
SESSION_RULE.session().execute(batchStatement);
165166

166167
verifyBatchInsert();
167168
}
@@ -178,14 +179,14 @@ public void should_execute_batch_of_bound_statements_with_unset_values() {
178179
String.format(
179180
"INSERT INTO test (k0, k1, v) values ('%s', ? , ?)", name.getMethodName()))
180181
.build();
181-
PreparedStatement preparedStatement = sessionRule.session().prepare(insert);
182+
PreparedStatement preparedStatement = SESSION_RULE.session().prepare(insert);
182183

183184
for (int i = 0; i < batchCount; i++) {
184185
builder.addStatement(preparedStatement.bind(i, i + 1));
185186
}
186187

187188
BatchStatement batchStatement = builder.build();
188-
sessionRule.session().execute(batchStatement);
189+
SESSION_RULE.session().execute(batchStatement);
189190

190191
verifyBatchInsert();
191192

@@ -196,17 +197,17 @@ public void should_execute_batch_of_bound_statements_with_unset_values() {
196197
if (i % 20 == 0) {
197198
boundStatement = boundStatement.unset(1);
198199
}
199-
builder.addStatement(boundStatement);
200+
builder2.addStatement(boundStatement);
200201
}
201202

202-
sessionRule.session().execute(builder2.build());
203+
SESSION_RULE.session().execute(builder2.build());
203204

204205
Statement<?> select =
205206
SimpleStatement.builder("SELECT * from test where k0 = ?")
206207
.addPositionalValue(name.getMethodName())
207208
.build();
208209

209-
ResultSet result = sessionRule.session().execute(select);
210+
ResultSet result = SESSION_RULE.session().execute(select);
210211

211212
List<Row> rows = result.all();
212213
assertThat(rows).hasSize(100);
@@ -230,7 +231,7 @@ public void should_execute_batch_of_bound_statements_with_named_variables() {
230231
// variable values.
231232
BatchStatementBuilder builder = BatchStatement.builder(DefaultBatchType.UNLOGGED);
232233
PreparedStatement preparedStatement =
233-
sessionRule.session().prepare("INSERT INTO test (k0, k1, v) values (:k0, :k1, :v)");
234+
SESSION_RULE.session().prepare("INSERT INTO test (k0, k1, v) values (:k0, :k1, :v)");
234235

235236
for (int i = 0; i < batchCount; i++) {
236237
builder.addStatement(
@@ -243,7 +244,7 @@ public void should_execute_batch_of_bound_statements_with_named_variables() {
243244
}
244245

245246
BatchStatement batchStatement = builder.build();
246-
sessionRule.session().execute(batchStatement);
247+
SESSION_RULE.session().execute(batchStatement);
247248

248249
verifyBatchInsert();
249250
}
@@ -257,7 +258,7 @@ public void should_execute_batch_of_bound_and_simple_statements_with_variables()
257258
String.format(
258259
"INSERT INTO test (k0, k1, v) values ('%s', ? , ?)", name.getMethodName()))
259260
.build();
260-
PreparedStatement preparedStatement = sessionRule.session().prepare(insert);
261+
PreparedStatement preparedStatement = SESSION_RULE.session().prepare(insert);
261262

262263
for (int i = 0; i < batchCount; i++) {
263264
if (i % 2 == 1) {
@@ -274,7 +275,7 @@ public void should_execute_batch_of_bound_and_simple_statements_with_variables()
274275
}
275276

276277
BatchStatement batchStatement = builder.build();
277-
sessionRule.session().execute(batchStatement);
278+
SESSION_RULE.session().execute(batchStatement);
278279

279280
verifyBatchInsert();
280281
}
@@ -284,25 +285,45 @@ public void should_execute_cas_batch() {
284285
// Build a batch with CAS operations on the same partition.
285286
BatchStatementBuilder builder = BatchStatement.builder(DefaultBatchType.UNLOGGED);
286287
SimpleStatement insert =
287-
SimpleStatement.builder(
288-
String.format(
289-
"INSERT INTO test (k0, k1, v) values ('%s', ? , ?) IF NOT EXISTS",
290-
name.getMethodName()))
288+
SimpleStatement.builder("INSERT INTO test (k0, k1, v) values (?, ?, ?) IF NOT EXISTS")
291289
.build();
292-
PreparedStatement preparedStatement = sessionRule.session().prepare(insert);
290+
PreparedStatement preparedStatement = SESSION_RULE.session().prepare(insert);
293291

294292
for (int i = 0; i < batchCount; i++) {
295-
builder.addStatement(preparedStatement.bind(i, i + 1));
293+
builder.addStatement(preparedStatement.bind(name.getMethodName(), i, i + 1));
296294
}
297295

296+
// Ensure LWT routing has a concrete routing key to compute replicas.
297+
BoundStatement routingKeyStmt = preparedStatement.bind(name.getMethodName(), 0, 1);
298+
builder.setRoutingKey(routingKeyStmt.getRoutingKey());
299+
builder.setSerialConsistencyLevel(ConsistencyLevel.SERIAL);
300+
298301
BatchStatement batchStatement = builder.build();
299-
ResultSet result = sessionRule.session().execute(batchStatement);
302+
// Validate serial consistency and LWT routing on the batch itself.
303+
assertThat(batchStatement.getSerialConsistencyLevel()).isEqualTo(ConsistencyLevel.SERIAL);
304+
assertThat(batchStatement.getRequestType()).isEqualTo(RequestRoutingType.LWT);
305+
assertThat(batchStatement.getRoutingKey()).isNotNull();
306+
307+
ResultSet result = SESSION_RULE.session().execute(batchStatement);
308+
// Validate that executed request preserved serial consistency level.
309+
assertThat(result.getExecutionInfo().getRequest()).isInstanceOf(Statement.class);
310+
assertThat(((Statement<?>) result.getExecutionInfo().getRequest()).getSerialConsistencyLevel())
311+
.isEqualTo(ConsistencyLevel.SERIAL);
300312
assertThat(result.wasApplied()).isTrue();
301313

302314
verifyBatchInsert();
303315

304-
// re execute same batch and ensure wasn't applied.
305-
result = sessionRule.session().execute(batchStatement);
316+
// Rebuild an equivalent batch and ensure it isn't applied.
317+
BatchStatementBuilder rerunBuilder = BatchStatement.builder(DefaultBatchType.UNLOGGED);
318+
rerunBuilder.setSerialConsistencyLevel(ConsistencyLevel.SERIAL);
319+
for (int i = 0; i < batchCount; i++) {
320+
rerunBuilder.addStatement(preparedStatement.bind(name.getMethodName(), i, i + 1));
321+
}
322+
// Use the same routing key to target the same partition for LWT.
323+
rerunBuilder.setRoutingKey(routingKeyStmt.getRoutingKey());
324+
BatchStatement rerunBatch = rerunBuilder.build();
325+
assertThat(rerunBatch.getRequestType()).isEqualTo(RequestRoutingType.LWT);
326+
result = SESSION_RULE.session().execute(rerunBatch);
306327
assertThat(result.wasApplied()).isFalse();
307328
}
308329

@@ -322,11 +343,11 @@ public void should_execute_counter_batch() {
322343
}
323344

324345
BatchStatement batchStatement = builder.build();
325-
sessionRule.session().execute(batchStatement);
346+
SESSION_RULE.session().execute(batchStatement);
326347

327348
for (int i = 1; i <= 3; i++) {
328349
ResultSet result =
329-
sessionRule
350+
SESSION_RULE
330351
.session()
331352
.execute(
332353
String.format(
@@ -356,7 +377,7 @@ public void should_fail_logged_batch_with_counter_increment() {
356377
}
357378

358379
BatchStatement batchStatement = builder.build();
359-
sessionRule.session().execute(batchStatement);
380+
SESSION_RULE.session().execute(batchStatement);
360381
}
361382

362383
@Test(expected = InvalidQueryException.class)
@@ -383,7 +404,7 @@ public void should_fail_counter_batch_with_non_counter_increment() {
383404
builder.addStatement(simpleInsert);
384405

385406
BatchStatement batchStatement = builder.build();
386-
sessionRule.session().execute(batchStatement);
407+
SESSION_RULE.session().execute(batchStatement);
387408
}
388409

389410
@Test
@@ -394,13 +415,13 @@ public void should_not_allow_unset_value_when_protocol_less_than_v4() {
394415
SessionUtils.configLoaderBuilder()
395416
.withString(DefaultDriverOption.PROTOCOL_VERSION, "V3")
396417
.build();
397-
try (CqlSession v3Session = SessionUtils.newSession(ccmRule, loader)) {
418+
try (CqlSession v3Session = SessionUtils.newSession(CCM_RULE, loader)) {
398419
// Intentionally use fully qualified table here to avoid warnings as these are not supported
399420
// by v3 protocol version, see JAVA-3068
400421
PreparedStatement prepared =
401422
v3Session.prepare(
402423
String.format(
403-
"INSERT INTO %s.test (k0, k1, v) values (?, ?, ?)", sessionRule.keyspace()));
424+
"INSERT INTO %s.test (k0, k1, v) values (?, ?, ?)", SESSION_RULE.keyspace()));
404425

405426
BatchStatementBuilder builder = BatchStatement.builder(DefaultBatchType.LOGGED);
406427
builder.addStatements(
@@ -427,7 +448,7 @@ private void verifyBatchInsert() {
427448
.addPositionalValue(name.getMethodName())
428449
.build();
429450

430-
ResultSet result = sessionRule.session().execute(select);
451+
ResultSet result = SESSION_RULE.session().execute(select);
431452

432453
List<Row> rows = result.all();
433454
assertThat(rows).hasSize(100);

0 commit comments

Comments
 (0)