Skip to content

Commit 60e5a9b

Browse files
authored
Merge pull request #7 from bennyz/limit-buffer-memory
fix OOM in fast networks
2 parents 0da515b + 764efde commit 60e5a9b

File tree

7 files changed

+532
-29
lines changed

7 files changed

+532
-29
lines changed

src/fls/byte_channel.rs

Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
/// Byte-bounded channel wrapper for memory-safe streaming
2+
///
3+
/// Wraps `mpsc::channel` with a `tokio::sync::Semaphore` to bound total
4+
/// buffered bytes rather than item count. This prevents OOM when chunk
5+
/// sizes vary (e.g., reqwest delivering 64-256KB chunks on fast networks).
6+
use std::sync::Arc;
7+
use tokio::sync::{mpsc, Semaphore};
8+
9+
/// Trait for items that know their byte size.
10+
pub trait SizedItem {
11+
fn byte_size(&self) -> usize;
12+
}
13+
14+
impl SizedItem for bytes::Bytes {
15+
fn byte_size(&self) -> usize {
16+
self.len()
17+
}
18+
}
19+
20+
/// Sender half of a byte-bounded channel.
21+
///
22+
/// Acquires semaphore permits equal to `chunk.byte_size()` before sending,
23+
/// ensuring total buffered bytes never exceeds `max_bytes`.
24+
pub struct ByteBoundedSender<T: SizedItem> {
25+
inner: mpsc::Sender<T>,
26+
semaphore: Arc<Semaphore>,
27+
max_bytes: usize,
28+
}
29+
30+
impl<T: SizedItem> ByteBoundedSender<T> {
31+
/// Send an item, blocking (async) until enough byte budget is available.
32+
///
33+
/// Acquires `min(item.byte_size(), max_bytes)` permits so a single
34+
/// oversized chunk can still pass through without deadlocking.
35+
pub async fn send(&self, item: T) -> Result<(), mpsc::error::SendError<T>> {
36+
let permits_needed = item.byte_size().min(self.max_bytes);
37+
38+
let permits_needed_u32 = permits_needed as u32;
39+
40+
// acquire_many_owned returns OwnedSemaphorePermit which we intentionally
41+
// forget — the receiver side adds permits back after consuming the item.
42+
let permit = self
43+
.semaphore
44+
.acquire_many(permits_needed_u32)
45+
.await
46+
.expect("semaphore closed unexpectedly");
47+
permit.forget();
48+
49+
self.inner.send(item).await
50+
}
51+
}
52+
53+
/// Receiver half of a byte-bounded channel.
54+
///
55+
/// Returns semaphore permits after receiving each item, freeing byte budget
56+
/// for the sender.
57+
pub struct ByteBoundedReceiver<T: SizedItem> {
58+
inner: mpsc::Receiver<T>,
59+
semaphore: Arc<Semaphore>,
60+
max_bytes: usize,
61+
}
62+
63+
impl<T: SizedItem> ByteBoundedReceiver<T> {
64+
/// Receive an item asynchronously, releasing byte budget on success.
65+
pub async fn recv(&mut self) -> Option<T> {
66+
let item = self.inner.recv().await?;
67+
let to_release = item.byte_size().min(self.max_bytes);
68+
self.semaphore.add_permits(to_release);
69+
Some(item)
70+
}
71+
72+
/// Receive an item synchronously (for use in `spawn_blocking`),
73+
/// releasing byte budget on success.
74+
pub fn blocking_recv(&mut self) -> Option<T> {
75+
let item = self.inner.blocking_recv()?;
76+
let to_release = item.byte_size().min(self.max_bytes);
77+
self.semaphore.add_permits(to_release);
78+
Some(item)
79+
}
80+
}
81+
82+
/// Create a byte-bounded channel.
83+
///
84+
/// - `max_bytes`: maximum total bytes buffered at any time (must be ≤ u32::MAX)
85+
/// - `max_items`: underlying mpsc channel item capacity (prevents unbounded item queuing)
86+
///
87+
/// # Panics
88+
///
89+
/// Panics if `max_bytes` exceeds `u32::MAX` (4,294,967,295 bytes ≈ 4GB).
90+
/// This limit exists because the underlying semaphore uses u32 permit counts.
91+
pub fn byte_bounded_channel<T: SizedItem>(
92+
max_bytes: usize,
93+
max_items: usize,
94+
) -> (ByteBoundedSender<T>, ByteBoundedReceiver<T>) {
95+
// Guard against overflow in send() method's permits_needed as u32 cast
96+
if max_bytes > u32::MAX as usize {
97+
panic!(
98+
"max_bytes ({}) exceeds u32::MAX ({}). Maximum supported buffer size is ~4GB.",
99+
max_bytes,
100+
u32::MAX
101+
);
102+
}
103+
104+
let (tx, rx) = mpsc::channel::<T>(max_items);
105+
let semaphore = Arc::new(Semaphore::new(max_bytes));
106+
107+
let sender = ByteBoundedSender {
108+
inner: tx,
109+
semaphore: semaphore.clone(),
110+
max_bytes,
111+
};
112+
113+
let receiver = ByteBoundedReceiver {
114+
inner: rx,
115+
semaphore,
116+
max_bytes,
117+
};
118+
119+
(sender, receiver)
120+
}
121+
122+
#[cfg(test)]
123+
mod tests {
124+
use super::*;
125+
use bytes::Bytes;
126+
use std::time::Duration;
127+
use tokio::time::timeout;
128+
129+
#[tokio::test]
130+
async fn test_basic_send_receive() {
131+
let (tx, mut rx) = byte_bounded_channel::<Bytes>(1024, 10);
132+
133+
let data = Bytes::from_static(b"hello");
134+
tx.send(data.clone()).await.unwrap();
135+
136+
let received = rx.recv().await.unwrap();
137+
assert_eq!(received, data);
138+
}
139+
140+
#[tokio::test]
141+
async fn test_byte_limit_enforcement() {
142+
// 100-byte limit, 5 item capacity
143+
let (tx, _rx) = byte_bounded_channel::<Bytes>(100, 5);
144+
145+
// Send 80 bytes (should succeed)
146+
let chunk1 = Bytes::from(vec![1u8; 80]);
147+
tx.send(chunk1).await.unwrap();
148+
149+
// Send 20 bytes (should succeed, total = 100)
150+
let chunk2 = Bytes::from(vec![2u8; 20]);
151+
tx.send(chunk2).await.unwrap();
152+
153+
// Try to send 1 more byte (should block)
154+
let chunk3 = Bytes::from(vec![3u8; 1]);
155+
let send_future = tx.send(chunk3);
156+
157+
// Should timeout because buffer is full
158+
assert!(timeout(Duration::from_millis(50), send_future)
159+
.await
160+
.is_err());
161+
}
162+
163+
#[tokio::test]
164+
async fn test_permits_released_after_recv() {
165+
let (tx, mut rx) = byte_bounded_channel::<Bytes>(100, 5);
166+
167+
// Fill buffer to capacity
168+
let chunk1 = Bytes::from(vec![1u8; 60]);
169+
let chunk2 = Bytes::from(vec![2u8; 40]);
170+
tx.send(chunk1).await.unwrap();
171+
tx.send(chunk2).await.unwrap();
172+
173+
// Buffer should be full, next send should block
174+
let chunk3 = Bytes::from(vec![3u8; 1]);
175+
let send_future = tx.send(chunk3.clone());
176+
assert!(timeout(Duration::from_millis(50), send_future)
177+
.await
178+
.is_err());
179+
180+
// Consume one chunk, freeing 60 bytes
181+
let _received = rx.recv().await.unwrap();
182+
183+
// Now the blocked send should succeed
184+
let send_future = tx.send(chunk3);
185+
assert!(timeout(Duration::from_millis(50), send_future)
186+
.await
187+
.is_ok());
188+
}
189+
190+
#[tokio::test]
191+
async fn test_oversized_chunk_handling() {
192+
// 50-byte limit
193+
let (tx, mut rx) = byte_bounded_channel::<Bytes>(50, 5);
194+
195+
// Send 100-byte chunk (larger than limit)
196+
let big_chunk = Bytes::from(vec![1u8; 100]);
197+
198+
// Should succeed (acquires min(100, 50) = 50 permits)
199+
tx.send(big_chunk.clone()).await.unwrap();
200+
201+
// Should be able to receive it
202+
let received = rx.recv().await.unwrap();
203+
assert_eq!(received, big_chunk);
204+
}
205+
206+
#[tokio::test]
207+
async fn test_multiple_chunk_sizes() {
208+
let (tx, mut rx) = byte_bounded_channel::<Bytes>(1000, 100);
209+
210+
let chunks = vec![
211+
Bytes::from(vec![1u8; 100]), // Small
212+
Bytes::from(vec![2u8; 500]), // Medium
213+
Bytes::from(vec![3u8; 300]), // Large
214+
Bytes::from(vec![4u8; 50]), // Tiny
215+
];
216+
217+
// Send all chunks
218+
for chunk in &chunks {
219+
tx.send(chunk.clone()).await.unwrap();
220+
}
221+
222+
// Receive and verify
223+
for expected in &chunks {
224+
let received = rx.recv().await.unwrap();
225+
assert_eq!(received, *expected);
226+
}
227+
}
228+
229+
#[tokio::test]
230+
async fn test_channel_closure() {
231+
let (tx, mut rx) = byte_bounded_channel::<Bytes>(100, 5);
232+
233+
tx.send(Bytes::from_static(b"data")).await.unwrap();
234+
drop(tx); // Close sender
235+
236+
// Should receive the sent data
237+
let received = rx.recv().await.unwrap();
238+
assert_eq!(received, Bytes::from_static(b"data"));
239+
240+
// Next recv should return None (channel closed)
241+
assert!(rx.recv().await.is_none());
242+
}
243+
244+
#[tokio::test]
245+
async fn test_blocking_recv() {
246+
let (tx, mut rx) = byte_bounded_channel::<Bytes>(100, 5);
247+
248+
// Test in spawn_blocking since blocking_recv is sync
249+
let handle = tokio::task::spawn_blocking(move || {
250+
// This should block until data is available
251+
rx.blocking_recv()
252+
});
253+
254+
// Give it a moment to start blocking
255+
tokio::time::sleep(Duration::from_millis(10)).await;
256+
257+
// Send data
258+
tx.send(Bytes::from_static(b"test")).await.unwrap();
259+
260+
// Should now unblock and return the data
261+
let result = handle.await.unwrap();
262+
assert_eq!(result.unwrap(), Bytes::from_static(b"test"));
263+
}
264+
265+
#[test]
266+
#[should_panic(expected = "max_bytes (4294967296) exceeds u32::MAX")]
267+
fn test_max_bytes_overflow_guard() {
268+
// Try to create a channel with max_bytes > u32::MAX
269+
let oversized = (u32::MAX as usize) + 1;
270+
let _ = byte_bounded_channel::<Bytes>(oversized, 100);
271+
}
272+
}

