Skip to content

Commit cd7b23f

Browse files
committed
Fix: connection_acquisition_timeout now covers TLS handshake
1 parent d10d2cd commit cd7b23f

File tree

3 files changed

+85
-73
lines changed

3 files changed

+85
-73
lines changed

src/neo4j/_async/io/_bolt_socket.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,11 @@ async def connect(
327327
s = None
328328
try:
329329
s = await cls._connect_secure(
330-
resolved_address, tcp_timeout, keep_alive, ssl_context
330+
resolved_address,
331+
tcp_timeout,
332+
deadline,
333+
keep_alive,
334+
ssl_context,
331335
)
332336
agreed_version = await s._handshake(resolved_address, deadline)
333337
return s, agreed_version

src/neo4j/_async_compat/network/_bolt_socket.py

Lines changed: 75 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -205,13 +205,19 @@ def kill(self):
205205

206206
@classmethod
207207
async def _connect_secure(
208-
cls, resolved_address, timeout, keep_alive, ssl_context
208+
cls,
209+
resolved_address: Address,
210+
timeout: float | None,
211+
deadline: Deadline,
212+
keep_alive: bool,
213+
ssl_context: SSLContext | None,
209214
) -> t.Self:
210215
"""
211216
Connect to the address and return the socket.
212217
213218
:param resolved_address:
214219
:param timeout: seconds
220+
:param deadline: deadline for the whole operation
215221
:param keep_alive: True or False
216222
:param ssl_context: SSLContext or None
217223
@@ -223,24 +229,35 @@ async def _connect_secure(
223229

224230
# TODO: tomorrow me: fix this mess
225231
try:
226-
if len(resolved_address) == 2:
227-
s = socket(AF_INET)
228-
elif len(resolved_address) == 4:
229-
s = socket(AF_INET6)
230-
else:
231-
raise ValueError(f"Unsupported address {resolved_address!r}")
232-
s.setblocking(False) # asyncio + blocking = no-no!
233-
log.debug("[#0000] C: <OPEN> %s", resolved_address)
234-
await wait_for(loop.sock_connect(s, resolved_address), timeout)
235-
local_port = s.getsockname()[1]
232+
try:
233+
if len(resolved_address) == 2:
234+
s = socket(AF_INET)
235+
elif len(resolved_address) == 4:
236+
s = socket(AF_INET6)
237+
else:
238+
raise ValueError(
239+
f"Unsupported address {resolved_address!r}"
240+
)
241+
s.setblocking(False) # asyncio + blocking = no-no!
242+
log.debug("[#0000] C: <OPEN> %s", resolved_address)
243+
await wait_for(loop.sock_connect(s, resolved_address), timeout)
244+
local_port = s.getsockname()[1]
236245

237-
keep_alive = 1 if keep_alive else 0
238-
s.setsockopt(SOL_SOCKET, SO_KEEPALIVE, keep_alive)
246+
s.setsockopt(SOL_SOCKET, SO_KEEPALIVE, 1 if keep_alive else 0)
247+
except asyncio.TimeoutError:
248+
log.debug("[#0000] S: <TIMEOUT> %s", resolved_address)
249+
raise ServiceUnavailable(
250+
"Timed out trying to establish connection to "
251+
f"{resolved_address!r}"
252+
) from None
253+
except asyncio.CancelledError:
254+
log.debug("[#0000] S: <CANCELLED> %s", resolved_address)
255+
raise
239256

240257
ssl_kwargs: dict[str, t.Any] = {}
241258

259+
hostname = resolved_address._host_name or None
242260
if ssl_context is not None:
243-
hostname = resolved_address._host_name or None
244261
sni_host = hostname if HAS_SNI and hostname else None
245262
ssl_kwargs.update(ssl=ssl_context, server_hostname=sni_host)
246263
log.debug("[#%04X] C: <SECURE> %s", local_port, hostname)
@@ -250,9 +267,28 @@ async def _connect_secure(
250267
loop=loop,
251268
)
252269
protocol = asyncio.StreamReaderProtocol(reader, loop=loop)
253-
transport, _ = await loop.create_connection(
254-
lambda: protocol, sock=s, **ssl_kwargs
255-
)
270+
271+
try:
272+
transport, _ = await wait_for(
273+
loop.create_connection(
274+
lambda: protocol, sock=s, **ssl_kwargs
275+
),
276+
deadline.to_timeout(),
277+
)
278+
except (OSError, SSLError, CertificateError) as error:
279+
log.debug(
280+
"[#0000] S: <SECURE FAILURE> %s: %r",
281+
resolved_address,
282+
error,
283+
)
284+
raise BoltSecurityError(
285+
message="Failed to establish encrypted connection.",
286+
address=(hostname, local_port),
287+
) from error
288+
except asyncio.CancelledError:
289+
log.debug("[#0000] S: <CANCELLED> %s", resolved_address)
290+
raise
291+
256292
writer = asyncio.StreamWriter(transport, protocol, reader, loop)
257293

258294
if ssl_context is not None:
@@ -265,39 +301,8 @@ async def _connect_secure(
265301
raise BoltProtocolError(
266302
"When using an encrypted socket, the server should "
267303
"always provide a certificate",
268-
address=(resolved_address._host_name, local_port),
304+
address=(hostname, local_port),
269305
)
270-
271-
return cls(reader, protocol, writer)
272-
273-
except asyncio.TimeoutError:
274-
log.debug("[#0000] S: <TIMEOUT> %s", resolved_address)
275-
log.debug("[#0000] C: <CLOSE> %s", resolved_address)
276-
if s:
277-
cls._kill_raw_socket(s)
278-
raise ServiceUnavailable(
279-
"Timed out trying to establish connection to "
280-
f"{resolved_address!r}"
281-
) from None
282-
except asyncio.CancelledError:
283-
log.debug("[#0000] S: <CANCELLED> %s", resolved_address)
284-
log.debug("[#0000] C: <CLOSE> %s", resolved_address)
285-
if s:
286-
cls._kill_raw_socket(s)
287-
raise
288-
except (SSLError, CertificateError) as error:
289-
log.debug(
290-
"[#0000] S: <SECURE FAILURE> %s: %s",
291-
resolved_address,
292-
error,
293-
)
294-
log.debug("[#0000] C: <CLOSE> %s", resolved_address)
295-
if s:
296-
cls._kill_raw_socket(s)
297-
raise BoltSecurityError(
298-
message="Failed to establish encrypted connection.",
299-
address=(resolved_address._host_name, local_port),
300-
) from error
301306
except Exception as error:
302307
log.debug(
303308
"[#0000] S: <ERROR> %s %s",
@@ -314,6 +319,8 @@ async def _connect_secure(
314319
) from error
315320
raise
316321

322+
return cls(reader, protocol, writer)
323+
317324
@abc.abstractmethod
318325
async def _handshake(
319326
self,
@@ -463,13 +470,19 @@ def kill(self):
463470

464471
@classmethod
465472
def _connect_secure(
466-
cls, resolved_address, timeout, keep_alive, ssl_context
473+
cls,
474+
resolved_address: Address,
475+
timeout: float | None,
476+
deadline: Deadline,
477+
keep_alive: bool,
478+
ssl_context: SSLContext | None,
467479
):
468480
"""
469481
Connect to the address and return the socket.
470482
471483
:param resolved_address:
472484
:param timeout: seconds
485+
:param deadline: deadline for the whole operation
473486
:param keep_alive: True or False
474487
:returns: socket object
475488
"""
@@ -497,26 +510,14 @@ def _connect_secure(
497510
log.debug("[#0000] C: <OPEN> %s", resolved_address)
498511
s.connect(resolved_address)
499512
s.settimeout(t)
500-
keep_alive = 1 if keep_alive else 0
501-
s.setsockopt(SOL_SOCKET, SO_KEEPALIVE, keep_alive)
513+
s.setsockopt(SOL_SOCKET, SO_KEEPALIVE, 1 if keep_alive else 0)
502514
except TimeoutError:
503515
log.debug("[#0000] S: <TIMEOUT> %s", resolved_address)
504-
log.debug("[#0000] C: <CLOSE> %s", resolved_address)
505-
if s:
506-
cls._kill_raw_socket(s)
507516
raise ServiceUnavailable(
508517
"Timed out trying to establish connection to "
509518
f"{resolved_address!r}"
510519
) from None
511520
except Exception as error:
512-
log.debug(
513-
"[#0000] S: <ERROR> %s %s",
514-
type(error).__name__,
515-
" ".join(map(repr, error.args)),
516-
)
517-
log.debug("[#0000] C: <CLOSE> %s", resolved_address)
518-
if s:
519-
cls._kill_raw_socket(s)
520521
if isinstance(error, OSError):
521522
raise ServiceUnavailable(
522523
"Failed to establish connection to "
@@ -531,16 +532,17 @@ def _connect_secure(
531532
sni_host = hostname if HAS_SNI and hostname else None
532533
log.debug("[#%04X] C: <SECURE> %s", local_port, hostname)
533534
try:
535+
t = s.gettimeout()
536+
if timeout:
537+
s.settimeout(deadline.to_timeout())
534538
s = ssl_context.wrap_socket(s, server_hostname=sni_host)
539+
s.settimeout(t)
535540
except (OSError, SSLError, CertificateError) as cause:
536541
log.debug(
537-
"[#0000] S: <SECURE FAILURE> %s: %s",
542+
"[#0000] S: <SECURE FAILURE> %s: %r",
538543
resolved_address,
539544
cause,
540545
)
541-
log.debug("[#0000] C: <CLOSE> %s", resolved_address)
542-
if s:
543-
cls._kill_raw_socket(s)
544546
raise BoltSecurityError(
545547
message="Failed to establish encrypted connection.",
546548
address=(hostname, local_port),
@@ -554,15 +556,17 @@ def _connect_secure(
554556
"[#0000] S: <SECURE FAILURE> %s: no certificate",
555557
resolved_address,
556558
)
557-
log.debug("[#0000] C: <CLOSE> %s", resolved_address)
558-
if s:
559-
cls._kill_raw_socket(s)
560559
raise BoltProtocolError(
561560
"When using an encrypted socket, the server should"
562561
"always provide a certificate",
563562
address=(hostname, local_port),
564563
)
565-
except Exception:
564+
except Exception as error:
565+
log.debug(
566+
"[#0000] S: <ERROR> %s %s",
567+
type(error).__name__,
568+
" ".join(map(repr, error.args)),
569+
)
566570
if s is not None:
567571
log.debug("[#0000] C: <CLOSE> %s", resolved_address)
568572
cls._kill_raw_socket(s)

src/neo4j/_sync/io/_bolt_socket.py

Lines changed: 5 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)