Skip to content

feat: Add concat(..., how="*_relaxed"})#3398

Draft
FBruzzesi wants to merge 24 commits intodtypes/supertypingfrom
feat/supertyping-relaxed-concat
Draft

feat: Add concat(..., how="*_relaxed"})#3398
FBruzzesi wants to merge 24 commits intodtypes/supertypingfrom
feat/supertyping-relaxed-concat

Conversation

@FBruzzesi
Copy link
Member

@FBruzzesi FBruzzesi commented Jan 11, 2026

Description

As Marco often says, in for a penny in for a pound 😂

This PR builds up on #3396 to showcase how that's useful to implement concat relaxed versions.
It's probably possible to lower some of the repetition, but considering how much time I spent on this (not sooo much so far), I am quite happy with the shape of it.

On top of the new feature there are some minor adjustment for the concat implementations across backends

TODO:

  • Unknown types for narwhals should keep the original backend dtype and let such backend to deal with it. See comment feat: Add concat(..., how="*_relaxed"}) #3398 (comment)
  • Backward compatibility: this is more a concern for how={"vertical", "diagonal"}, and would be a followup to this PR

What type of PR is this? (check all applicable)

  • 💾 Refactor
  • ✨ Feature
  • 🐛 Bug Fix
  • 🔧 Optimization
  • 📝 Documentation
  • ✅ Test
  • 🐳 Other

Related issues

Checklist

  • Code follows style guide (ruff)
  • Tests added
  • Documented the changes

@FBruzzesi FBruzzesi added the enhancement New feature or request label Jan 11, 2026
Comment on lines +712 to +716
_DTYPE_BACKEND_PRIORITY: dict[DTypeBackend, Literal[0, 1, 2]] = {
"pyarrow": 2,
"numpy_nullable": 1,
None: 0,
}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strong opinions here!

Disclaimer: pandas automatically does something similar already when concatenating dataframes vertically:

  • {"pyarrow", None} -> "pyarrow"
  • {"numpy_nullable", None} -> "numpy_nullable"
  • {"numpy_nullable", "pyarrow"} -> dtype('O') 😭 WHY THOUGH?

return Schema(into_out_schema)


def combine_schemas(left: Schema, right: Schema) -> tuple[Schema, Schema]:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reading this over, combine is definitely the wrong verb here as we end up getting two schemas and they preserve their original dtypes.

I think my first idea was to follow the meshing of the schemas with the to_supertype call, which indeed is always happening when we call this function in the codebase. I will think about a better name but open to suggestions

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wasn't this called merge_schemas or something in polars?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will think about a better name but open to suggestions

@FBruzzesi I've got an idea that does a similar thing to (perf: Avoid unnecessary lambdas)

Since every use of combine_schema currently looks like this:

lambda x, y: to_supertype(*combine_schemas(x, y)),

We can just define that as the function that gets passed to reduce instead.

I used the name to_supertype_diagonal for now, but I'm not too attached it 😄

Show diff (sorry, very big)

diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py
index bc450265a..aabb81ec2 100644
--- a/narwhals/_arrow/namespace.py
+++ b/narwhals/_arrow/namespace.py
@@ -19,7 +19,7 @@ from narwhals._expression_parsing import (
     combine_evaluate_output_names,
 )
 from narwhals._utils import Implementation
-from narwhals.schema import Schema, combine_schemas, to_supertype
+from narwhals.schema import Schema, to_supertype, to_supertype_diagonal
 
 if TYPE_CHECKING:
     from collections.abc import Iterator, Sequence