src/fls/from_url.rs

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use tokio::sync::mpsc;
55
use tokio::task::JoinHandle;
66

77
use crate::fls::block_writer::AsyncBlockWriter;
8+
use crate::fls::byte_channel::byte_bounded_channel;
89
use crate::fls::decompress::{spawn_stderr_reader, start_decompressor_process};
910
use crate::fls::download_error::DownloadError;
1011
use crate::fls::error_handling::process_error_messages;
@@ -349,21 +350,18 @@ pub async fn flash_from_url(
349350

350351
use futures_util::StreamExt;
351352

352-
// Calculate buffer capacity (shared across all retry attempts)
353+
// Create byte-bounded download buffer (shared across all retry attempts)
353354
let buffer_size_mb = options.common.buffer_size_mb;
354-
// HTTP chunks from reqwest are typically 8-32 KB, not 64 KB
355-
// To ensure we get the full buffer size, use a conservative estimate
356-
let avg_chunk_size_kb = 16; // From common obvervation: 16kb
357-
let buffer_capacity = (buffer_size_mb * 1024) / avg_chunk_size_kb;
358-
let buffer_capacity = buffer_capacity.max(1000); // At least 1000 chunks
355+
let max_buffer_bytes = buffer_size_mb * 1024 * 1024;
359356

360357
println!(
361-
"Using download buffer: {} MB (capacity: {} chunks, ~{} KB per chunk)",
362-
buffer_size_mb, buffer_capacity, avg_chunk_size_kb
358+
"Using download buffer: {} MB (byte-bounded)",
359+
buffer_size_mb
363360
);
364361

365-
// Create persistent bounded channel for download buffering (lives across retries)
366-
let (buffer_tx, mut buffer_rx) = mpsc::channel::<bytes::Bytes>(buffer_capacity);
362+
// Create persistent byte-bounded channel for download buffering (lives across retries)
363+
// max_items=4096 prevents unbounded item queuing; byte budget is the real bound
364+
let (buffer_tx, mut buffer_rx) = byte_bounded_channel::<bytes::Bytes>(max_buffer_bytes, 4096);
367365

368366
// Channels for tracking bytes actually written to decompressor
369367
let (decompressor_written_progress_tx, mut decompressor_written_progress_rx) =

src/fls/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// Module declarations
22
pub mod automotive;
33
mod block_writer;
4+
pub mod byte_channel;
45
pub(crate) mod compression;
56
mod decompress;
67
mod download_error;

0 commit comments

Comments
 (0)