Skip to content

Commit 230fbfe

Browse files
committed
refactor(perf): Side-port (#3435) {min,max}_horizontal fix
See #3435 (comment) Also - tweaked some typing - added docs - included `coalesce` in the same path
1 parent 0d81b8b commit 230fbfe

File tree

4 files changed

+59
-19
lines changed

4 files changed

+59
-19
lines changed

narwhals/_plan/arrow/functions/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,11 @@
9595
dtype_native,
9696
string_type,
9797
)
98-
from narwhals._plan.arrow.functions._horizontal import max_horizontal, min_horizontal
98+
from narwhals._plan.arrow.functions._horizontal import (
99+
coalesce,
100+
max_horizontal,
101+
min_horizontal,
102+
)
99103
from narwhals._plan.arrow.functions._lists import ExplodeBuilder
100104
from narwhals._plan.arrow.functions._multiplex import (
101105
fill_nan,
@@ -161,6 +165,7 @@
161165
"clip",
162166
"clip_lower",
163167
"clip_upper",
168+
"coalesce",
164169
"concat_horizontal",
165170
"concat_tables",
166171
"concat_vertical",

narwhals/_plan/arrow/functions/_horizontal.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,16 @@
22

33
import pyarrow.compute as pc # ignore-banned-import
44

5-
__all__ = ["max_horizontal", "min_horizontal"]
5+
__all__ = ["coalesce", "max_horizontal", "min_horizontal"]
66

77
# TODO @dangotbanned: Wrap horizontal functions with correct typing
88
# Should only return scalar if all elements are as well
99
# NOTE: Changing typing will propagate to a lot of places (so be careful!):
1010
# - `_round.{clip,clip_lower,clip_upper}`
1111
# - `acero.join_asof_tables`
1212
# - `ArrowNamespace.{min,max}_horizontal`
13+
# - `ArrowNamespace.coalesce`
1314
# - `ArrowSeries.rolling_var`
1415
min_horizontal = pc.min_element_wise
1516
max_horizontal = pc.max_element_wise
17+
coalesce = pc.coalesce

narwhals/_plan/arrow/namespace.py

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from typing import TYPE_CHECKING, Any, Literal, cast, overload
66

77
import pyarrow as pa # ignore-banned-import
8-
import pyarrow.compute as pc # ignore-banned-import
98

109
from narwhals._arrow.utils import narwhals_to_native_dtype
1110
from narwhals._plan._guards import is_tuple_of
@@ -18,13 +17,21 @@
1817
from narwhals.exceptions import InvalidOperationError
1918

2019
if TYPE_CHECKING:
21-
from collections.abc import Callable, Iterable, Iterator, Sequence
20+
from collections.abc import Iterable, Iterator, Sequence
21+
22+
from typing_extensions import TypeAlias
2223

2324
from narwhals._arrow.typing import ChunkedArrayAny
25+
from narwhals._plan._dispatch import BoundMethod
2426
from narwhals._plan.arrow.dataframe import ArrowDataFrame as Frame
2527
from narwhals._plan.arrow.expr import ArrowExpr as Expr, ArrowScalar as Scalar
2628
from narwhals._plan.arrow.series import ArrowSeries as Series
27-
from narwhals._plan.arrow.typing import ChunkedArray, IntegerScalar
29+
from narwhals._plan.arrow.typing import (
30+
BinaryFunction,
31+
ChunkedArray,
32+
IntegerScalar,
33+
VariadicFunction,
34+
)
2835
from narwhals._plan.expressions import expr, functions as F
2936
from narwhals._plan.expressions.boolean import AllHorizontal, AnyHorizontal
3037
from narwhals._plan.expressions.expr import FunctionExpr as FExpr, RangeExpr
@@ -40,6 +47,8 @@
4047
PythonLiteral,
4148
)
4249

50+
Wrapper: TypeAlias = BoundMethod[FExpr[Any], Frame, Expr | Scalar]
51+
4352

4453
Int64 = Version.MAIN.dtypes.Int64()
4554

@@ -102,53 +111,71 @@ def lit(
102111
nw_ser.to_native(), name or node.name, nw_ser.version
103112
)
104113

