3939from functools import partial , wraps
4040from typing import TYPE_CHECKING , TypeVar , cast
4141from unittest .mock import Mock , patch
42+ from importlib .metadata import version
4243
4344import attr
4445import pytest
8990 from collections .abc import Awaitable , Callable
9091 from wsproto .events import Event
9192
92- from typing_extensions import ParamSpec
93+ from typing_extensions import ParamSpec , TypeAlias
9394 PS = ParamSpec ("PS" )
9495
96+ StapledMemoryStream : TypeAlias = trio .StapledStream [trio .testing .MemorySendStream , trio .testing .MemoryReceiveStream ]
97+
9598WS_PROTO_VERSION = tuple (map (int , wsproto .__version__ .split ('.' )))
9699
97100HOST = '127.0.0.1'
@@ -116,6 +119,8 @@ async def echo_server(nursery: trio.Nursery) -> AsyncGenerator[WebSocketServer,
116119 serve_fn = partial (serve_websocket , echo_request_handler , HOST , 0 ,
117120 ssl_context = None )
118121 server = await nursery .start (serve_fn )
122+ # Cast needed because currently `nursery.start` has typing issues
123+ # blocked by https://github.com/python/mypy/pull/17512
119124 yield cast (WebSocketServer , server )
120125
121126
@@ -147,37 +152,28 @@ def __init__(self, seconds: int) -> None:
147152 self ._seconds = seconds
148153
149154 def __call__ (self , fn : Callable [PS , Awaitable [T ]]) -> Callable [PS , Awaitable [T | None ]]:
155+ # Type of decorated function contains type `Any`
150156 @wraps (fn )
151- async def wrapper (* args : PS .args , ** kwargs : PS .kwargs ) -> T | None :
152- result : T | None = None
157+ async def wrapper ( # type: ignore[misc]
158+ * args : PS .args ,
159+ ** kwargs : PS .kwargs ,
160+ ) -> T :
153161 with trio .move_on_after (self ._seconds ) as cancel_scope :
154- result = await fn (* args , ** kwargs )
162+ return await fn (* args , ** kwargs )
155163 if cancel_scope .cancelled_caught :
156164 pytest .fail (f'Test runtime exceeded the maximum { self ._seconds } seconds' )
157- return result
165+ raise AssertionError ( "Should be unreachable" )
158166 return wrapper
159167
160168
161169@attr .s (hash = False , eq = False )
162- class MemoryListener (
163- trio .abc .Listener [
164- "trio.StapledStream[trio.testing.MemorySendStream, trio.testing.MemoryReceiveStream]"
165- ]
166- ):
170+ class MemoryListener (trio .abc .Listener ["StapledMemoryStream" ]):
167171 closed : bool = attr .ib (default = False )
168- accepted_streams : list [
169- trio .StapledStream [trio .testing .MemorySendStream , trio .testing .MemoryReceiveStream ]
170- ] = attr .ib (factory = list )
172+ accepted_streams : list [StapledMemoryStream ] = attr .ib (factory = list )
171173 queued_streams : tuple [
172- trio .MemorySendChannel [
173- trio .StapledStream [trio .testing .MemorySendStream , trio .testing .MemoryReceiveStream ]
174- ],
175- trio .MemoryReceiveChannel [
176- trio .StapledStream [trio .testing .MemorySendStream , trio .testing .MemoryReceiveStream ]
177- ],
178- ] = attr .ib (factory = lambda : trio .open_memory_channel [
179- "trio.StapledStream[trio.testing.MemorySendStream, trio.testing.MemoryReceiveStream]"
180- ](1 ))
174+ trio .MemorySendChannel [StapledMemoryStream ],
175+ trio .MemoryReceiveChannel [StapledMemoryStream ],
176+ ] = attr .ib (factory = lambda : trio .open_memory_channel ["StapledMemoryStream" ](1 ))
181177 accept_hook : Callable [[], Awaitable [object ]] | None = attr .ib (default = None )
182178
183179 async def connect (self ) -> trio .StapledStream [
@@ -385,8 +381,11 @@ async def test_ascii_encoded_path_is_ok(echo_server: WebSocketServer) -> None:
385381 assert conn .path == RESOURCE + '/' + path
386382
387383
384+ # Type ignore because @patch contains `Any`
388385@patch ('trio_websocket._impl.open_websocket' )
389- def test_client_open_url_options (open_websocket_mock : Mock ) -> None :
386+ def test_client_open_url_options ( # type: ignore[misc]
387+ open_websocket_mock : Mock ,
388+ ) -> None :
390389 """open_websocket_url() must pass its options on to open_websocket()"""
391390 port = 1234
392391 url = f'ws://{ HOST } :{ port } { RESOURCE } '
@@ -618,7 +617,7 @@ async def handler(request: WebSocketRequest) -> None:
618617 assert exc_info .value .__context__ is user_cancelled_context
619618
620619def _trio_default_non_strict_exception_groups () -> bool :
621- version = trio . __version__ # type: ignore[attr-defined]
620+ version = version ( " trio" )
622621 assert re .match (r'^0\.\d\d\.' , version ), "unexpected trio versioning scheme"
623622 return int (version [2 :4 ]) < 25
624623
0 commit comments