Skip to content

Commit faae1f9

Browse files
committed
fixup! [substrait] Add support for ExtensionTable
1 parent 20ba857 commit faae1f9

File tree

4 files changed

+114
-84
lines changed

4 files changed

+114
-84
lines changed

datafusion/expr/src/registry.rs

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ pub trait SerializerRegistry: Debug + Send + Sync {
132132
fn serialize_logical_plan(
133133
&self,
134134
node: &dyn UserDefinedLogicalNode,
135-
) -> Result<Vec<u8>> {
135+
) -> Result<NamedBytes> {
136136
not_impl_err!(
137137
"Serializing user defined logical plan node `{}` is not supported",
138138
node.name()
@@ -151,17 +151,16 @@ pub trait SerializerRegistry: Debug + Send + Sync {
151151
)
152152
}
153153

154-
/// Serialized table definition for UDTFs or manually registered table providers that can't be
155-
/// marshaled by reference. Should return some benign error for regular tables that can be
156-
/// found/restored by name in the destination execution context.
157-
fn serialize_custom_table(&self, _table: &dyn TableSource) -> Result<Vec<u8>> {
158-
not_impl_err!("No custom table support")
154+
/// Serialized table definition for UDTFs or some other table provider implementation that
155+
/// can't be marshaled by reference.
156+
fn serialize_custom_table(
157+
&self,
158+
_table: &dyn TableSource,
159+
) -> Result<Option<NamedBytes>> {
160+
Ok(None)
159161
}
160162

161-
/// Deserialize the custom table with the given name.
162-
/// Note: more often than not, the name can't be used as a discriminator if multiple different
163-
/// `TableSource` and/or `TableProvider` implementations are expected (this is particularly true
164-
/// for UDTFs in DataFusion, which are always registered under the same name: `tmp_table`).
163+
/// Deserialize a custom table.
165164
fn deserialize_custom_table(
166165
&self,
167166
name: &str,
@@ -171,6 +170,11 @@ pub trait SerializerRegistry: Debug + Send + Sync {
171170
}
172171
}
173172

173+
/// A sequence of bytes with a string qualifier. Meant to encapsulate serialized extensions
174+
/// that need to carry their type, e.g. the `type_url` for protobuf messages.
175+
#[derive(Debug, Clone)]
176+
pub struct NamedBytes(pub String, pub Vec<u8>);
177+
174178
/// A [`FunctionRegistry`] that uses in memory [`HashMap`]s
175179
#[derive(Default, Debug)]
176180
pub struct MemoryFunctionRegistry {

datafusion/substrait/src/logical_plan/consumer.rs

Lines changed: 37 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ use datafusion::logical_expr::expr::{Exists, InSubquery, Sort};
3030

3131
use datafusion::logical_expr::{
3232
Aggregate, BinaryExpr, Case, Cast, EmptyRelation, Expr, ExprSchemable, Extension,
33-
LogicalPlan, Operator, Projection, SortExpr, Subquery, TableScan, TryCast, Values,
33+
LogicalPlan, Operator, Projection, SortExpr, Subquery, TableSource, TryCast, Values,
3434
};
3535
use substrait::proto::aggregate_rel::Grouping;
3636
use substrait::proto::expression as substrait_expression;
@@ -462,9 +462,7 @@ pub trait SubstraitConsumer: Send + Sync + Sized {
462462
fn consume_extension_table(
463463
&self,
464464
extension_table: &ExtensionTable,
465-
_schema: &DFSchema,
466-
_projection: &Option<MaskExpression>,
467-
) -> Result<LogicalPlan> {
465+
) -> Result<Arc<dyn TableSource>> {
468466
if let Some(ext_detail) = extension_table.detail.as_ref() {
469467
substrait_err!(
470468
"Missing handler for extension table: {}",
@@ -599,24 +597,11 @@ impl SubstraitConsumer for DefaultSubstraitConsumer<'_> {
599597
fn consume_extension_table(
600598
&self,
601599
extension_table: &ExtensionTable,
602-
schema: &DFSchema,
603-
projection: &Option<MaskExpression>,
604-
) -> Result<LogicalPlan> {
600+
) -> Result<Arc<dyn TableSource>> {
605601
if let Some(ext_detail) = &extension_table.detail {
606-
let source = self
607-
.state
602+
self.state
608603
.serializer_registry()
609-
.deserialize_custom_table(&ext_detail.type_url, &ext_detail.value)?;
610-
let table_name = ext_detail
611-
.type_url
612-
.rsplit_once('/')
613-
.map(|(_, name)| name)
614-
.unwrap_or(&ext_detail.type_url);
615-
let table_scan = TableScan::try_new(table_name, source, None, vec![], None)?;
616-
let plan = LogicalPlan::TableScan(table_scan);
617-
ensure_schema_compatibility(plan.schema(), schema.clone())?;
618-
let schema = apply_masking(schema.clone(), projection)?;
619-
apply_projection(plan, schema)
604+
.deserialize_custom_table(&ext_detail.type_url, &ext_detail.value)
620605
} else {
621606
substrait_err!("Unexpected empty detail in ExtensionTable")
622607
}
@@ -1366,26 +1351,14 @@ pub async fn from_read_rel(
13661351
read: &ReadRel,
13671352
) -> Result<LogicalPlan> {
13681353
async fn read_with_schema(
1369-
consumer: &impl SubstraitConsumer,
13701354
table_ref: TableReference,
1355+
table_source: Arc<dyn TableSource>,
13711356
schema: DFSchema,
13721357
projection: &Option<MaskExpression>,
13731358
) -> Result<LogicalPlan> {
13741359
let schema = schema.replace_qualifier(table_ref.clone());
13751360

1376-
let plan = {
1377-
let provider = match consumer.resolve_table_ref(&table_ref).await? {
1378-
Some(ref provider) => Arc::clone(provider),
1379-
_ => return plan_err!("No table named '{table_ref}'"),
1380-
};
1381-
1382-
LogicalPlanBuilder::scan(
1383-
table_ref,
1384-
provider_as_source(Arc::clone(&provider)),
1385-
None,
1386-
)?
1387-
.build()?
1388-
};
1361+
let plan = { LogicalPlanBuilder::scan(table_ref, table_source, None)?.build()? };
13891362

13901363
ensure_schema_compatibility(plan.schema(), schema.clone())?;
13911364

@@ -1394,6 +1367,17 @@ pub async fn from_read_rel(
13941367
apply_projection(plan, schema)
13951368
}
13961369

1370+
async fn table_source(
1371+
consumer: &impl SubstraitConsumer,
1372+
table_ref: &TableReference,
1373+
) -> Result<Arc<dyn TableSource>> {
1374+
if let Some(provider) = consumer.resolve_table_ref(table_ref).await? {
1375+
Ok(provider_as_source(provider))
1376+
} else {
1377+
plan_err!("No table named '{table_ref}'")
1378+
}
1379+
}
1380+
13971381
let named_struct = read.base_schema.as_ref().ok_or_else(|| {
13981382
substrait_datafusion_err!("No base schema provided for Read Relation")
13991383
})?;
@@ -1419,10 +1403,10 @@ pub async fn from_read_rel(
14191403
table: nt.names[2].clone().into(),
14201404
},
14211405
};
1422-
1406+
let table_source = table_source(consumer, &table_reference).await?;
14231407
read_with_schema(
1424-
consumer,
14251408
table_reference,
1409+
table_source,
14261410
substrait_schema,
14271411
&read.projection,
14281412
)
@@ -1501,17 +1485,31 @@ pub async fn from_read_rel(
15011485
let name = filename.unwrap();
15021486
// directly use unwrap here since we could determine it is a valid one
15031487
let table_reference = TableReference::Bare { table: name.into() };
1488+
let table_source = table_source(consumer, &table_reference).await?;
15041489

15051490
read_with_schema(
1506-
consumer,
15071491
table_reference,
1492+
table_source,
15081493
substrait_schema,
15091494
&read.projection,
15101495
)
15111496
.await
15121497
}
15131498
Some(ReadType::ExtensionTable(ext)) => {
1514-
consumer.consume_extension_table(ext, &substrait_schema, &read.projection)
1499+
let name_hint = read
1500+
.common
1501+
.as_ref()
1502+
.and_then(|rel_common| rel_common.hint.as_ref())
1503+
.map(|hint| hint.alias.as_str())
1504+
.filter(|alias| !alias.is_empty());
1505+
let table_name = name_hint.unwrap_or("tmp_table");
1506+
read_with_schema(
1507+
TableReference::from(table_name),
1508+
consumer.consume_extension_table(ext)?,
1509+
substrait_schema,
1510+
&read.projection,
1511+
)
1512+
.await
15151513
}
15161514
None => {
15171515
substrait_err!("Unexpected empty read_type")
@@ -1917,7 +1915,7 @@ pub async fn from_substrait_sorts(
19171915
},
19181916
None => not_impl_err!("Sort without sort kind is invalid"),
19191917
};
1920-
let (asc, nulls_first) = asc_nullfirst.unwrap();
1918+
let (asc, nulls_first) = asc_nullfirst?;
19211919
sorts.push(Sort {
19221920
expr,
19231921
asc,

datafusion/substrait/src/logical_plan/producer.rs

Lines changed: 60 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,11 @@ use std::sync::Arc;
2222
use substrait::proto::expression_reference::ExprType;
2323

2424
use datafusion::arrow::datatypes::{Field, IntervalUnit};
25-
use datafusion::logical_expr::{Aggregate, Distinct, EmptyRelation, Extension, Filter, Join, Like, Limit, Partitioning, Projection, Repartition, Sort, SortExpr, SubqueryAlias, TableScan, TableSource, TryCast, Union, Values, Window, WindowFrameUnits};
25+
use datafusion::logical_expr::{
26+
Aggregate, Distinct, EmptyRelation, Extension, Filter, Join, Like, Limit,
27+
Partitioning, Projection, Repartition, Sort, SortExpr, SubqueryAlias, TableScan,
28+
TableSource, TryCast, Union, Values, Window, WindowFrameUnits,
29+
};
2630
use datafusion::{
2731
arrow::datatypes::{DataType, TimeUnit},
2832
error::{DataFusionError, Result},
@@ -50,9 +54,10 @@ use datafusion::execution::SessionState;
5054
use datafusion::logical_expr::expr::{
5155
Alias, BinaryExpr, Case, Cast, GroupingSet, InList, InSubquery, WindowFunction,
5256
};
57+
use datafusion::logical_expr::registry::NamedBytes;
5358
use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Operator};
5459
use datafusion::prelude::Expr;
55-
use pbjson_types::{Any as ProtoAny, Any};
60+
use pbjson_types::Any as ProtoAny;
5661
use substrait::proto::exchange_rel::{ExchangeKind, RoundRobin, ScatterFields};
5762
use substrait::proto::expression::cast::FailureBehavior;
5863
use substrait::proto::expression::field_reference::{RootReference, RootType};
@@ -66,8 +71,8 @@ use substrait::proto::expression::subquery::InPredicate;
6671
use substrait::proto::expression::window_function::BoundsType;
6772
use substrait::proto::expression::ScalarFunction;
6873
use substrait::proto::read_rel::{ExtensionTable, VirtualTable};
69-
use substrait::proto::rel_common::EmitKind;
7074
use substrait::proto::rel_common::EmitKind::Emit;
75+
use substrait::proto::rel_common::{EmitKind, Hint};
7176
use substrait::proto::{
7277
fetch_rel, rel_common, ExchangeRel, ExpressionReference, ExtendedExpression,
7378
RelCommon,
@@ -363,10 +368,10 @@ pub trait SubstraitProducer: Send + Sync + Sized {
363368
from_in_subquery(self, in_subquery, schema)
364369
}
365370

366-
fn handle_extension_table(
371+
fn handle_custom_table(
367372
&mut self,
368373
_table: &dyn TableSource,
369-
) -> Result<ExtensionTable> {
374+
) -> Result<Option<ExtensionTable>> {
370375
not_impl_err!("Not implemented")
371376
}
372377
}
@@ -395,12 +400,12 @@ impl SubstraitProducer for DefaultSubstraitProducer<'_> {
395400
}
396401

397402
fn handle_extension(&mut self, plan: &Extension) -> Result<Box<Rel>> {
398-
let extension_bytes = self
403+
let NamedBytes(qualifier, bytes) = self
399404
.serializer_registry
400405
.serialize_logical_plan(plan.node.as_ref())?;
401406
let detail = ProtoAny {
402-
type_url: plan.node.name().to_string(),
403-
value: extension_bytes.into(),
407+
type_url: qualifier.to_string(),
408+
value: bytes.to_owned().into(),
404409
};
405410
let mut inputs_rel = plan
406411
.node
@@ -429,14 +434,22 @@ impl SubstraitProducer for DefaultSubstraitProducer<'_> {
429434
}))
430435
}
431436

432-
fn handle_extension_table(&mut self, table: &dyn TableSource) -> Result<ExtensionTable> {
433-
let bytes = self.serializer_registry.serialize_custom_table(table)?;
434-
Ok(ExtensionTable {
435-
detail: Some(Any {
436-
type_url: "/substrait.ExtensionTable".into(),
437-
value: bytes.into(),
438-
})
439-
})
437+
fn handle_custom_table(
438+
&mut self,
439+
table: &dyn TableSource,
440+
) -> Result<Option<ExtensionTable>> {
441+
if let Some(NamedBytes(qualifier, bytes)) =
442+
self.serializer_registry.serialize_custom_table(table)?
443+
{
444+
Ok(Some(ExtensionTable {
445+
detail: Some(ProtoAny {
446+
type_url: qualifier.to_string(),
447+
value: bytes.to_owned().into(),
448+
}),
449+
}))
450+
} else {
451+
Ok(None)
452+
}
440453
}
441454
}
442455

@@ -572,21 +585,31 @@ pub fn from_table_scan(
572585
let table_schema = scan.source.schema().to_dfschema_ref()?;
573586
let base_schema = to_substrait_named_struct(&table_schema)?;
574587

575-
let table = if let Ok(ext_table) = producer
576-
.handle_extension_table(scan.source.as_ref())
577-
{
578-
ReadType::ExtensionTable(ext_table)
579-
} else {
580-
ReadType::NamedTable(NamedTable {
581-
names: scan.table_name.to_vec(),
582-
advanced_extension: None,
583-
})
584-
};
585-
588+
let (table, common) =
589+
if let Ok(Some(ext_table)) = producer.handle_custom_table(scan.source.as_ref()) {
590+
(
591+
ReadType::ExtensionTable(ext_table),
592+
Some(RelCommon {
593+
hint: Some(Hint {
594+
alias: scan.table_name.to_string(),
595+
..Default::default()
596+
}),
597+
..Default::default()
598+
}),
599+
)
600+
} else {
601+
(
602+
ReadType::NamedTable(NamedTable {
603+
names: scan.table_name.to_vec(),
604+
advanced_extension: None,
605+
}),
606+
None,
607+
)
608+
};
586609

587610
Ok(Box::new(Rel {
588611
rel_type: Some(RelType::Read(Box::new(ReadRel {
589-
common: None,
612+
common,
590613
base_schema: Some(base_schema),
591614
filter: None,
592615
best_effort_filter: None,
@@ -1715,7 +1738,7 @@ pub fn from_in_subquery(
17151738
subquery_type: Some(
17161739
substrait::proto::expression::subquery::SubqueryType::InPredicate(
17171740
Box::new(InPredicate {
1718-
needles: (vec![substrait_expr]),
1741+
needles: vec![substrait_expr],
17191742
haystack: Some(subquery_plan),
17201743
}),
17211744
),
@@ -2909,6 +2932,7 @@ mod test {
29092932
#[tokio::test]
29102933
async fn round_trip_extension_table() {
29112934
const TABLE_NAME: &str = "custom_table";
2935+
const TYPE_URL: &str = "/substrait.test.CustomTable";
29122936
const SERIALIZED: &[u8] = "table definition".as_bytes();
29132937

29142938
fn custom_table() -> Arc<dyn TableProvider> {
@@ -2921,9 +2945,12 @@ mod test {
29212945
#[derive(Debug)]
29222946
struct Registry;
29232947
impl SerializerRegistry for Registry {
2924-
fn serialize_custom_table(&self, table: &dyn TableSource) -> Result<Vec<u8>> {
2948+
fn serialize_custom_table(
2949+
&self,
2950+
table: &dyn TableSource,
2951+
) -> Result<Option<NamedBytes>> {
29252952
if table.schema() == custom_table().schema() {
2926-
Ok(SERIALIZED.to_vec())
2953+
Ok(Some(NamedBytes(TYPE_URL.to_string(), SERIALIZED.to_vec())))
29272954
} else {
29282955
Err(DataFusionError::Internal("Not our table".into()))
29292956
}
@@ -2933,7 +2960,7 @@ mod test {
29332960
name: &str,
29342961
bytes: &[u8],
29352962
) -> Result<Arc<dyn TableSource>> {
2936-
if name == TABLE_NAME && bytes == SERIALIZED {
2963+
if name == TYPE_URL && bytes == SERIALIZED {
29372964
Ok(Arc::new(DefaultTableSource::new(custom_table())))
29382965
} else {
29392966
panic!("Unexpected extension table: {name}");
@@ -2965,7 +2992,7 @@ mod test {
29652992
assert_contains!(
29662993
// confirm that the Substrait plan contains our custom_table as an ExtensionTable
29672994
serde_json::to_string(substrait.as_ref()).unwrap(),
2968-
format!(r#""extensionTable":{{"detail":{{"typeUrl":"{TABLE_NAME}","#)
2995+
format!(r#""extensionTable":{{"detail":{{"typeUrl":"{TYPE_URL}","#)
29692996
);
29702997
remote // make sure the restored plan is fully working in the remote context
29712998
.execute_logical_plan(restored.clone())

0 commit comments

Comments
 (0)