|
5 | 5 | from typing import TYPE_CHECKING, Any, Literal, cast, overload |
6 | 6 |
|
7 | 7 | import pyarrow as pa # ignore-banned-import |
8 | | -import pyarrow.compute as pc # ignore-banned-import |
9 | 8 |
|
10 | 9 | from narwhals._arrow.utils import narwhals_to_native_dtype |
11 | 10 | from narwhals._plan._guards import is_tuple_of |
|
18 | 17 | from narwhals.exceptions import InvalidOperationError |
19 | 18 |
|
20 | 19 | if TYPE_CHECKING: |
21 | | - from collections.abc import Callable, Iterable, Iterator, Sequence |
| 20 | + from collections.abc import Iterable, Iterator, Sequence |
| 21 | + |
| 22 | + from typing_extensions import TypeAlias |
22 | 23 |
|
23 | 24 | from narwhals._arrow.typing import ChunkedArrayAny |
| 25 | + from narwhals._plan._dispatch import BoundMethod |
24 | 26 | from narwhals._plan.arrow.dataframe import ArrowDataFrame as Frame |
25 | 27 | from narwhals._plan.arrow.expr import ArrowExpr as Expr, ArrowScalar as Scalar |
26 | 28 | from narwhals._plan.arrow.series import ArrowSeries as Series |
27 | | - from narwhals._plan.arrow.typing import ChunkedArray, IntegerScalar |
| 29 | + from narwhals._plan.arrow.typing import ( |
| 30 | + BinaryFunction, |
| 31 | + ChunkedArray, |
| 32 | + IntegerScalar, |
| 33 | + VariadicFunction, |
| 34 | + ) |
28 | 35 | from narwhals._plan.expressions import expr, functions as F |
29 | 36 | from narwhals._plan.expressions.boolean import AllHorizontal, AnyHorizontal |
30 | 37 | from narwhals._plan.expressions.expr import FunctionExpr as FExpr, RangeExpr |
|
40 | 47 | PythonLiteral, |
41 | 48 | ) |
42 | 49 |
|
| 50 | + Wrapper: TypeAlias = BoundMethod[FExpr[Any], Frame, Expr | Scalar] |
| 51 | + |
43 | 52 |
|
44 | 53 | Int64 = Version.MAIN.dtypes.Int64() |
45 | 54 |
|
@@ -102,53 +111,71 @@ def lit( |
102 | 111 | nw_ser.to_native(), name or node.name, nw_ser.version |
103 | 112 | ) |
104 | 113 |
|
105 | | - def _horizontal_function( |
106 | | - self, fn_native: Callable[[Any, Any], Any], /, fill: NonNestedLiteral = None |
107 | | - ) -> Callable[[FExpr[Any], Frame, str], Expr | Scalar]: |
| 114 | + @overload |
| 115 | + def _horizontal( |
| 116 | + self, function: BinaryFunction, /, fill: NonNestedLiteral = None |
| 117 | + ) -> Wrapper: ... |
| 118 | + @overload |
| 119 | + def _horizontal( |
| 120 | + self, function: VariadicFunction, /, *, variadic: Literal[True] |
| 121 | + ) -> Wrapper: ... |
| 122 | + def _horizontal( |
| 123 | + self, |
| 124 | + function: BinaryFunction | VariadicFunction, |
| 125 | + /, |
| 126 | + fill: NonNestedLiteral = None, |
| 127 | + *, |
| 128 | + variadic: bool = False, |
| 129 | + ) -> Wrapper: |
| 130 | + """Generate a horizontal wrapper function. |
| 131 | +
|
| 132 | + Arguments: |
| 133 | + function: Native binary or variadic function. |
| 134 | + fill: Fill value to use when nulls should *not* be ignored. |
| 135 | + variadic: If False (default), perform a binary reduction. |
| 136 | + Otherwise, assume we can unpack directly into `function`. |
| 137 | + """ |
| 138 | + |
108 | 139 | def func(node: FExpr[Any], frame: Frame, name: str) -> Expr | Scalar: |
109 | 140 | it = (self._expr.from_ir(e, frame, name).native for e in node.input) |
110 | 141 | if fill is not None: |
111 | 142 | it = (fn.fill_null(native, fill) for native in it) |
112 | | - result = reduce(fn_native, it) |
| 143 | + result = function(*it) if variadic else reduce(function, it) |
113 | 144 | if isinstance(result, pa.Scalar): |
114 | 145 | return self._scalar.from_native(result, name, self.version) |
115 | 146 | return self._expr.from_native(result, name, self.version) |
116 | 147 |
|
117 | 148 | return func |
118 | 149 |
|
119 | 150 | def coalesce(self, node: FExpr[F.Coalesce], frame: Frame, name: str) -> Expr | Scalar: |
120 | | - it = (self._expr.from_ir(e, frame, name).native for e in node.input) |
121 | | - result = pc.coalesce(*it) |
122 | | - if isinstance(result, pa.Scalar): |
123 | | - return self._scalar.from_native(result, name, self.version) |
124 | | - return self._expr.from_native(result, name, self.version) |
| 151 | + return self._horizontal(fn.coalesce, variadic=True)(node, frame, name) |
125 | 152 |
|
126 | 153 | def any_horizontal( |
127 | 154 | self, node: FExpr[AnyHorizontal], frame: Frame, name: str |
128 | 155 | ) -> Expr | Scalar: |
129 | 156 | fill = False if node.function.ignore_nulls else None |
130 | | - return self._horizontal_function(fn.or_, fill)(node, frame, name) |
| 157 | + return self._horizontal(fn.or_, fill)(node, frame, name) |
131 | 158 |
|
132 | 159 | def all_horizontal( |
133 | 160 | self, node: FExpr[AllHorizontal], frame: Frame, name: str |
134 | 161 | ) -> Expr | Scalar: |
135 | 162 | fill = True if node.function.ignore_nulls else None |
136 | | - return self._horizontal_function(fn.and_, fill)(node, frame, name) |
| 163 | + return self._horizontal(fn.and_, fill)(node, frame, name) |
137 | 164 |
|
138 | 165 | def sum_horizontal( |
139 | 166 | self, node: FExpr[F.SumHorizontal], frame: Frame, name: str |
140 | 167 | ) -> Expr | Scalar: |
141 | | - return self._horizontal_function(fn.add, fill=0)(node, frame, name) |
| 168 | + return self._horizontal(fn.add, fill=0)(node, frame, name) |
142 | 169 |
|
143 | 170 | def min_horizontal( |
144 | 171 | self, node: FExpr[F.MinHorizontal], frame: Frame, name: str |
145 | 172 | ) -> Expr | Scalar: |
146 | | - return self._horizontal_function(fn.min_horizontal)(node, frame, name) |
| 173 | + return self._horizontal(fn.min_horizontal, variadic=True)(node, frame, name) |
147 | 174 |
|
148 | 175 | def max_horizontal( |
149 | 176 | self, node: FExpr[F.MaxHorizontal], frame: Frame, name: str |
150 | 177 | ) -> Expr | Scalar: |
151 | | - return self._horizontal_function(fn.max_horizontal)(node, frame, name) |
| 178 | + return self._horizontal(fn.max_horizontal, variadic=True)(node, frame, name) |
152 | 179 |
|
153 | 180 | def mean_horizontal( |
154 | 181 | self, node: FExpr[F.MeanHorizontal], frame: Frame, name: str |
|
0 commit comments