Skip to content

Commit ec79e46

Browse files
ccciudatuadragomir
authored andcommitted
[substrait] Add support for ExtensionTable
1 parent ec2026a commit ec79e46

File tree

5 files changed

+289
-66
lines changed

5 files changed

+289
-66
lines changed

datafusion/core/src/execution/context/mod.rs

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ use datafusion_expr::{
6363
expr_rewriter::FunctionRewrite,
6464
logical_plan::{DdlStatement, Statement},
6565
planner::ExprPlanner,
66-
Expr, UserDefinedLogicalNode, WindowUDF,
66+
Expr, WindowUDF,
6767
};
6868

6969
// backwards compatibility
@@ -1682,27 +1682,7 @@ pub enum RegisterFunction {
16821682
#[derive(Debug)]
16831683
pub struct EmptySerializerRegistry;
16841684

1685-
impl SerializerRegistry for EmptySerializerRegistry {
1686-
fn serialize_logical_plan(
1687-
&self,
1688-
node: &dyn UserDefinedLogicalNode,
1689-
) -> Result<Vec<u8>> {
1690-
not_impl_err!(
1691-
"Serializing user defined logical plan node `{}` is not supported",
1692-
node.name()
1693-
)
1694-
}
1695-
1696-
fn deserialize_logical_plan(
1697-
&self,
1698-
name: &str,
1699-
_bytes: &[u8],
1700-
) -> Result<Arc<dyn UserDefinedLogicalNode>> {
1701-
not_impl_err!(
1702-
"Deserializing user defined logical plan node `{name}` is not supported"
1703-
)
1704-
}
1705-
}
1685+
impl SerializerRegistry for EmptySerializerRegistry {}
17061686

17071687
/// Describes which SQL statements can be run.
17081688
///

datafusion/expr/src/registry.rs

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
2020
use crate::expr_rewriter::FunctionRewrite;
2121
use crate::planner::ExprPlanner;
22-
use crate::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode, WindowUDF};
22+
use crate::{AggregateUDF, ScalarUDF, TableSource, UserDefinedLogicalNode, WindowUDF};
2323
use datafusion_common::{not_impl_err, plan_datafusion_err, HashMap, Result};
2424
use std::collections::HashSet;
2525
use std::fmt::Debug;
@@ -123,24 +123,58 @@ pub trait FunctionRegistry {
123123
}
124124
}
125125

126-
/// Serializer and deserializer registry for extensions like [UserDefinedLogicalNode].
126+
/// Serializer and deserializer registry for extensions like [UserDefinedLogicalNode]
127+
/// and custom table providers for which the name alone is meaningless in the target
128+
/// execution context, e.g. UDTFs, manually registered tables etc.
127129
pub trait SerializerRegistry: Debug + Send + Sync {
128130
/// Serialize this node to a byte array. This serialization should not include
129131
/// input plans.
130132
fn serialize_logical_plan(
131133
&self,
132134
node: &dyn UserDefinedLogicalNode,
133-
) -> Result<Vec<u8>>;
135+
) -> Result<NamedBytes> {
136+
not_impl_err!(
137+
"Serializing user defined logical plan node `{}` is not supported",
138+
node.name()
139+
)
140+
}
134141

135142
/// Deserialize user defined logical plan node ([UserDefinedLogicalNode]) from
136143
/// bytes.
137144
fn deserialize_logical_plan(
138145
&self,
139146
name: &str,
140-
bytes: &[u8],
141-
) -> Result<Arc<dyn UserDefinedLogicalNode>>;
147+
_bytes: &[u8],
148+
) -> Result<Arc<dyn UserDefinedLogicalNode>> {
149+
not_impl_err!(
150+
"Deserializing user defined logical plan node `{name}` is not supported"
151+
)
152+
}
153+
154+
/// Serialized table definition for UDTFs or some other table provider implementation that
155+
/// can't be marshaled by reference.
156+
fn serialize_custom_table(
157+
&self,
158+
_table: &dyn TableSource,
159+
) -> Result<Option<NamedBytes>> {
160+
Ok(None)
161+
}
162+
163+
/// Deserialize a custom table.
164+
fn deserialize_custom_table(
165+
&self,
166+
name: &str,
167+
_bytes: &[u8],
168+
) -> Result<Arc<dyn TableSource>> {
169+
not_impl_err!("Deserializing custom table `{name}` is not supported")
170+
}
142171
}
143172