105-
def _horizontal_function(
106-
self, fn_native: Callable[[Any, Any], Any], /, fill: NonNestedLiteral = None
107-
) -> Callable[[FExpr[Any], Frame, str], Expr | Scalar]:
114+
@overload
115+
def _horizontal(
116+
self, function: BinaryFunction, /, fill: NonNestedLiteral = None
117+
) -> Wrapper: ...
118+
@overload
119+
def _horizontal(
120+
self, function: VariadicFunction, /, *, variadic: Literal[True]
121+
) -> Wrapper: ...
122+
def _horizontal(
123+
self,
124+
function: BinaryFunction | VariadicFunction,
125+
/,
126+
fill: NonNestedLiteral = None,
127+
*,
128+
variadic: bool = False,
129+
) -> Wrapper:
130+
"""Generate a horizontal wrapper function.
131+
132+
Arguments:
133+
function: Native binary or variadic function.
134+
fill: Fill value to use when nulls should *not* be ignored.
135+
variadic: If False (default), perform a binary reduction.
136+
Otherwise, assume we can unpack directly into `function`.
137+
"""
138+
108139
def func(node: FExpr[Any], frame: Frame, name: str) -> Expr | Scalar:
109140
it = (self._expr.from_ir(e, frame, name).native for e in node.input)
110141
if fill is not None:
111142
it = (fn.fill_null(native, fill) for native in it)
112-
result = reduce(fn_native, it)
143+
result = function(*it) if variadic else reduce(function, it)
113144
if isinstance(result, pa.Scalar):
114145
return self._scalar.from_native(result, name, self.version)
115146
return self._expr.from_native(result, name, self.version)
116147

117148
return func
118149

119150
def coalesce(self, node: FExpr[F.Coalesce], frame: Frame, name: str) -> Expr | Scalar:
120-
it = (self._expr.from_ir(e, frame, name).native for e in node.input)
121-
result = pc.coalesce(*it)
122-
if isinstance(result, pa.Scalar):
123-
return self._scalar.from_native(result, name, self.version)
124-
return self._expr.from_native(result, name, self.version)
151+
return self._horizontal(fn.coalesce, variadic=True)(node, frame, name)
125152

126153
def any_horizontal(
127154
self, node: FExpr[AnyHorizontal], frame: Frame, name: str
128155
) -> Expr | Scalar:
129156
fill = False if node.function.ignore_nulls else None
130-
return self._horizontal_function(fn.or_, fill)(node, frame, name)
157+
return self._horizontal(fn.or_, fill)(node, frame, name)
131158

132159
def all_horizontal(
133160
self, node: FExpr[AllHorizontal], frame: Frame, name: str
134161
) -> Expr | Scalar:
135162
fill = True if node.function.ignore_nulls else None
136-
return self._horizontal_function(fn.and_, fill)(node, frame, name)
163+
return self._horizontal(fn.and_, fill)(node, frame, name)
137164

138165
def sum_horizontal(
139166
self, node: FExpr[F.SumHorizontal], frame: Frame, name: str
140167
) -> Expr | Scalar:
141-
return self._horizontal_function(fn.add, fill=0)(node, frame, name)
168+
return self._horizontal(fn.add, fill=0)(node, frame, name)
142169

143170
def min_horizontal(
144171
self, node: FExpr[F.MinHorizontal], frame: Frame, name: str
145172
) -> Expr | Scalar:
146-
return self._horizontal_function(fn.min_horizontal)(node, frame, name)
173+
return self._horizontal(fn.min_horizontal, variadic=True)(node, frame, name)
147174

148175
def max_horizontal(
149176
self, node: FExpr[F.MaxHorizontal], frame: Frame, name: str
150177
) -> Expr | Scalar:
151-
return self._horizontal_function(fn.max_horizontal)(node, frame, name)
178+
return self._horizontal(fn.max_horizontal, variadic=True)(node, frame, name)
152179

153180
def mean_horizontal(
154181
self, node: FExpr[F.MeanHorizontal], frame: Frame, name: str

narwhals/_plan/arrow/typing.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,12 @@ class BinaryComp(
206206
class BinaryLogical(BinaryFunction["BooleanScalar", "pa.BooleanScalar"], Protocol): ...
207207

208208

209+
# TODO @dangotbanned: Use stricter typing & revisit with
210+
# https://github.com/narwhals-dev/narwhals/blob/0d81b8b8cd1d24a68d360e6f8cf742b45cc2bdec/narwhals/_plan/arrow/functions/_horizontal.py#L7-L15
211+
class VariadicFunction(Protocol):
212+
def __call__(self, *args: Arrow) -> Any: ...
213+
214+
209215
BinaryNumericTemporal: TypeAlias = BinaryFunction[
210216
NumericOrTemporalScalarT, NumericOrTemporalScalarT
211217
]

0 commit comments

Comments
 (0)