Skip to content

Commit 334f5bc

Browse files
authored
Merge pull request #94 from etiennedm/smpclient-timeout-refactor
2 parents 76f09e0 + f9b98c1 commit 334f5bc

File tree

2 files changed

+43
-27
lines changed

2 files changed

+43
-27
lines changed

smpclient/__init__.py

Lines changed: 37 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ class SMPClient:
6363
Args:
6464
transport: the `SMPTransport` to use
6565
address: the address of the SMP server, see `smpclient.transport` for details
66+
timeout_s: the default timeout in seconds for SMP requests
6667
6768
Example:
6869
@@ -86,25 +87,30 @@ async def main():
8687
```
8788
"""
8889

89-
def __init__(self, transport: SMPTransport, address: str): # noqa: DOC301
90+
def __init__(
91+
self, transport: SMPTransport, address: str, timeout_s: float = 2.5
92+
): # noqa: DOC301
9093
self._transport: Final = transport
9194
self._address: Final = address
95+
self._timeout_s = timeout_s
9296

93-
async def connect(self, timeout_s: float = 5.0) -> None:
97+
async def connect(self, connect_timeout_s: float | None = None) -> None:
9498
"""Connect to the SMP server.
9599
96100
Args:
97-
timeout_s: the timeout for the connection attempt in seconds
101+
connect_timeout_s: the timeout for the connection attempt in seconds
98102
"""
99-
await self._transport.connect(self._address, timeout_s)
100-
await self._initialize()
103+
connect_timeout_s = connect_timeout_s if connect_timeout_s is not None else self._timeout_s
104+
105+
await self._transport.connect(self._address, connect_timeout_s)
106+
await self._initialize(self._timeout_s)
101107

102108
async def disconnect(self) -> None:
103109
"""Disconnect from the SMP server."""
104110
await self._transport.disconnect()
105111

106112
async def request(
107-
self, request: SMPRequest[TRep, TEr1, TEr2], timeout_s: float = 120.000
113+
self, request: SMPRequest[TRep, TEr1, TEr2], timeout_s: float | None = None
108114
) -> TRep | TEr1 | TEr2:
109115
"""Make an `SMPRequest` to the SMP server and return the Response or Error.
110116
@@ -159,6 +165,7 @@ async def request(
159165
```
160166
161167
"""
168+
timeout_s = timeout_s if timeout_s is not None else self._timeout_s
162169

163170
try:
164171
async with timeout(timeout_s):
@@ -198,8 +205,8 @@ async def upload(
198205
image: bytes,
199206
slot: int = 0,
200207
upgrade: bool = False,
201-
first_timeout_s: float = 40.000,
202-
subsequent_timeout_s: float = 2.500,
208+
first_timeout_s: float = 40.0,
209+
subsequent_timeout_s: float | None = None,
203210
use_sha: bool = True,
204211
) -> AsyncIterator[int]:
205212
"""Iteratively upload an `image` to `slot`, yielding the offset.
@@ -214,6 +221,8 @@ async def upload(
214221
[boot_write_img_confirmed()](https://docs.zephyrproject.org/apidoc/latest/group__mcuboot__api.html#ga95ccc9e1c7460fec16b9ce9ac8ad7a72)
215222
for this purpose.
216223
first_timeout_s: the timeout for the first `ImageUploadWrite` request
224+
which might take longer than subsequent requests (e.g. if a big
225+
chunk of flash memory has to be erased upfront).
217226
subsequent_timeout_s: the timeout for subsequent `ImageUploadWrite` requests
218227
use_sha: `True` to include the SHA256 hash of the image in the first
219228
packet.
@@ -228,6 +237,9 @@ async def upload(
228237
Raises:
229238
SMPUploadError: if the upload routine fails
230239
"""
240+
subsequent_timeout_s = (
241+
subsequent_timeout_s if subsequent_timeout_s is not None else self._timeout_s
242+
)
231243

232244
response = await self.request(
233245
self._maximize_image_upload_write_packet(
@@ -290,7 +302,7 @@ async def upload_file(
290302
self,
291303
file_data: bytes,
292304
file_path: str,
293-
timeout_s: float = 2.500,
305+
timeout_s: float | None = None,
294306
) -> AsyncIterator[int]:
295307
"""Iteratively upload a `file_data` to `file_path`, yielding the offset.
296308
@@ -305,6 +317,8 @@ async def upload_file(
305317
Raises:
306318
SMPUploadError: if the upload routine fails
307319
"""
320+
timeout_s = timeout_s if timeout_s is not None else self._timeout_s
321+
308322
response = await self.request(
309323
self._maximize_file_upload_packet(
310324
FileUpload(name=file_path, off=0, data=b"", len=len(file_data)),
@@ -342,7 +356,7 @@ async def upload_file(
342356
async def download_file(
343357
self,
344358
file_path: str,
345-
timeout_s: float = 2.500,
359+
timeout_s: float | None = None,
346360
) -> bytes:
347361
"""Download a file from the SMP server.
348362
@@ -356,6 +370,8 @@ async def download_file(
356370
Raises:
357371
SMPUploadError: if the download routine fails
358372
"""
373+
timeout_s = timeout_s if timeout_s is not None else self._timeout_s
374+
359375
response = await self.request(FileDownload(off=0, name=file_path), timeout_s=timeout_s)
360376
file_length = 0
361377

@@ -500,18 +516,18 @@ def _maximize_file_upload_packet(self, request: FileUpload, data: bytes) -> File
500516
len=request.len,
501517
)
502518

503-
async def _initialize(self) -> None:
519+
async def _initialize(self, timeout_s: float | None = None) -> None:
504520
"""Gather initialization information from the SMP server."""
521+
timeout_s = timeout_s if timeout_s is not None else self._timeout_s
505522

506523
try:
507-
async with timeout(2):
508-
mcumgr_parameters = await self.request(MCUMgrParametersRead())
509-
if success(mcumgr_parameters):
510-
logger.debug(f"MCUMgr parameters: {mcumgr_parameters}")
511-
self._transport.initialize(mcumgr_parameters.buf_size)
512-
elif error(mcumgr_parameters):
513-
logger.warning(f"Error reading MCUMgr parameters: {mcumgr_parameters}")
514-
else:
515-
assert_never(mcumgr_parameters)
516-
except asyncio.TimeoutError:
524+
mcumgr_parameters = await self.request(MCUMgrParametersRead(), timeout_s=timeout_s)
525+
if success(mcumgr_parameters):
526+
logger.debug(f"MCUMgr parameters: {mcumgr_parameters}")
527+
self._transport.initialize(mcumgr_parameters.buf_size)
528+
elif error(mcumgr_parameters):
529+
logger.warning(f"Error reading MCUMgr parameters: {mcumgr_parameters}")
530+
else:
531+
assert_never(mcumgr_parameters)
532+
except TimeoutError:
517533
logger.warning("Timeout waiting for MCUMgr parameters")

tests/test_smp_client.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -85,12 +85,12 @@ def test_constructor() -> None:
8585
@pytest.mark.asyncio
8686
async def test_connect() -> None:
8787
m = SMPMockTransport()
88-
s = SMPClient(m, "address")
88+
s = SMPClient(m, "address", 5.0)
8989
s._initialize = AsyncMock() # type: ignore
9090
await s.connect()
9191

9292
m.connect.assert_awaited_once_with("address", 5.0)
93-
s._initialize.assert_awaited_once_with()
93+
s._initialize.assert_awaited_once_with(5.0)
9494

9595

9696
@pytest.mark.asyncio
@@ -175,7 +175,7 @@ async def test_request() -> None:
175175
@pytest.mark.asyncio
176176
async def test_upload() -> None:
177177
m = SMPMockTransport()
178-
s = SMPClient(m, "address")
178+
s = SMPClient(m, "address", 2.5)
179179

180180
s.request = AsyncMock() # type: ignore
181181

@@ -239,7 +239,7 @@ async def test_upload() -> None:
239239
off=415,
240240
data=image[415 : 415 + 474],
241241
),
242-
timeout_s=2.500,
242+
timeout_s=2.5,
243243
)
244244

245245
# assert that upload() raises SMPUploadError
@@ -386,7 +386,7 @@ async def mock_request(
386386
@pytest.mark.asyncio
387387
async def test_upload_file() -> None:
388388
m = SMPMockTransport()
389-
s = SMPClient(m, "address")
389+
s = SMPClient(m, "address", 2.5)
390390

391391
s.request = AsyncMock() # type: ignore
392392

@@ -617,7 +617,7 @@ async def mock_request(
617617
@pytest.mark.asyncio
618618
async def test_download_file() -> None:
619619
m = SMPMockTransport()
620-
s = SMPClient(m, "address")
620+
s = SMPClient(m, "address", 2.5)
621621

622622
s.request = AsyncMock() # type: ignore
623623

0 commit comments

Comments
 (0)