173+
/// A sequence of bytes with a string qualifier. Meant to encapsulate serialized extensions
174+
/// that need to carry their type, e.g. the `type_url` for protobuf messages.
175+
#[derive(Debug, Clone)]
176+
pub struct NamedBytes(pub String, pub Vec<u8>);
177+
144178
/// A [`FunctionRegistry`] that uses in memory [`HashMap`]s
145179
#[derive(Default, Debug)]
146180
pub struct MemoryFunctionRegistry {

datafusion/substrait/src/logical_plan/consumer.rs

Lines changed: 69 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ use datafusion::logical_expr::expr::{Exists, InSubquery, Sort};
3030

3131
use datafusion::logical_expr::{
3232
Aggregate, BinaryExpr, Case, Cast, EmptyRelation, Expr, ExprSchemable, Extension,
33-
LogicalPlan, Operator, Projection, SortExpr, Subquery, TryCast, Values,
33+
LogicalPlan, Operator, Projection, SortExpr, Subquery, TableSource, TryCast, Values,
3434
};
3535
use substrait::proto::aggregate_rel::Grouping;
3636
use substrait::proto::expression as substrait_expression;
@@ -86,6 +86,7 @@ use substrait::proto::expression::{
8686
SingularOrList, SwitchExpression, WindowFunction,
8787
};
8888
use substrait::proto::read_rel::local_files::file_or_files::PathType::UriFile;
89+
use substrait::proto::read_rel::ExtensionTable;
8990
use substrait::proto::rel_common::{Emit, EmitKind};
9091
use substrait::proto::set_rel::SetOp;
9192
use substrait::proto::{
@@ -457,6 +458,20 @@ pub trait SubstraitConsumer: Send + Sync + Sized {
457458
user_defined_literal.type_reference
458459
)
459460
}
461+
462+
fn consume_extension_table(
463+
&self,
464+
extension_table: &ExtensionTable,
465+
) -> Result<Arc<dyn TableSource>> {
466+
if let Some(ext_detail) = extension_table.detail.as_ref() {
467+
substrait_err!(
468+
"Missing handler for extension table: {}",
469+
&ext_detail.type_url
470+
)
471+
} else {
472+
substrait_err!("Unexpected empty detail in ExtensionTable")
473+
}
474+
}
460475
}
461476

462477
/// Convert Substrait Rel to DataFusion DataFrame
@@ -578,6 +593,19 @@ impl SubstraitConsumer for DefaultSubstraitConsumer<'_> {
578593
let plan = plan.with_exprs_and_inputs(plan.expressions(), inputs)?;
579594
Ok(LogicalPlan::Extension(Extension { node: plan }))
580595
}
596+
597+
fn consume_extension_table(
598+
&self,
599+
extension_table: &ExtensionTable,
600+
) -> Result<Arc<dyn TableSource>> {
601+
if let Some(ext_detail) = &extension_table.detail {
602+
self.state
603+
.serializer_registry()
604+
.deserialize_custom_table(&ext_detail.type_url, &ext_detail.value)
605+
} else {
606+
substrait_err!("Unexpected empty detail in ExtensionTable")
607+
}
608+
}
581609
}
582610

