diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedPusher.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedPusher.java index 013785ecc6c..3005259e4cb 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedPusher.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedPusher.java @@ -246,7 +246,7 @@ public long pushData(boolean growThreshold) throws IOException { if (offSet + recordSize > dataBuf.length) { try { - dataPusher.addTask(partition, dataBuf, offSet); + dataBuf = dataPusher.swapBufferWithIdleTask(partition, dataBuf, offSet); memoryThresholdManager.updateStats(offSet, true); } catch (InterruptedException e) { TaskInterruptedHelper.throwTaskKillException(); @@ -261,7 +261,7 @@ public long pushData(boolean growThreshold) throws IOException { } if (offSet > 0) { try { - dataPusher.addTask(currentPartition, dataBuf, offSet); + dataPusher.swapBufferWithIdleTask(currentPartition, dataBuf, offSet); memoryThresholdManager.updateStats(offSet, offSet == pushBufferMaxSize); } catch (InterruptedException e) { TaskInterruptedHelper.throwTaskKillException(); diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java index d5a0fdf2233..d868c10bb94 100644 --- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java +++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java @@ -310,19 +310,20 @@ private int getOrUpdateOffset(int partitionId, int serializedRecordSize) } if ((buffer.length - offset) < serializedRecordSize) { - flushSendBuffer(partitionId, buffer, offset); + sendBuffers[partitionId] = swapAndFlushSendBuffer(partitionId, buffer, offset); updateRecordsWrittenMetrics(); offset = 0; } return offset; } - private void flushSendBuffer(int partitionId, byte[] buffer, int size) + private byte[] swapAndFlushSendBuffer(int partitionId, byte[] buffer, int size) throws IOException, InterruptedException { long start = System.nanoTime(); logger.debug("Flush buffer, size {}.", Utils.bytesToString(size)); - dataPusher.addTask(partitionId, buffer, size); + byte[] newBuffer = dataPusher.swapBufferWithIdleTask(partitionId, buffer, size); writeMetrics.incWriteTime(System.nanoTime() - start); + return newBuffer; } protected void closeWrite() throws IOException { diff --git a/client/src/main/java/org/apache/celeborn/client/write/DataPusher.java b/client/src/main/java/org/apache/celeborn/client/write/DataPusher.java index 73af67b0733..1fd314c95f0 100644 --- a/client/src/main/java/org/apache/celeborn/client/write/DataPusher.java +++ b/client/src/main/java/org/apache/celeborn/client/write/DataPusher.java @@ -151,8 +151,9 @@ public void run() { pushThread.start(); } - public void addTask(int partitionId, byte[] buffer, int size) + public byte[] swapBufferWithIdleTask(int partitionId, byte[] buffer, int size) throws IOException, InterruptedException { + byte[] returnBuffer = null; try { PushTask task = null; while (task == null) { @@ -161,7 +162,9 @@ public void addTask(int partitionId, byte[] buffer, int size) } task.setSize(size); task.setPartitionId(partitionId); - System.arraycopy(buffer, 0, task.getBuffer(), 0, size); + // swap buffer + returnBuffer = task.getBuffer(); + task.setBuffer(buffer); while (!dataPushQueue.addPushTask(task)) { checkException(); } @@ -170,6 +173,7 @@ public void addTask(int partitionId, byte[] buffer, int size) pushThread.interrupt(); throw e; } + return returnBuffer; } public void waitOnTermination() throws IOException, InterruptedException { diff --git a/client/src/main/java/org/apache/celeborn/client/write/PushTask.java b/client/src/main/java/org/apache/celeborn/client/write/PushTask.java index f4dbd61d4df..dced9158293 100644 --- a/client/src/main/java/org/apache/celeborn/client/write/PushTask.java +++ b/client/src/main/java/org/apache/celeborn/client/write/PushTask.java @@ -21,7 +21,7 @@ public class PushTask { private int partitionId; private int size; - private byte[] buffer; + private volatile byte[] buffer; public PushTask(int bufferSize) { this.buffer = new byte[bufferSize]; @@ -46,6 +46,10 @@ public void setSize(int size) { this.size = size; } + public void setBuffer(byte[] buffer) { + this.buffer = buffer; + } + public byte[] getBuffer() { return buffer; } diff --git a/client/src/test/java/org/apache/celeborn/client/write/DataPushQueueSuiteJ.java b/client/src/test/java/org/apache/celeborn/client/write/DataPushQueueSuiteJ.java index 841e04eee07..7954012395d 100644 --- a/client/src/test/java/org/apache/celeborn/client/write/DataPushQueueSuiteJ.java +++ b/client/src/test/java/org/apache/celeborn/client/write/DataPushQueueSuiteJ.java @@ -126,7 +126,7 @@ protected void pushData(PushTask task) throws IOException { int batchId = pushState.nextBatchId(); pushState.addBatch(batchId, b.length, reducePartitionMap.get(i).hostAndPushPort()); partitionBatchIdMap.put(i, batchId); - dataPusher.addTask(i, b, b.length); + dataPusher.swapBufferWithIdleTask(i, b, b.length); } dataPusher.waitOnTermination(); @@ -169,7 +169,7 @@ protected void pushData(PushTask task) throws IOException { throw new OutOfMemoryError(); } }; - dataPusher.addTask(0, new byte[10], 0); + dataPusher.swapBufferWithIdleTask(0, new byte[10], 0); try { dataPusher.waitOnTermination(); } catch (Throwable e) {