11from typing import Any , Generic , Protocol , Self , TypeAlias , final , type_check_only
2- from typing_extensions import TypeAliasType , TypeVar
2+ from typing_extensions import TypeAliasType , TypeVar , TypeVarTuple , override
33
4- from ._shape import Shape , Shape0 , Shape0N , Shape1 , Shape1N , Shape2 , Shape2N , Shape3 , Shape3N , Shape4 , Shape4N
4+ from ._shape import AnyShape , Shape , Shape0 , Shape1 , Shape1N , Shape2 , Shape2N , Shape3 , Shape3N , Shape4 , Shape4N
55
66__all__ = [
7+ "HasInnerShape" ,
78 "HasRankGE" ,
89 "HasRankLE" ,
910 "Rank" ,
@@ -21,56 +22,71 @@ __all__ = [
2122
2223###
2324
24- _Shape00 : TypeAlias = Shape0
25- _Shape01 : TypeAlias = _Shape00 | Shape1
25+ _Shape01 : TypeAlias = Shape0 | Shape1
2626_Shape02 : TypeAlias = _Shape01 | Shape2
2727_Shape03 : TypeAlias = _Shape02 | Shape3
2828_Shape04 : TypeAlias = _Shape03 | Shape4
2929
3030###
3131
32- _UpperT = TypeVar ("_UpperT" , bound = Shape )
33- _LowerT = TypeVar ("_LowerT" , bound = Shape )
32+ # TODO(jorenham): remove `| Rank0 | Rank` once python/mypy#19110 is fixed
33+ _UpperT = TypeVar ("_UpperT" , bound = Shape | Rank0 | Rank )
34+ _LowerT = TypeVar ("_LowerT" , bound = Shape | Rank0 | Rank )
3435_RankT = TypeVar ("_RankT" , bound = Shape , default = Any )
3536
37+ # TODO(jorenham): remove `| Rank0 | Rank` once python/mypy#19110 is fixed
38+ _RankLE : TypeAlias = _CanBroadcast [Any , _UpperT , _RankT ] | Shape0 | Rank0 | Rank
39+ # TODO(jorenham): remove `| Rank` once python/mypy#19110 is fixed
40+ _RankGE : TypeAlias = _CanBroadcast [_LowerT , Any , _RankT ] | _LowerT | Rank
41+
3642HasRankLE = TypeAliasType (
3743 "HasRankLE" ,
38- _HasShape [ Shape0 | _HasOwnShape [ _UpperT ] | _CanBroadcast [ Any , _UpperT , _RankT ]],
44+ _HasInnerShape [ _RankLE [ _UpperT , _RankT ]],
3945 type_params = (_UpperT , _RankT ),
4046)
4147HasRankGE = TypeAliasType (
4248 "HasRankGE" ,
43- _HasShape [ _LowerT | _CanBroadcast [_LowerT , Any , _RankT ]],
49+ _HasInnerShape [ _RankGE [_LowerT , _RankT ]],
4450 type_params = (_LowerT , _RankT ),
4551)
4652
47- ###
53+ _ShapeT = TypeVar ( "_ShapeT" , bound = Shape )
4854
49- _ShapeT_co = TypeVar ("_ShapeT_co" , bound = Shape | _HasOwnShape | _CanBroadcast , covariant = True )
55+ # for unwrapping potential rank types as shape tuples
56+ HasInnerShape = TypeAliasType (
57+ "HasInnerShape" ,
58+ _HasInnerShape [_HasOwnShape [Any , _ShapeT ]],
59+ type_params = (_ShapeT ,),
60+ )
5061
51- @type_check_only
52- class _HasShape (Protocol [_ShapeT_co ]):
53- @property
54- def shape (self , / ) -> _ShapeT_co : ...
62+ ###
63+
64+ _ShapeLikeT_co = TypeVar ("_ShapeLikeT_co" , bound = Shape | _HasOwnShape | _CanBroadcast [Any , Any ], covariant = True )
5565
56- _FromT_contra = TypeVar ("_FromT_contra" , default = Any , contravariant = True )
57- _ToT_contra = TypeVar ("_ToT_contra" , bound = Shape , default = Any , contravariant = True )
66+ _FromT_contra = TypeVar ("_FromT_contra" , contravariant = True )
67+ _ToT_contra = TypeVar ("_ToT_contra" , bound = tuple [ Any , ...] , contravariant = True )
5868_EquivT_co = TypeVar ("_EquivT_co" , bound = Shape , default = Any , covariant = True )
5969
70+ # __broadcast__ is the type-check-only interface order of ranks
6071@final
6172@type_check_only
6273class _CanBroadcast (Protocol [_FromT_contra , _ToT_contra , _EquivT_co ]):
6374 def __broadcast__ (self , from_ : _FromT_contra , to : _ToT_contra , / ) -> _EquivT_co : ...
6475
76+ # __inner_shape__ is similar to `shape`, but directly exposes the `Rank` type.
77+ @final
78+ @type_check_only
79+ class _HasInnerShape (Protocol [_ShapeLikeT_co ]):
80+ @property
81+ def __inner_shape__ (self , / ) -> _ShapeLikeT_co : ...
82+
83+ _OwnShapeT_contra = TypeVar ("_OwnShapeT_contra" , bound = tuple [Any , ...], default = Any , contravariant = True )
84+ _OwnShapeT_co = TypeVar ("_OwnShapeT_co" , bound = Shape , default = _OwnShapeT_contra , covariant = True )
85+
6586# This double shape-type parameter is a sneaky way to annotate a doubly-bound nominal type range,
6687# e.g. `_HasOwnShape[Shape2N, Shape0N]` accepts `Shape2N`, `Shape1N`, and `Shape0N`, but
6788# rejects `Shape3N` and `Shape1`. Besides brevity, it also works around several mypy bugs that
6889# are related to "unions vs joins".
69-
70- _OwnShapeT_contra = TypeVar ("_OwnShapeT_contra" , bound = Shape , default = Any , contravariant = True )
71- _OwnShapeT_co = TypeVar ("_OwnShapeT_co" , bound = Shape , default = _OwnShapeT_contra , covariant = True )
72- _OwnShapeT = TypeVar ("_OwnShapeT" , bound = tuple [Any , ...], default = Any )
73-
7490@final
7591@type_check_only
7692class _HasOwnShape (Protocol [_OwnShapeT_contra , _OwnShapeT_co ]):
@@ -79,59 +95,74 @@ class _HasOwnShape(Protocol[_OwnShapeT_contra, _OwnShapeT_co]):
7995###
8096# TODO(jorenham): embed the array-like types, e.g. `Sequence[Sequence[T]]`
8197
82- @type_check_only
83- class _BaseRank (Generic [_FromT_contra , _OwnShapeT , _ToT_contra ]):
84- def __broadcast__ (self , from_ : _FromT_contra , to : _ToT_contra , / ) -> Self : ...
85- def __own_shape__ (self , shape : _OwnShapeT , / ) -> _OwnShapeT : ...
98+ _Ts = TypeVarTuple ("_Ts" ) # should only contain `int`s
8699
100+ # https://github.com/python/mypy/issues/19093
87101@type_check_only
88- class _BaseRankM (
89- _BaseRank [_FromT_contra | _HasOwnShape [_ToT_contra , Shape ], _OwnShapeT , _ToT_contra ],
90- Generic [_FromT_contra , _OwnShapeT , _ToT_contra ],
91- ): ...
102+ class BaseRank (tuple [* _Ts ], Generic [* _Ts ]):
103+ def __broadcast__ (self , from_ : tuple [* _Ts ], to : tuple [* _Ts ], / ) -> Self : ...
104+ def __own_shape__ (self , shape : tuple [* _Ts ], / ) -> tuple [* _Ts ]: ...
92105
93106@final
94107@type_check_only
95- class Rank0 (_BaseRankM [_Shape00 , Shape0 , Shape0N ], tuple [()]): ...
108+ class Rank0 (BaseRank [()]):
109+ @override
110+ def __broadcast__ (self , from_ : Shape0 | _HasOwnShape [Shape , Any ], to : Shape , / ) -> Self : ...
96111
97112@final
98113@type_check_only
99- class Rank1 (_BaseRankM [_Shape01 , Shape1 , Shape1N ], tuple [int ]): ...
114+ class Rank1 (BaseRank [int ]):
115+ @override
116+ def __broadcast__ (self , from_ : _Shape01 | _HasOwnShape [Shape1N , Any ], to : Shape1N , / ) -> Self : ...
100117
101118@final
102119@type_check_only
103- class Rank2 (_BaseRankM [_Shape02 , Shape2 , Shape2N ], tuple [int , int ]): ...
120+ class Rank2 (BaseRank [int , int ]):
121+ @override
122+ def __broadcast__ (self , from_ : _Shape02 | _HasOwnShape [Shape2N , Any ], to : Shape2N , / ) -> Self : ...
104123
105124@final
106125@type_check_only
107- class Rank3 (_BaseRankM [_Shape03 , Shape3 , Shape3N ], tuple [int , int , int ]): ...
126+ class Rank3 (BaseRank [int , int , int ]):
127+ @override
128+ def __broadcast__ (self , from_ : _Shape03 | _HasOwnShape [Shape3N , Any ], to : Shape3N , / ) -> Self : ...
108129
109130@final
110131@type_check_only
111- class Rank4 (_BaseRankM [_Shape04 , Shape4 , Shape4N ], tuple [int , int , int , int ]): ...
132+ class Rank4 (BaseRank [int , int , int , int ]):
133+ @override
134+ def __broadcast__ (self , from_ : _Shape04 | _HasOwnShape [Shape4N , Any ], to : Shape4N , / ) -> Self : ...
112135
113- # this emulates `AnyOf`, rather than a `Union`.
114- @type_check_only
115- class _BaseRankMToN (_BaseRank [Shape0N , _OwnShapeT , _OwnShapeT ], Generic [_OwnShapeT ]): ...
136+ # these emulates `AnyOf` (gradual union), rather than a `Union`.
116137
117138@final
118139@type_check_only
119- class Rank (_BaseRankMToN [Shape0N ], tuple [int , ...]): ...
140+ class Rank (BaseRank [* tuple [int , ...]]):
141+ @override
142+ def __broadcast__ (self , from_ : AnyShape , to : tuple [* _Ts ], / ) -> Self : ...
120143
121144@final
122145@type_check_only
123- class Rank1N (_BaseRankMToN [Shape1N ], tuple [int , * tuple [int , ...]]): ...
146+ class Rank1N (BaseRank [int , * tuple [int , ...]]):
147+ @override
148+ def __broadcast__ (self , from_ : AnyShape , to : Shape1N , / ) -> Self : ...
124149
125150@final
126151@type_check_only
127- class Rank2N (_BaseRankMToN [Shape2N ], tuple [int , int , * tuple [int , ...]]): ...
152+ class Rank2N (BaseRank [int , int , * tuple [int , ...]]):
153+ @override
154+ def __broadcast__ (self , from_ : AnyShape , to : Shape2N , / ) -> Self : ...
128155
129156@final
130157@type_check_only
131- class Rank3N (_BaseRankMToN [Shape3N ], tuple [int , int , int , * tuple [int , ...]]): ...
158+ class Rank3N (BaseRank [int , int , int , * tuple [int , ...]]):
159+ @override
160+ def __broadcast__ (self , from_ : AnyShape , to : Shape3N , / ) -> Self : ...
132161
133162@final
134163@type_check_only
135- class Rank4N (_BaseRankMToN [Shape4N ], tuple [int , int , int , int , * tuple [int , ...]]): ...
164+ class Rank4N (BaseRank [int , int , int , int , * tuple [int , ...]]):
165+ @override
166+ def __broadcast__ (self , from_ : AnyShape , to : Shape4N , / ) -> Self : ...
136167
137168Rank0N : TypeAlias = Rank
0 commit comments