583611
// Substrait PrecisionTimestampTz indicates that the timestamp is relative to UTC, which
@@ -1323,26 +1351,14 @@ pub async fn from_read_rel(
13231351
read: &ReadRel,
13241352
) -> Result<LogicalPlan> {
13251353
async fn read_with_schema(
1326-
consumer: &impl SubstraitConsumer,
13271354
table_ref: TableReference,
1355+
table_source: Arc<dyn TableSource>,
13281356
schema: DFSchema,
13291357
projection: &Option<MaskExpression>,
13301358
) -> Result<LogicalPlan> {
13311359
let schema = schema.replace_qualifier(table_ref.clone());
13321360

1333-
let plan = {
1334-
let provider = match consumer.resolve_table_ref(&table_ref).await? {
1335-
Some(ref provider) => Arc::clone(provider),
1336-
_ => return plan_err!("No table named '{table_ref}'"),
1337-
};
1338-
1339-
LogicalPlanBuilder::scan(
1340-
table_ref,
1341-
provider_as_source(Arc::clone(&provider)),
1342-
None,
1343-
)?
1344-
.build()?
1345-
};
1361+
let plan = { LogicalPlanBuilder::scan(table_ref, table_source, None)?.build()? };
13461362

13471363
ensure_schema_compatibility(plan.schema(), schema.clone())?;
13481364

@@ -1351,6 +1367,17 @@ pub async fn from_read_rel(
13511367
apply_projection(plan, schema)
13521368
}
13531369

1370+
async fn table_source(
1371+
consumer: &impl SubstraitConsumer,
1372+
table_ref: &TableReference,
1373+
) -> Result<Arc<dyn TableSource>> {
1374+
if let Some(provider) = consumer.resolve_table_ref(table_ref).await? {
1375+
Ok(provider_as_source(provider))
1376+
} else {
1377+
plan_err!("No table named '{table_ref}'")
1378+
}
1379+
}
1380+
13541381
let named_struct = read.base_schema.as_ref().ok_or_else(|| {
13551382
substrait_datafusion_err!("No base schema provided for Read Relation")
13561383
})?;
@@ -1376,10 +1403,10 @@ pub async fn from_read_rel(
13761403
table: nt.names[2].clone().into(),
13771404
},
13781405
};
1379-
1406+
let table_source = table_source(consumer, &table_reference).await?;
13801407
read_with_schema(
1381-
consumer,
13821408
table_reference,
1409+
table_source,
13831410
substrait_schema,
13841411
&read.projection,
13851412
)
@@ -1458,17 +1485,38 @@ pub async fn from_read_rel(
14581485
let name = filename.unwrap();
14591486
// directly use unwrap here since we could determine it is a valid one
14601487
let table_reference = TableReference::Bare { table: name.into() };
1488+
let table_source = table_source(consumer, &table_reference).await?;
14611489

14621490
read_with_schema(
1463-
consumer,
14641491
table_reference,
1492+
table_source,
1493+
substrait_schema,
1494+
&read.projection,
1495+
)
1496+
.await
1497+
}
1498+
Some(ReadType::ExtensionTable(ext)) => {
1499+
// look for the original table name under `rel.common.hint.alias`
1500+
// in case the producer was kind enough to put it there.
1501+
let name_hint = read
1502+
.common
1503+
.as_ref()
1504+
.and_then(|rel_common| rel_common.hint.as_ref())
1505+
.map(|hint| hint.alias.as_str().trim())
1506+
.filter(|alias| !alias.is_empty());
1507+
// if no name hint was provided, use the name that datafusion
1508+
// sets for UDTFs
1509+
let table_name = name_hint.unwrap_or("tmp_table");
1510+
read_with_schema(
1511+
TableReference::from(table_name),
1512+
consumer.consume_extension_table(ext)?,
14651513
substrait_schema,
14661514
&read.projection,
14671515
)
14681516
.await
14691517
}
1470-
_ => {
1471-
not_impl_err!("Unsupported ReadType: {:?}", read.read_type)
1518+
None => {
1519+
substrait_err!("Unexpected empty read_type")
14721520
}
14731521
}
14741522
}
@@ -1871,7 +1919,7 @@ pub async fn from_substrait_sorts(
18711919
},
18721920
None => not_impl_err!("Sort without sort kind is invalid"),
18731921
};
1874-
let (asc, nulls_first) = asc_nullfirst.unwrap();
1922+
let (asc, nulls_first) = asc_nullfirst?;
18751923
sorts.push(Sort {
18761924
expr,
18771925
asc,

0 commit comments

Comments
 (0)