2626
2727import asyncio
2828import datetime
29+ import inspect
2930import logging
30- from typing import TYPE_CHECKING , cast
31+ from typing import TYPE_CHECKING , Any , cast
3132
32- import aiohttp
33+ import niquests
3334
3435from ..backoff import Backoff
3536from ..exceptions import HTTPException , WebsocketConnectionException
6465WSS : str = "wss://eventsub.wss.twitch.tv/ws"
6566
6667
68+ async def _resolve_awaitable (value : Any ) -> Any :
69+ return await value if inspect .isawaitable (value ) else value
70+
71+
6772class WebsocketClosed :
6873 # TODO: Docs...
6974
@@ -95,6 +100,8 @@ class Websocket:
95100 "_session_id" ,
96101 "_shard_id" ,
97102 "_socket" ,
103+ "_socket_response" ,
104+ "_socket_session" ,
98105 "_subscriptions" ,
99106 "_token_for" ,
100107 )
@@ -116,7 +123,9 @@ def __init__(
116123
117124 self ._session_id : str | None = None
118125
119- self ._socket : aiohttp .ClientWebSocketResponse | None = None
126+ self ._socket : Any = None
127+ self ._socket_response : niquests .Response | niquests .AsyncResponse | None = None
128+ self ._socket_session : niquests .AsyncSession | None = None
120129 self ._listen_task : asyncio .Task [None ] | None = None
121130
122131 self ._ready : asyncio .Event = asyncio .Event ()
@@ -143,12 +152,12 @@ def __init__(
143152
144153 self ._connection_tasks : set [asyncio .Task [None ]] = set ()
145154
146- msg = "Websocket %s is being used without a Client/Bot. Event dispatching is disabled for this websocket."
147155 if not client :
148156 if shard_id is not None :
149157 # TODO: Proper Exception...
150158 raise RuntimeError ("..." )
151159
160+ msg = "Websocket %s is being used without a Client/Bot. Event dispatching is disabled for this websocket."
152161 logger .warning (msg , self )
153162
154163 self ._log_name = "EventSub Websocket" if self ._shard_id is None else "Conduit Websocket"
@@ -173,10 +182,7 @@ def session_id(self) -> str | None:
173182
174183 @property
175184 def can_subscribe (self ) -> bool :
176- if self ._shard_id is not None :
177- return False
178-
179- return self .subscription_count < 300
185+ return False if self ._shard_id is not None else self .subscription_count < 300
180186
181187 @property
182188 def subscription_count (self ) -> int :
@@ -197,11 +203,26 @@ async def connect(self, *, url: str | None = None, reconnect: bool = False, fail
197203 return await self .close ()
198204
199205 while True :
206+ session : niquests .AsyncSession = niquests .AsyncSession ()
200207 try :
201- async with aiohttp .ClientSession () as session :
202- new = await session .ws_connect (url_ , heartbeat = self ._heartbeat )
203- session .detach ()
208+ response : niquests .Response | niquests .AsyncResponse = await session .get (url_ , timeout = self ._heartbeat )
209+ status = response .status_code or 0
210+ new : Any = getattr (response , "extension" , None )
211+
212+ if status != 101 or new is None :
213+ await _resolve_awaitable (response .close ())
214+ await session .close ()
215+
216+ raise WebsocketConnectionException (
217+ f'Failed to connect to { self ._log_name } "{ self } " with status { status } . '
218+ "Please attempt to reconnect or re-subscribe this eventsub connection."
219+ )
204220 except Exception as e :
221+ try :
222+ await session .close ()
223+ except Exception :
224+ pass
225+
205226 logger .debug ('Failed to connect to %s "%s>"": %s.' , self ._log_name , self , e )
206227
207228 if fail_once :
@@ -233,6 +254,8 @@ async def connect(self, *, url: str | None = None, reconnect: bool = False, fail
233254 await self .close (cleanup = False )
234255
235256 self ._socket = new
257+ self ._socket_response = response
258+ self ._socket_session = session
236259
237260 if not self ._listen_task :
238261 self ._listen_task = asyncio .create_task (self ._listen ())
@@ -347,39 +370,38 @@ async def _listen(self) -> None:
347370
348371 while True :
349372 try :
350- message : aiohttp . WSMessage = await self ._socket .receive ( )
373+ message = await _resolve_awaitable ( self ._socket .next_payload () )
351374 except Exception :
375+ if self ._closing or self ._closed :
376+ break
377+
352378 await self ._create_connection_task ()
353379 break
354380
355- type_ : aiohttp .WSMsgType = message .type
356- if type_ in (aiohttp .WSMsgType .CLOSED , aiohttp .WSMsgType .CLOSE , aiohttp .WSMsgType .CLOSING ):
357- logger .debug ('Received close message [%s] on %s: "%s"' , self ._socket .close_code , self ._log_name , self )
358-
359- if self ._socket .close_code == 4001 :
360- logger .critical (
361- '%s "%s" attempted to send an outgoing message to Twitch. '
362- "Twitch prohibits sending outgoing messages to the server, this will result in a disconnect. "
363- "This websocket will NOT attempt to reconnect." ,
364- self ._log_name ,
365- self ,
366- )
367- return await self .close ()
381+ if message is None :
382+ logger .debug ('Received close message on %s: "%s"' , self ._log_name , self )
368383
369- elif self ._socket . close_code == 4003 :
370- return await self . close ()
384+ if self ._closing or self . _closed :
385+ break
371386
372387 await self ._create_connection_task ()
373388 break
374389
375- if type_ is not aiohttp .WSMsgType .TEXT :
390+ if isinstance (message , bytes ):
391+ try :
392+ message = message .decode ("utf-8" )
393+ except Exception :
394+ logger .debug ('Received undecodable bytes message from %s: "%s>"' , self ._log_name , self )
395+ continue
396+
397+ if not isinstance (message , str ):
376398 logger .debug ('Received unknown message from %s: "%s>"' , self ._log_name , self )
377399 continue
378400
379401 self ._last_keepalive = datetime .datetime .now ()
380402
381403 try :
382- data : WebsocketMessages = cast ("WebsocketMessages" , _from_json (message . data ))
404+ data : WebsocketMessages = cast ("WebsocketMessages" , _from_json (message ))
383405 except Exception :
384406 logger .warning ('Unable to parse JSON in %s: "%s"' , self ._log_name , self )
385407 continue
@@ -524,12 +546,26 @@ async def close(self, cleanup: bool = True, *, reassociate: bool = True) -> None
524546
525547 if self ._socket :
526548 try :
527- await self ._socket .close ()
549+ await _resolve_awaitable (self ._socket .close ())
550+ except Exception :
551+ pass
552+
553+ if self ._socket_response :
554+ try :
555+ await _resolve_awaitable (self ._socket_response .close ())
556+ except Exception :
557+ pass
558+
559+ if self ._socket_session :
560+ try :
561+ await self ._socket_session .close ()
528562 except Exception :
529563 pass
530564
531565 self ._keep_alive_task = None
532566 self ._socket = None
567+ self ._socket_response = None
568+ self ._socket_session = None
533569
534570 if self ._listen_task :
535571 try :
0 commit comments