feat: Add concat(..., how="*_relaxed"})#3398
feat: Add concat(..., how="*_relaxed"})#3398FBruzzesi wants to merge 24 commits intodtypes/supertypingfrom
concat(..., how="*_relaxed"})#3398Conversation
| _DTYPE_BACKEND_PRIORITY: dict[DTypeBackend, Literal[0, 1, 2]] = { | ||
| "pyarrow": 2, | ||
| "numpy_nullable": 1, | ||
| None: 0, | ||
| } |
There was a problem hiding this comment.
Strong opinions here!
Disclaimer: pandas automatically does something similar already when concatenating dataframes vertically:
{"pyarrow", None} -> "pyarrow"{"numpy_nullable", None} -> "numpy_nullable"{"numpy_nullable", "pyarrow"} -> dtype('O')😭 WHY THOUGH?
narwhals/schema.py
Outdated
| return Schema(into_out_schema) | ||
|
|
||
|
|
||
| def combine_schemas(left: Schema, right: Schema) -> tuple[Schema, Schema]: |
There was a problem hiding this comment.
Reading this over, combine is definitely the wrong verb here as we end up getting two schemas and they preserve their original dtypes.
I think my first idea was to follow the meshing of the schemas with the to_supertype call, which indeed is always happening when we call this function in the codebase. I will think about a better name but open to suggestions
There was a problem hiding this comment.
Wasn't this called merge_schemas or something in polars?
There was a problem hiding this comment.
I will think about a better name but open to suggestions
@FBruzzesi I've got an idea that does a similar thing to (perf: Avoid unnecessary lambdas)
Since every use of combine_schema currently looks like this:
lambda x, y: to_supertype(*combine_schemas(x, y)),We can just define that as the function that gets passed to reduce instead.
I used the name to_supertype_diagonal for now, but I'm not too attached it 😄
Show diff (sorry, very big)
diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py
index bc450265a..aabb81ec2 100644
--- a/narwhals/_arrow/namespace.py
+++ b/narwhals/_arrow/namespace.py
@@ -19,7 +19,7 @@ from narwhals._expression_parsing import (
combine_evaluate_output_names,
)
from narwhals._utils import Implementation
-from narwhals.schema import Schema, combine_schemas, to_supertype
+from narwhals.schema import Schema, to_supertype, to_supertype_diagonal
if TYPE_CHECKING:
from collections.abc import Iterator, Sequence
@@ -180,7 +180,7 @@ class ArrowNamespace(
def _concat_diagonal_relaxed(self, dfs: Sequence[pa.Table], /) -> pa.Table:
native_schemas = tuple(table.schema for table in dfs)
out_schema = reduce(
- lambda x, y: to_supertype(*combine_schemas(x, y)),
+ to_supertype_diagonal,
(Schema.from_arrow(pa_schema) for pa_schema in native_schemas),
).to_arrow()
to_schemas = (
diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py
index 5254744ea..5114ee435 100644
--- a/narwhals/_dask/namespace.py
+++ b/narwhals/_dask/namespace.py
@@ -24,7 +24,7 @@ from narwhals._expression_parsing import (
)
from narwhals._pandas_like.utils import promote_dtype_backend
from narwhals._utils import Implementation, zip_strict
-from narwhals.schema import Schema, combine_schemas, to_supertype
+from narwhals.schema import Schema, to_supertype, to_supertype_diagonal
if TYPE_CHECKING:
from collections.abc import Iterable, Iterator
@@ -177,7 +177,7 @@ class DaskNamespace(
dtypes = tuple(df.dtypes.to_dict() for df in dfs)
dtype_backend = promote_dtype_backend(dfs, self._implementation)
out_schema = reduce(
- lambda x, y: to_supertype(*combine_schemas(x, y)),
+ to_supertype_diagonal,
(Schema.from_pandas_like(dtype) for dtype in dtypes),
).to_pandas(dtype_backend=dtype_backend.values())
diff --git a/narwhals/_duckdb/namespace.py b/narwhals/_duckdb/namespace.py
index 41c2e8d29..37f773de8 100644
--- a/narwhals/_duckdb/namespace.py
+++ b/narwhals/_duckdb/namespace.py
@@ -27,7 +27,7 @@ from narwhals._expression_parsing import (
)
from narwhals._sql.namespace import SQLNamespace
from narwhals._utils import Implementation
-from narwhals.schema import Schema, combine_schemas, to_supertype
+from narwhals.schema import Schema, to_supertype, to_supertype_diagonal
if TYPE_CHECKING:
from collections.abc import Iterable
@@ -116,9 +116,7 @@ class DuckDBNamespace(
if how == "diagonal_relaxed":
schemas = [Schema(df.collect_schema()) for df in items]
- out_schema = reduce(
- lambda x, y: to_supertype(*combine_schemas(x, y)), schemas
- )
+ out_schema = reduce(to_supertype_diagonal, schemas)
res, *others = (
item.select(
*(
diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py
index 293589471..d43e42d5a 100644
--- a/narwhals/_pandas_like/namespace.py
+++ b/narwhals/_pandas_like/namespace.py
@@ -24,7 +24,7 @@ from narwhals._pandas_like.utils import (
promote_dtype_backend,
)
from narwhals._utils import zip_strict
-from narwhals.schema import Schema, combine_schemas, to_supertype
+from narwhals.schema import Schema, to_supertype, to_supertype_diagonal
if TYPE_CHECKING:
from collections.abc import Iterable, Sequence
@@ -270,8 +270,7 @@ class PandasLikeNamespace(
dtypes = tuple(native_schema(df) for df in dfs)
dtype_backend = promote_dtype_backend(dfs, self._implementation)
out_schema = reduce(
- lambda x, y: to_supertype(*combine_schemas(x, y)),
- (Schema.from_pandas_like(dtype) for dtype in dtypes),
+ to_supertype_diagonal, (Schema.from_pandas_like(dtype) for dtype in dtypes)
).to_pandas(dtype_backend=dtype_backend.values())
native_res = (
diff --git a/narwhals/_spark_like/namespace.py b/narwhals/_spark_like/namespace.py
index 275f82589..e67a99e29 100644
--- a/narwhals/_spark_like/namespace.py
+++ b/narwhals/_spark_like/namespace.py
@@ -19,7 +19,7 @@ from narwhals._spark_like.utils import (
true_divide,
)
from narwhals._sql.namespace import SQLNamespace
-from narwhals.schema import Schema, combine_schemas, to_supertype
+from narwhals.schema import Schema, to_supertype, to_supertype_diagonal
if TYPE_CHECKING:
from collections.abc import Iterable
@@ -195,9 +195,7 @@ class SparkLikeNamespace(
if how == "diagonal_relaxed":
schemas = tuple(Schema(item.collect_schema()) for item in items)
- out_schema = reduce(
- lambda x, y: to_supertype(*combine_schemas(x, y)), schemas
- )
+ out_schema = reduce(to_supertype_diagonal, schemas)
native_items = (
item.select(
*(
diff --git a/narwhals/schema.py b/narwhals/schema.py
index 3f5ea8650..59a589034 100644
--- a/narwhals/schema.py
+++ b/narwhals/schema.py
@@ -387,20 +387,21 @@ def _ensure_names_match(left: Schema, right: Schema) -> tuple[Schema, Schema]:
def to_supertype(left: Schema, right: Schema) -> Schema:
+ """Take two schemas and try to find the supertypes between them."""
# Adapted from polars https://github.com/pola-rs/polars/blob/c2412600210a21143835c9dfcb0a9182f462b619/crates/polars-core/src/schema/mod.rs#L83-L96
left, right = _ensure_names_match(left, right)
it = zip(left.keys(), left.values(), right.values())
return Schema((name, _supertype(ltype, rtype)) for (name, ltype, rtype) in it)
-def combine_schemas(left: Schema, right: Schema) -> tuple[Schema, Schema]:
- """Extend both schemas with names and dtypes missing from the other.
+def to_supertype_diagonal(left: Schema, right: Schema) -> Schema:
+ """Align schemas for `concat(how="diagonal*")` and try to find the supertypes between them.
- Returns a tuple of two schemas where each original schema is extended
- with the columns that exist in the other schema but not in itself.
+ Both schemas are extended with the columns that exist in the other schema but not in itself.
- The final order for both schemas is: left schema keys first (in order),
- followed by keys missing from left (in the order they appear in right).
+ The final order for both schemas is:
+ - left schema keys first
+ - followed by keys missing from left (in the order they appear in right)
"""
left_names = set(left.keys())
missing_in_left = (kv for kv in right.items() if kv[0] not in left_names)
@@ -408,4 +409,4 @@ def combine_schemas(left: Schema, right: Schema) -> tuple[Schema, Schema]:
extended_left = Schema((*left.items(), *missing_in_left))
# Reorder right to match: left keys first, then right-only keys
extended_right = Schema((kv[0], right.get(*kv)) for kv in extended_left.items())
- return extended_left, extended_right
+ return to_supertype(extended_left, extended_right)
There was a problem hiding this comment.
I used the name
to_supertype_diagonalfor now, but I'm not too attached it 😄
I don't love it either 😂
What if we just call it merge_schemas as polars, which adds the to_supertype at the end of the current combine_schemas implementation?
There was a problem hiding this comment.
What if we just call it
merge_schemasas polars, which adds theto_supertypeat the end of the currentcombine_schemasimplementation?
I was wrong earlier in (#3398 (comment)).
AFAICT, merge_schemas is only used for "horizontal".
What I was trying to say yesterday on discord:
(1)
Schema.to_supertypeoperates on the simplest caseBy this I mean, we only have to do this for
concat(how="vertical")
Was that polars uses to_supertype for both paths.
"diagonal" goes through convert_diagonal_concat first and then both go through the same convert_st_union bit
I'm not sure if that helps anything - just trying to correct myself 😅
There was a problem hiding this comment.
Was that
polarsusesto_supertypefor both paths.
"diagonal"goes throughconvert_diagonal_concatfirst and then both go through the sameconvert_st_unionbit
Okay I just pulled at this thread a bit locally and now have a very rough narwhals port of convert_diagonal_concat.
It's operating over BaseFrame atm, but probably belongs in CompliantNamespace (and using the compliant-level).
I've tested with ibis and I'm pretty sure we can support both "diagonal*" variants this way 🫰
Show diff
diff --git a/narwhals/schema.py b/narwhals/schema.py
index 59a589034..5fead8f4e 100644
--- a/narwhals/schema.py
+++ b/narwhals/schema.py
@@ -7,7 +7,7 @@ https://github.com/pola-rs/polars/blob/main/py-polars/polars/schema.py.
from __future__ import annotations
from collections import OrderedDict
-from collections.abc import Mapping
+from collections.abc import Mapping, Sequence
from functools import partial
from typing import TYPE_CHECKING, cast
@@ -26,13 +26,15 @@ from narwhals.exceptions import ComputeError, SchemaMismatchError
if TYPE_CHECKING:
from collections.abc import Iterable
- from typing import Any, ClassVar
+ from typing import Any, ClassVar, TypeVar
import polars as pl
import pyarrow as pa
from typing_extensions import Self
+ from narwhals.dataframe import BaseFrame
from narwhals.dtypes import DType
+ from narwhals.expr import Expr
from narwhals.typing import (
DTypeBackend,
IntoArrowSchema,
@@ -40,6 +42,8 @@ if TYPE_CHECKING:
IntoPolarsSchema,
)
+ FrameT = TypeVar("FrameT", bound="BaseFrame[Any]")
+
__all__ = ["Schema"]
@@ -410,3 +414,44 @@ def to_supertype_diagonal(left: Schema, right: Schema) -> Schema:
# Reorder right to match: left keys first, then right-only keys
extended_right = Schema((kv[0], right.get(*kv)) for kv in extended_left.items())
return to_supertype(extended_left, extended_right)
+
+
+def convert_concat_diagonal(frames: Sequence[FrameT]) -> Sequence[FrameT]:
+ """Adapted from [`convert_diagonal_concat`].
+
+ [`convert_diagonal_concat`]: https://github.com/pola-rs/polars/blob/c2412600210a21143835c9dfcb0a9182f462b619/crates/polars-plan/src/plans/conversion/dsl_to_ir/concat.rs#L10-L68
+ """
+ import narwhals.functions as nw_f
+
+ schemas = [frame.collect_schema() for frame in frames]
+ it_schemas = iter(schemas)
+ total_schema = dict(next(it_schemas))
+ seen_names = set(total_schema)
+ to_add_fields: dict[str, DType] = {}
+ for sch in it_schemas:
+ to_add_fields.update(
+ {name: dtype for name, dtype in sch.items() if name not in seen_names}
+ )
+ seen_names.update(to_add_fields)
+ if not seen_names:
+ return frames
+ total_schema.update(to_add_fields)
+ total_names = tuple(total_schema)
+ added_exprs: dict[str, Expr] = {}
+ results: list[FrameT] = []
+ for frame, schema in zip(frames, schemas):
+ to_add_exprs: list[Expr] = []
+ for name, dtype in total_schema.items():
+ if name not in schema:
+ maybe_seen = added_exprs.get(name)
+ if maybe_seen is None:
+ to_add_expr = nw_f.lit(None, dtype).alias(name)
+ to_add_exprs.append(to_add_expr)
+ added_exprs[name] = to_add_expr
+ else:
+ to_add_exprs.append(maybe_seen)
+ result = frame
+ if to_add_exprs:
+ result = result.with_columns(to_add_exprs)
+ results.append(result.select(total_names))
+ return resultsWe end up converting "diagonal*" -> "vertical*" (IIUC)
Updated
I've got the Compliant* version working now over on (https://github.com/narwhals-dev/narwhals/compare/convert-concat-diagonal)
There was a problem hiding this comment.
Thanks @dangotbanned - on a pragmatic side, I would say to keep this as a follow up separated from this PR. My understanding from now is that it would only benefit ibis. I would rather have correctness in relaxed and "strict" modes first, and then expand to ibis. Although convert_concat_diagonal might be re-used in duckdb and spark-like as well, I would argue that the current implementation is a bit easier to read and maintain 🙈
There was a problem hiding this comment.
Thanks @dangotbanned - on a pragmatic side, I would say to keep this as a follow up separated from this PR.
Absolutely!
I would argue that the current implementation is a bit easier to read
Agreed on that part too @FBruzzesi
I've just finished cleaning things up some and documenting in these few commits (f5c6407...convert-concat-diagonal)
My understanding from now is that it would only benefit ibis.
For how="diagonal"?
You might be right there, but I was thinking this replaces the need for combine_schemas entirely.
That's used in all of the backends (besides polars).
Is combine_schemas shorter than align_diagonal?
Definitely! 😅
But it is also called n_schemas - 1 times per concat.
By my count, each of those calls does:
- 1x
Schema.keys() - 3x
Schema.items() - 2x
Schema(...)
I was quite alarmed by that when I noticed it 😳
So in my head, that ^^^
... and all of the extra stuff here ...
narwhals/narwhals/_arrow/namespace.py
Lines 180 to 193 in 8fabb13
... seems like it could be avoided and we just do align_diagonal -> _concat_vertical_relaxed?
narwhals/narwhals/_arrow/namespace.py
Lines 213 to 218 in 8fabb13
Note
I would rather fully disolve 🫠 than benchmark again, so I'm just asking to consider how many iterations each is doing 🙏
tests/frame/concat_test.py
Outdated
| def _cast(frame: LazyFrameT, schema: Schema) -> LazyFrameT: | ||
| return frame.select(nw.col(name).cast(dtype) for name, dtype in schema.items()) |
There was a problem hiding this comment.
I still cry each time I try to reach out for {DataFrame,LazyFrame}.cast 😂
There was a problem hiding this comment.
@FBruzzesi let's revive #1045 as a new issue.
The example in that test isn't the motivating part.
What's here could be both more efficient and shorter
narwhals/narwhals/_arrow/namespace.py
Lines 180 to 194 in 8fabb13
+++ b/narwhals/_arrow/dataframe.py
+ def cast(self, schema: IntoSchema | Iterable[tuple[str, DType]]) -> Self:
+ return self._with_native(
+ self.native.cast(Schema(schema).to_arrow()), validate_column_names=False
+ )
+
+++ b/narwhals/_arrow/namespace.py
-from narwhals._arrow.utils import cast_to_comparable_string_types
+from narwhals._arrow.utils import cast_to_comparable_string_types, concat_tables
-from narwhals.schema import Schema, combine_schemas, to_supertype
+from narwhals.schema import Schema, to_supertype, to_supertype_diagonalWith that, we can avoid materializing as much all at once
def concat_diagonal_relaxed(self, dfs: Sequence[ArrowDataFrame], /) -> ArrowDataFrame:
schemas = [Schema(df.schema) for df in dfs]
supertypes = reduce(to_supertype_diagonal, schemas)
to_concat = (
df.cast((name, supertypes[name]) for name in schema).native
for df, schema in zip(dfs, schemas)
)
return dfs[0]._with_native(concat_tables(to_concat))```py
error: Unsupported target for indexed assignment ("Mapping[str, DType]")
```
If the function already exists, we can use it instead 😄
- Reuse the schemas we've already collected - And rename the variables so it is obvious - Skip checking if columns exist for the first schema - Use a generator inside `dict.update` (instead of many `__setitem__`s) - Use `iter_dtype_backends` instead of creating a new `partial` Unrelated to performance: Return `.values()` instead of a `dict` (no usage of the keys anywhere?)
| assert_equal_data(result.collect(), expected_data) | ||
|
|
||
|
|
||
| def test_pyarrow_concat_vertical_uuid() -> None: |
There was a problem hiding this comment.
@MarcoGorelli I added 3 tests (pyarrow, duckdb and ibis) to check if the changes in 4a1b946 and b205ddb support what we mentioned yesterday during the community call. Namely that datatypes that are not supported by narwhals are preserved.
For polars, pandas and dask I am not applying these "workarounds" since we should support all their datatypes*
*Exception is made for poalrs Float16 and Null, but also for polars we just dispatch the method without touching the datatypes. Am I missing something for pandas and dask?
There was a problem hiding this comment.
pd.PeriodDtype?
But more generally, pandas supports registering extension types and so does polars (recently)
| ns: CompliantNamespace[CompliantFrameT, CompliantExprT], | ||
| mapping: Mapping[str, dtypes.DType], | ||
| ) -> Iterable[CompliantExprT]: | ||
| Unknown = ns._version.dtypes.Unknown() # noqa: N806 |
There was a problem hiding this comment.
I would:
- Move this to https://github.com/narwhals-dev/narwhals/blob/4a1b946eeb3627812a0d44fd270f77837629da3c/narwhals/dtypes/_utils.py
- Use
from narwhals.dtypes._classes import Unknownat the module-level - Replace
==withdtype.base_type() is Unknown
Description
As Marco often says, in for a penny in for a pound 😂
This PR builds up on #3396 to showcase how that's useful to implement concat relaxed versions.
It's probably possible to lower some of the repetition, but considering how much time I spent on this (not sooo much so far), I am quite happy with the shape of it.
On top of the new feature there are some minor adjustment for the concat implementations across backends
TODO:
concat(..., how="*_relaxed"})#3398 (comment)how={"vertical", "diagonal"}, and would be a followup to this PRWhat type of PR is this? (check all applicable)
Related issues
concat(..., how={"vertical_relaxed", "diagonal_relaxed"})#3386Checklist