Skip to content

Commit 2702c9f

Browse files
committed
Convert websocket to niquests
1 parent f44bf3e commit 2702c9f

File tree

2 files changed

+69
-31
lines changed

2 files changed

+69
-31
lines changed

requirements.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
1-
aiohttp>=3.9.1,<4
1+
aiohttp>=3.9.1,<4
2+
niquests[ws]>=3.9.0,<4
3+
# niquests[speedups,ws]>=3.9.0,<4

twitchio/eventsub/websockets.py

Lines changed: 66 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,11 @@
2626

2727
import asyncio
2828
import datetime
29+
import inspect
2930
import logging
30-
from typing import TYPE_CHECKING, cast
31+
from typing import TYPE_CHECKING, Any, cast
3132

32-
import aiohttp
33+
import niquests
3334

3435
from ..backoff import Backoff
3536
from ..exceptions import HTTPException, WebsocketConnectionException
@@ -64,6 +65,10 @@
6465
WSS: str = "wss://eventsub.wss.twitch.tv/ws"
6566

6667

68+
async def _resolve_awaitable(value: Any) -> Any:
69+
return await value if inspect.isawaitable(value) else value
70+
71+
6772
class WebsocketClosed:
6873
# TODO: Docs...
6974

@@ -95,6 +100,8 @@ class Websocket:
95100
"_session_id",
96101
"_shard_id",
97102
"_socket",
103+
"_socket_response",
104+
"_socket_session",
98105
"_subscriptions",
99106
"_token_for",
100107
)
@@ -116,7 +123,9 @@ def __init__(
116123

117124
self._session_id: str | None = None
118125

119-
self._socket: aiohttp.ClientWebSocketResponse | None = None
126+
self._socket: Any = None
127+
self._socket_response: niquests.Response | niquests.AsyncResponse | None = None
128+
self._socket_session: niquests.AsyncSession | None = None
120129
self._listen_task: asyncio.Task[None] | None = None
121130

122131
self._ready: asyncio.Event = asyncio.Event()
@@ -143,12 +152,12 @@ def __init__(
143152

144153
self._connection_tasks: set[asyncio.Task[None]] = set()
145154

146-
msg = "Websocket %s is being used without a Client/Bot. Event dispatching is disabled for this websocket."
147155
if not client:
148156
if shard_id is not None:
149157
# TODO: Proper Exception...
150158
raise RuntimeError("...")
151159

160+
msg = "Websocket %s is being used without a Client/Bot. Event dispatching is disabled for this websocket."
152161
logger.warning(msg, self)
153162

154163
self._log_name = "EventSub Websocket" if self._shard_id is None else "Conduit Websocket"
@@ -173,10 +182,7 @@ def session_id(self) -> str | None:
173182

174183
@property
175184
def can_subscribe(self) -> bool:
176-
if self._shard_id is not None:
177-
return False
178-
179-
return self.subscription_count < 300
185+
return False if self._shard_id is not None else self.subscription_count < 300
180186

181187
@property
182188
def subscription_count(self) -> int:
@@ -197,11 +203,26 @@ async def connect(self, *, url: str | None = None, reconnect: bool = False, fail
197203
return await self.close()
198204

199205
while True:
206+
session: niquests.AsyncSession = niquests.AsyncSession()
200207
try:
201-
async with aiohttp.ClientSession() as session:
202-
new = await session.ws_connect(url_, heartbeat=self._heartbeat)
203-
session.detach()
208+
response: niquests.Response | niquests.AsyncResponse = await session.get(url_, timeout=self._heartbeat)
209+
status = response.status_code or 0
210+
new: Any = getattr(response, "extension", None)
211+
212+
if status != 101 or new is None:
213+
await _resolve_awaitable(response.close())
214+
await session.close()
215+
216+
raise WebsocketConnectionException(
217+
f'Failed to connect to {self._log_name} "{self}" with status {status}. '
218+
"Please attempt to reconnect or re-subscribe this eventsub connection."
219+
)
204220
except Exception as e:
221+
try:
222+
await session.close()
223+
except Exception:
224+
pass
225+
205226
logger.debug('Failed to connect to %s "%s>"": %s.', self._log_name, self, e)
206227

207228
if fail_once:
@@ -233,6 +254,8 @@ async def connect(self, *, url: str | None = None, reconnect: bool = False, fail
233254
await self.close(cleanup=False)
234255

235256
self._socket = new
257+
self._socket_response = response
258+
self._socket_session = session
236259

237260
if not self._listen_task:
238261
self._listen_task = asyncio.create_task(self._listen())
@@ -347,39 +370,38 @@ async def _listen(self) -> None:
347370

348371
while True:
349372
try:
350-
message: aiohttp.WSMessage = await self._socket.receive()
373+
message = await _resolve_awaitable(self._socket.next_payload())
351374
except Exception:
375+
if self._closing or self._closed:
376+
break
377+
352378
await self._create_connection_task()
353379
break
354380

355-
type_: aiohttp.WSMsgType = message.type
356-
if type_ in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING):
357-
logger.debug('Received close message [%s] on %s: "%s"', self._socket.close_code, self._log_name, self)
358-
359-
if self._socket.close_code == 4001:
360-
logger.critical(
361-
'%s "%s" attempted to send an outgoing message to Twitch. '
362-
"Twitch prohibits sending outgoing messages to the server, this will result in a disconnect. "
363-
"This websocket will NOT attempt to reconnect.",
364-
self._log_name,
365-
self,
366-
)
367-
return await self.close()
381+
if message is None:
382+
logger.debug('Received close message on %s: "%s"', self._log_name, self)
368383

369-
elif self._socket.close_code == 4003:
370-
return await self.close()
384+
if self._closing or self._closed:
385+
break
371386

372387
await self._create_connection_task()
373388
break
374389

375-
if type_ is not aiohttp.WSMsgType.TEXT:
390+
if isinstance(message, bytes):
391+
try:
392+
message = message.decode("utf-8")
393+
except Exception:
394+
logger.debug('Received undecodable bytes message from %s: "%s>"', self._log_name, self)
395+
continue
396+
397+
if not isinstance(message, str):
376398
logger.debug('Received unknown message from %s: "%s>"', self._log_name, self)
377399
continue
378400

379401
self._last_keepalive = datetime.datetime.now()
380402

381403
try:
382-
data: WebsocketMessages = cast("WebsocketMessages", _from_json(message.data))
404+
data: WebsocketMessages = cast("WebsocketMessages", _from_json(message))
383405
except Exception:
384406
logger.warning('Unable to parse JSON in %s: "%s"', self._log_name, self)
385407
continue
@@ -524,12 +546,26 @@ async def close(self, cleanup: bool = True, *, reassociate: bool = True) -> None
524546

525547
if self._socket:
526548
try:
527-
await self._socket.close()
549+
await _resolve_awaitable(self._socket.close())
550+
except Exception:
551+
pass
552+
553+
if self._socket_response:
554+
try:
555+
await _resolve_awaitable(self._socket_response.close())
556+
except Exception:
557+
pass
558+
559+
if self._socket_session:
560+
try:
561+
await self._socket_session.close()
528562
except Exception:
529563
pass
530564

531565
self._keep_alive_task = None
532566
self._socket = None
567+
self._socket_response = None
568+
self._socket_session = None
533569

534570
if self._listen_task:
535571
try:

0 commit comments

Comments
 (0)