Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 76 additions & 3 deletions spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
Original file line number Diff line number Diff line change
Expand Up @@ -483,15 +483,18 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
val serde = handler.asInstanceOf[CometOperatorSerde[SparkPlan]]
if (isOperatorEnabled(serde, op)) {
// For operators that require native children (like writes), check if all data-producing
// children are CometNativeExec. This prevents runtime failures when the native operator
// expects Arrow arrays but receives non-Arrow data (e.g., OnHeapColumnVector).
// children produce Arrow data. If not, try to convert them using CometSparkToColumnarExec.
if (serde.requiresNativeChildren && op.children.nonEmpty) {
val convertedOp = tryConvertChildrenToArrow(op)
if (convertedOp.isDefined) {
return convertToComet(convertedOp.get, handler)
}
// Get the actual data-producing children (unwrap WriteFilesExec if present)
val dataProducingChildren = op.children.flatMap {
case writeFiles: WriteFilesExec => Seq(writeFiles.child)
case other => Seq(other)
}
if (!dataProducingChildren.forall(_.isInstanceOf[CometNativeExec])) {
if (!dataProducingChildren.forall(producesArrowData)) {
withInfo(op, "Cannot perform native operation because input is not in Arrow format")
return None
}
Expand Down Expand Up @@ -600,4 +603,74 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
}
}

/**
* Check if a SparkPlan produces Arrow-formatted data. This handles wrapper operators like
* ReusedExchangeExec and QueryStageExec that wrap actual Comet operators.
*/
private def producesArrowData(plan: SparkPlan): Boolean = {
plan match {
case _: CometExec => true
case r: ReusedExchangeExec => producesArrowData(r.child)
case s: ShuffleQueryStageExec => producesArrowData(s.plan)
case b: BroadcastQueryStageExec => producesArrowData(b.plan)
case _ => false
}
}

/**
* Try to convert non-Arrow children to Arrow format using CometSparkToColumnarExec. This
* enables native writes even when the source is a Spark operator like RangeExec.
*
* @return
* Some(newOp) if any child was converted, None if no conversion was needed or possible
*/
private def tryConvertChildrenToArrow(op: SparkPlan): Option[SparkPlan] = {
var anyConverted = false
val fallbackReasons = new scala.collection.mutable.ListBuffer[String]()

val newChildren = op.children.map {
case writeFiles: WriteFilesExec =>
// For WriteFilesExec, we need to convert its child
val child = writeFiles.child
if (!producesArrowData(child) && canConvertToArrow(child, fallbackReasons)) {
anyConverted = true
val converted = convertToComet(child, CometSparkToColumnarExec)
converted match {
case Some(cometChild) => writeFiles.withNewChildren(Seq(cometChild))
case None => writeFiles
}
} else {
writeFiles
}
case child if !producesArrowData(child) && canConvertToArrow(child, fallbackReasons) =>
anyConverted = true
convertToComet(child, CometSparkToColumnarExec).getOrElse(child)
case other => other
}

if (anyConverted) {
Some(op.withNewChildren(newChildren))
} else {
None
}
}

/**
* Check if a SparkPlan can be converted to Arrow format using CometSparkToColumnarExec.
*/
private def canConvertToArrow(
op: SparkPlan,
fallbackReasons: scala.collection.mutable.ListBuffer[String]): Boolean = {
if (!CometSparkToColumnarExec.isSchemaSupported(op.schema, fallbackReasons)) {
return false
}
op match {
case _: LeafExecNode =>
val simpleClassName = Utils.getSimpleName(op.getClass)
val nodeName = simpleClassName.replaceAll("Exec$", "")
COMET_SPARK_TO_ARROW_SUPPORTED_OPERATOR_LIST.get(conf).contains(nodeName)
case _ => false
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,176 @@ class CometParquetWriterSuite extends CometTestBase {
}
}

// Tests for issue #2957: INSERT OVERWRITE DIRECTORY with native writer
// Note: INSERT OVERWRITE DIRECTORY uses InsertIntoDataSourceDirCommand which internally
// executes InsertIntoHadoopFsRelationCommand. The outer plan shows ExecutedCommandExec,
// but the actual write happens in an internal execution which should use Comet's native writer.
//
// Fix: CometExecRule now auto-converts children to Arrow format when native write is requested.
// This enables native writes even when the source is a Spark operator like RangeExec.
test("INSERT OVERWRITE DIRECTORY using parquet - basic with RangeExec source") {
withTempPath { dir =>
val outputPath = dir.getAbsolutePath

withSQLConf(
CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true",
CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true",
CometConf.COMET_EXEC_ENABLED.key -> "true") {
// Note: COMET_SPARK_TO_ARROW_ENABLED is NOT needed - the fix auto-converts RangeExec

// Create source table using RangeExec
spark.range(1, 10).toDF("id").createOrReplaceTempView("source_table")

// Execute INSERT OVERWRITE DIRECTORY
spark.sql(s"""
INSERT OVERWRITE DIRECTORY '$outputPath'
USING PARQUET
SELECT id FROM source_table
""")

// Verify data was written correctly
val result = spark.read.parquet(outputPath)
assert(result.count() == 9, "INSERT OVERWRITE DIRECTORY should write 9 rows")
}
}
}

// Test with Parquet source file (native scan) - this should use native writer
test("INSERT OVERWRITE DIRECTORY using parquet - with parquet source") {
withTempPath { srcDir =>
withTempPath { outDir =>
val srcPath = srcDir.getAbsolutePath
val outputPath = outDir.getAbsolutePath

// Create source parquet file
spark.range(1, 10).toDF("id").write.parquet(srcPath)

withSQLConf(
CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true",
CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true",
CometConf.COMET_EXEC_ENABLED.key -> "true",
CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "true") {

// Create table from parquet file
spark.read.parquet(srcPath).createOrReplaceTempView("parquet_source")

// Execute INSERT OVERWRITE DIRECTORY
spark.sql(s"""
INSERT OVERWRITE DIRECTORY '$outputPath'
USING PARQUET
SELECT id FROM parquet_source
""")

// Verify data was written correctly
val result = spark.read.parquet(outputPath)
assert(result.count() == 9, "INSERT OVERWRITE DIRECTORY should write 9 rows")
}
}
}
}

test("INSERT OVERWRITE DIRECTORY using parquet with repartition hint") {
withTempPath { dir =>
val outputPath = dir.getAbsolutePath

withSQLConf(
CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true",
CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true",
CometConf.COMET_EXEC_ENABLED.key -> "true") {

// Create source table
spark.range(1, 100).toDF("value").createOrReplaceTempView("df")

// Execute INSERT OVERWRITE DIRECTORY with REPARTITION hint (as in issue #2957)
val df = spark.sql(s"""
INSERT OVERWRITE DIRECTORY '$outputPath'
USING PARQUET
SELECT /*+ REPARTITION(3) */ value FROM df
""")

// Verify data was written correctly
val result = spark.read.parquet(outputPath)
assert(result.count() == 99)

// Check if native write was used
val plan = df.queryExecution.executedPlan
val hasNativeWrite = plan.collect { case _: CometNativeWriteExec => true }.nonEmpty
if (!hasNativeWrite) {
logWarning(s"Native write not used for INSERT OVERWRITE DIRECTORY:\n${plan.treeString}")
}
}
}
}

test("INSERT OVERWRITE DIRECTORY using parquet with aggregation") {
withTempPath { dir =>
val outputPath = dir.getAbsolutePath

withSQLConf(
CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true",
CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true",
CometConf.COMET_EXEC_ENABLED.key -> "true") {

// Create source table with some data for aggregation
spark.range(1, 100).toDF("id").createOrReplaceTempView("agg_source")

// Execute INSERT OVERWRITE DIRECTORY with aggregation
val df = spark.sql(s"""
INSERT OVERWRITE DIRECTORY '$outputPath'
USING PARQUET
SELECT id % 10 as group_id, count(*) as cnt
FROM agg_source
GROUP BY id % 10
""")

// Verify data was written correctly (10 groups: 0-9)
val result = spark.read.parquet(outputPath)
assert(result.count() == 10)

// Check if native write was used
val plan = df.queryExecution.executedPlan
val hasNativeWrite = plan.collect { case _: CometNativeWriteExec => true }.nonEmpty
if (!hasNativeWrite) {
logWarning(s"Native write not used for INSERT OVERWRITE DIRECTORY:\n${plan.treeString}")
}
}
}
}

test("INSERT OVERWRITE DIRECTORY using parquet with compression option") {
withTempPath { dir =>
val outputPath = dir.getAbsolutePath

withSQLConf(
CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true",
CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true",
CometConf.COMET_EXEC_ENABLED.key -> "true") {

// Create source table
spark.range(1, 50).toDF("id").createOrReplaceTempView("comp_source")

// Execute INSERT OVERWRITE DIRECTORY with compression option
val df = spark.sql(s"""
INSERT OVERWRITE DIRECTORY '$outputPath'
USING PARQUET
OPTIONS ('compression' = 'snappy')
SELECT id FROM comp_source
""")

// Verify data was written correctly
val result = spark.read.parquet(outputPath)
assert(result.count() == 49)

// Check if native write was used
val plan = df.queryExecution.executedPlan
val hasNativeWrite = plan.collect { case _: CometNativeWriteExec => true }.nonEmpty
if (!hasNativeWrite) {
logWarning(s"Native write not used for INSERT OVERWRITE DIRECTORY:\n${plan.treeString}")
}
}
}
}

private def readSparkRows(path: String): Array[Row] = {
var rows: Array[Row] = null
withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
Expand Down