@@ -180,7 +180,7 @@ class ArrowNamespace(
     def _concat_diagonal_relaxed(self, dfs: Sequence[pa.Table], /) -> pa.Table:
         native_schemas = tuple(table.schema for table in dfs)
         out_schema = reduce(
-            lambda x, y: to_supertype(*combine_schemas(x, y)),
+            to_supertype_diagonal,
             (Schema.from_arrow(pa_schema) for pa_schema in native_schemas),
         ).to_arrow()
         to_schemas = (
diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py
index 5254744ea..5114ee435 100644
--- a/narwhals/_dask/namespace.py
+++ b/narwhals/_dask/namespace.py
@@ -24,7 +24,7 @@ from narwhals._expression_parsing import (
 )
 from narwhals._pandas_like.utils import promote_dtype_backend
 from narwhals._utils import Implementation, zip_strict
-from narwhals.schema import Schema, combine_schemas, to_supertype
+from narwhals.schema import Schema, to_supertype, to_supertype_diagonal
 
 if TYPE_CHECKING:
     from collections.abc import Iterable, Iterator
@@ -177,7 +177,7 @@ class DaskNamespace(
             dtypes = tuple(df.dtypes.to_dict() for df in dfs)
             dtype_backend = promote_dtype_backend(dfs, self._implementation)
             out_schema = reduce(
-                lambda x, y: to_supertype(*combine_schemas(x, y)),
+                to_supertype_diagonal,
                 (Schema.from_pandas_like(dtype) for dtype in dtypes),
             ).to_pandas(dtype_backend=dtype_backend.values())
 
diff --git a/narwhals/_duckdb/namespace.py b/narwhals/_duckdb/namespace.py
index 41c2e8d29..37f773de8 100644
--- a/narwhals/_duckdb/namespace.py
+++ b/narwhals/_duckdb/namespace.py
@@ -27,7 +27,7 @@ from narwhals._expression_parsing import (
 )
 from narwhals._sql.namespace import SQLNamespace
 from narwhals._utils import Implementation
-from narwhals.schema import Schema, combine_schemas, to_supertype
+from narwhals.schema import Schema, to_supertype, to_supertype_diagonal
 
 if TYPE_CHECKING:
     from collections.abc import Iterable
@@ -116,9 +116,7 @@ class DuckDBNamespace(
 
         if how == "diagonal_relaxed":
             schemas = [Schema(df.collect_schema()) for df in items]
-            out_schema = reduce(
-                lambda x, y: to_supertype(*combine_schemas(x, y)), schemas
-            )
+            out_schema = reduce(to_supertype_diagonal, schemas)
             res, *others = (
                 item.select(
                     *(
diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py
index 293589471..d43e42d5a 100644
--- a/narwhals/_pandas_like/namespace.py
+++ b/narwhals/_pandas_like/namespace.py
@@ -24,7 +24,7 @@ from narwhals._pandas_like.utils import (
     promote_dtype_backend,
 )
 from narwhals._utils import zip_strict
-from narwhals.schema import Schema, combine_schemas, to_supertype
+from narwhals.schema import Schema, to_supertype, to_supertype_diagonal
 
 if TYPE_CHECKING:
     from collections.abc import Iterable, Sequence
@@ -270,8 +270,7 @@ class PandasLikeNamespace(
         dtypes = tuple(native_schema(df) for df in dfs)
         dtype_backend = promote_dtype_backend(dfs, self._implementation)
         out_schema = reduce(
-            lambda x, y: to_supertype(*combine_schemas(x, y)),
-            (Schema.from_pandas_like(dtype) for dtype in dtypes),
+            to_supertype_diagonal, (Schema.from_pandas_like(dtype) for dtype in dtypes)
         ).to_pandas(dtype_backend=dtype_backend.values())
 
         native_res = (
diff --git a/narwhals/_spark_like/namespace.py b/narwhals/_spark_like/namespace.py
index 275f82589..e67a99e29 100644
--- a/narwhals/_spark_like/namespace.py
+++ b/narwhals/_spark_like/namespace.py
@@ -19,7 +19,7 @@ from narwhals._spark_like.utils import (
     true_divide,
 )
 from narwhals._sql.namespace import SQLNamespace
-from narwhals.schema import Schema, combine_schemas, to_supertype
+from narwhals.schema import Schema, to_supertype, to_supertype_diagonal
 
 if TYPE_CHECKING:
     from collections.abc import Iterable
@@ -195,9 +195,7 @@ class SparkLikeNamespace(
 
         if how == "diagonal_relaxed":
             schemas = tuple(Schema(item.collect_schema()) for item in items)
-            out_schema = reduce(
-                lambda x, y: to_supertype(*combine_schemas(x, y)), schemas
-            )
+            out_schema = reduce(to_supertype_diagonal, schemas)
             native_items = (
                 item.select(
                     *(
diff --git a/narwhals/schema.py b/narwhals/schema.py
index 3f5ea8650..59a589034 100644
--- a/narwhals/schema.py
+++ b/narwhals/schema.py
@@ -387,20 +387,21 @@ def _ensure_names_match(left: Schema, right: Schema) -> tuple[Schema, Schema]:
 
 
 def to_supertype(left: Schema, right: Schema) -> Schema:
+    """Take two schemas and try to find the supertypes between them."""
     # Adapted from polars https://github.com/pola-rs/polars/blob/c2412600210a21143835c9dfcb0a9182f462b619/crates/polars-core/src/schema/mod.rs#L83-L96
     left, right = _ensure_names_match(left, right)
     it = zip(left.keys(), left.values(), right.values())
     return Schema((name, _supertype(ltype, rtype)) for (name, ltype, rtype) in it)
 
 
-def combine_schemas(left: Schema, right: Schema) -> tuple[Schema, Schema]:
-    """Extend both schemas with names and dtypes missing from the other.
+def to_supertype_diagonal(left: Schema, right: Schema) -> Schema:
+    """Align schemas for `concat(how="diagonal*")` and try to find the supertypes between them.
 
-    Returns a tuple of two schemas where each original schema is extended
-    with the columns that exist in the other schema but not in itself.
+    Both schemas are extended with the columns that exist in the other schema but not in itself.
 
-    The final order for both schemas is: left schema keys first (in order),
-    followed by keys missing from left (in the order they appear in right).
+    The final order for both schemas is:
+    - left schema keys first
+    - followed by keys missing from left (in the order they appear in right)
     """
     left_names = set(left.keys())
     missing_in_left = (kv for kv in right.items() if kv[0] not in left_names)
@@ -408,4 +409,4 @@ def combine_schemas(left: Schema, right: Schema) -> tuple[Schema, Schema]:
     extended_left = Schema((*left.items(), *missing_in_left))
     # Reorder right to match: left keys first, then right-only keys
     extended_right = Schema((kv[0], right.get(*kv)) for kv in extended_left.items())
-    return extended_left, extended_right
+    return to_supertype(extended_left, extended_right)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used the name to_supertype_diagonal for now, but I'm not too attached it 😄

I don't love it either 😂

What if we just call it merge_schemas as polars, which adds the to_supertype at the end of the current combine_schemas implementation?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if we just call it merge_schemas as polars, which adds the to_supertype at the end of the current combine_schemas implementation?

I was wrong earlier in (#3398 (comment)).

AFAICT, merge_schemas is only used for "horizontal".

What I was trying to say yesterday on discord:

(1) Schema.to_supertype operates on the simplest case

By this I mean, we only have to do this for concat(how="vertical")

Was that polars uses to_supertype for both paths.
"diagonal" goes through convert_diagonal_concat first and then both go through the same convert_st_union bit


I'm not sure if that helps anything - just trying to correct myself 😅

Copy link
Member

@dangotbanned dangotbanned Jan 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was that polars uses to_supertype for both paths.

"diagonal" goes through convert_diagonal_concat first and then both go through the same convert_st_union bit

@FBruzzesi

Okay I just pulled at this thread a bit locally and now have a very rough narwhals port of convert_diagonal_concat.

It's operating over BaseFrame atm, but probably belongs in CompliantNamespace (and using the compliant-level).

I've tested with ibis and I'm pretty sure we can support both "diagonal*" variants this way 🫰

Show diff

diff --git a/narwhals/schema.py b/narwhals/schema.py
index 59a589034..5fead8f4e 100644
--- a/narwhals/schema.py
+++ b/narwhals/schema.py
@@ -7,7 +7,7 @@ https://github.com/pola-rs/polars/blob/main/py-polars/polars/schema.py.
 from __future__ import annotations
 
 from collections import OrderedDict
-from collections.abc import Mapping
+from collections.abc import Mapping, Sequence
 from functools import partial
 from typing import TYPE_CHECKING, cast
 
@@ -26,13 +26,15 @@ from narwhals.exceptions import ComputeError, SchemaMismatchError
 
 if TYPE_CHECKING:
     from collections.abc import Iterable
-    from typing import Any, ClassVar
+    from typing import Any, ClassVar, TypeVar
 
     import polars as pl
     import pyarrow as pa
     from typing_extensions import Self
 
+    from narwhals.dataframe import BaseFrame
     from narwhals.dtypes import DType
+    from narwhals.expr import Expr
     from narwhals.typing import (
         DTypeBackend,
         IntoArrowSchema,
@@ -40,6 +42,8 @@ if TYPE_CHECKING:
         IntoPolarsSchema,
     )
 
+    FrameT = TypeVar("FrameT", bound="BaseFrame[Any]")
+
 
 __all__ = ["Schema"]
 
@@ -410,3 +414,44 @@ def to_supertype_diagonal(left: Schema, right: Schema) -> Schema:
     # Reorder right to match: left keys first, then right-only keys
     extended_right = Schema((kv[0], right.get(*kv)) for kv in extended_left.items())
     return to_supertype(extended_left, extended_right)
+
+
+def convert_concat_diagonal(frames: Sequence[FrameT]) -> Sequence[FrameT]:
+    """Adapted from [`convert_diagonal_concat`].
+
+    [`convert_diagonal_concat`]: https://github.com/pola-rs/polars/blob/c2412600210a21143835c9dfcb0a9182f462b619/crates/polars-plan/src/plans/conversion/dsl_to_ir/concat.rs#L10-L68
+    """
+    import narwhals.functions as nw_f
+
+    schemas = [frame.collect_schema() for frame in frames]
+    it_schemas = iter(schemas)
+    total_schema = dict(next(it_schemas))
+    seen_names = set(total_schema)
+    to_add_fields: dict[str, DType] = {}
+    for sch in it_schemas:
+        to_add_fields.update(
+            {name: dtype for name, dtype in sch.items() if name not in seen_names}
+        )
+        seen_names.update(to_add_fields)
+    if not seen_names:
+        return frames
+    total_schema.update(to_add_fields)
+    total_names = tuple(total_schema)
+    added_exprs: dict[str, Expr] = {}
+    results: list[FrameT] = []
+    for frame, schema in zip(frames, schemas):
+        to_add_exprs: list[Expr] = []
+        for name, dtype in total_schema.items():
+            if name not in schema:
+                maybe_seen = added_exprs.get(name)
+                if maybe_seen is None:
+                    to_add_expr = nw_f.lit(None, dtype).alias(name)
+                    to_add_exprs.append(to_add_expr)
+                    added_exprs[name] = to_add_expr
+                else:
+                    to_add_exprs.append(maybe_seen)
+        result = frame
+        if to_add_exprs:
+            result = result.with_columns(to_add_exprs)
+        results.append(result.select(total_names))
+    return results

We end up converting "diagonal*" -> "vertical*" (IIUC)

Updated

I've got the Compliant* version working now over on (https://github.com/narwhals-dev/narwhals/compare/convert-concat-diagonal)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @dangotbanned - on a pragmatic side, I would say to keep this as a follow up separated from this PR. My understanding from now is that it would only benefit ibis. I would rather have correctness in relaxed and "strict" modes first, and then expand to ibis. Although convert_concat_diagonal might be re-used in duckdb and spark-like as well, I would argue that the current implementation is a bit easier to read and maintain 🙈

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @dangotbanned - on a pragmatic side, I would say to keep this as a follow up separated from this PR.

Absolutely!

I would argue that the current implementation is a bit easier to read

Agreed on that part too @FBruzzesi
I've just finished cleaning things up some and documenting in these few commits (f5c6407...convert-concat-diagonal)


My understanding from now is that it would only benefit ibis.

For how="diagonal"?
You might be right there, but I was thinking this replaces the need for combine_schemas entirely.
That's used in all of the backends (besides polars).

Is combine_schemas shorter than align_diagonal?
Definitely! 😅

But it is also called n_schemas - 1 times per concat.
By my count, each of those calls does:

  • 1x Schema.keys()
  • 3x Schema.items()
  • 2x Schema(...)

I was quite alarmed by that when I noticed it 😳


So in my head, that ^^^

... and all of the extra stuff here ...

def _concat_diagonal_relaxed(self, dfs: Sequence[pa.Table], /) -> pa.Table:
native_schemas = tuple(table.schema for table in dfs)
out_schema = reduce(
lambda x, y: to_supertype(*combine_schemas(x, y)),
(Schema.from_arrow(pa_schema) for pa_schema in native_schemas),
).to_arrow()
to_schemas = (
pa.schema([out_schema.field(name) for name in native_schema.names])
for native_schema in native_schemas
)
to_concat = tuple(
table.cast(to_schema) for table, to_schema in zip(dfs, to_schemas)
)
return self._concat_diagonal(to_concat)

... seems like it could be avoided and we just do align_diagonal -> _concat_vertical_relaxed?

def _concat_vertical_relaxed(self, dfs: Sequence[pa.Table], /) -> pa.Table:
out_schema = reduce(
to_supertype, (Schema.from_arrow(table.schema) for table in dfs)
).to_arrow()
return pa.concat_tables([table.cast(out_schema) for table in dfs])

Note

I would rather fully disolve 🫠 than benchmark again, so I'm just asking to consider how many iterations each is doing 🙏

Comment on lines 17 to 18
def _cast(frame: LazyFrameT, schema: Schema) -> LazyFrameT:
return frame.select(nw.col(name).cast(dtype) for name, dtype in schema.items())
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still cry each time I try to reach out for {DataFrame,LazyFrame}.cast 😂

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want it too 😢

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@FBruzzesi let's revive #1045 as a new issue.

The example in that test isn't the motivating part.


What's here could be both more efficient and shorter

def _concat_diagonal_relaxed(self, dfs: Sequence[pa.Table], /) -> pa.Table:
native_schemas = tuple(table.schema for table in dfs)
out_schema = reduce(
lambda x, y: to_supertype(*combine_schemas(x, y)),
(Schema.from_arrow(pa_schema) for pa_schema in native_schemas),
).to_arrow()
to_schemas = (
pa.schema([out_schema.field(name) for name in native_schema.names])
for native_schema in native_schemas
)
to_concat = tuple(
table.cast(to_schema) for table, to_schema in zip(dfs, to_schemas)
)
return self._concat_diagonal(to_concat)

+++ b/narwhals/_arrow/dataframe.py
+    def cast(self, schema: IntoSchema | Iterable[tuple[str, DType]]) -> Self:
+        return self._with_native(
+            self.native.cast(Schema(schema).to_arrow()), validate_column_names=False
+        )
+

+++ b/narwhals/_arrow/namespace.py
-from narwhals._arrow.utils import cast_to_comparable_string_types
+from narwhals._arrow.utils import cast_to_comparable_string_types, concat_tables
-from narwhals.schema import Schema, combine_schemas, to_supertype
+from narwhals.schema import Schema, to_supertype, to_supertype_diagonal

With that, we can avoid materializing as much all at once

    def concat_diagonal_relaxed(self, dfs: Sequence[ArrowDataFrame], /) -> ArrowDataFrame:
        schemas = [Schema(df.schema) for df in dfs]
        supertypes = reduce(to_supertype_diagonal, schemas)
        to_concat = (
            df.cast((name, supertypes[name]) for name in schema).native
            for df, schema in zip(dfs, schemas)
        )
        return dfs[0]._with_native(concat_tables(to_concat))

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise I am going to forget: #3402

- Reuse the schemas we've already collected
  - And rename the variables so it is obvious
- Skip checking if columns exist for the first schema
- Use a generator inside `dict.update` (instead of many `__setitem__`s)
- Use `iter_dtype_backends` instead of  creating a new `partial`

 Unrelated to performance: Return `.values()` instead of a `dict` (no usage of the keys anywhere?)
assert_equal_data(result.collect(), expected_data)


def test_pyarrow_concat_vertical_uuid() -> None:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MarcoGorelli I added 3 tests (pyarrow, duckdb and ibis) to check if the changes in 4a1b946 and b205ddb support what we mentioned yesterday during the community call. Namely that datatypes that are not supported by narwhals are preserved.

For polars, pandas and dask I am not applying these "workarounds" since we should support all their datatypes*

*Exception is made for poalrs Float16 and Null, but also for polars we just dispatch the method without touching the datatypes. Am I missing something for pandas and dask?

Copy link
Member

@dangotbanned dangotbanned Jan 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pd.PeriodDtype?

But more generally, pandas supports registering extension types and so does polars (recently)

ns: CompliantNamespace[CompliantFrameT, CompliantExprT],
mapping: Mapping[str, dtypes.DType],
) -> Iterable[CompliantExprT]:
Unknown = ns._version.dtypes.Unknown() # noqa: N806
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants