Skip to content

Commit 16ea30e

Browse files
committed
Change CountingStream to BoundedStream
1 parent cd7e934 commit 16ea30e

File tree

3 files changed

+31
-29
lines changed

3 files changed

+31
-29
lines changed

relay-server/src/endpoints/upload.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ use crate::services::projects::cache::Project;
3131
use crate::services::upload::Error as ServiceError;
3232
use crate::services::upstream::UpstreamRequestError;
3333
use crate::utils::upload::SignedLocation;
34-
use crate::utils::{CountingStream, find_error_source, tus, upload};
34+
use crate::utils::{BoundedStream, find_error_source, tus, upload};
3535

3636
#[derive(Debug, thiserror::Error)]
3737
enum Error {
@@ -137,7 +137,11 @@ async fn handle(
137137
.into_data_stream()
138138
.map(|result| result.map_err(io::Error::other))
139139
.boxed();
140-
let stream = CountingStream::new(stream, upload_length);
140+
let (lower_bound, upper_bound) = match upload_length {
141+
None => (1, config.max_upload_size()),
142+
Some(u) => (u, u),
143+
};
144+
let stream = BoundedStream::new(stream, lower_bound, upper_bound);
141145
let byte_counter = stream.byte_counter();
142146

143147
let location = upload::Sink::new(&state)

relay-server/src/utils/stream.rs

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,14 @@ use sync_wrapper::SyncWrapper;
1616
/// This type is `Sync` via [`SyncWrapper`], allowing it to be sent across thread boundaries
1717
/// as required by the upload service.
1818
#[derive(Debug)]
19-
pub struct CountingStream<S> {
19+
pub struct BoundedStream<S> {
20+
pub lower_bound: usize,
21+
pub upper_bound: usize,
2022
inner: Option<SyncWrapper<S>>,
21-
expected_length: Option<usize>,
2223
byte_counter: ByteCounter,
2324
}
2425

25-
/// A shared counter that can be read after the [`CountingStream`] has been moved.
26+
/// A shared counter that can be read after the [`BoundedStream`] has been moved.
2627
#[derive(Clone, Debug)]
2728
pub struct ByteCounter(Arc<AtomicUsize>);
2829

@@ -40,28 +41,24 @@ impl ByteCounter {
4041
}
4142
}
4243

43-
impl<S> CountingStream<S> {
44-
/// Creates a new [`CountingStream`] wrapping the given stream with the expected total length.
45-
pub fn new(stream: S, expected_length: Option<usize>) -> Self {
44+
impl<S> BoundedStream<S> {
45+
/// Creates a new [`BoundedStream`] wrapping the given stream with the expected total length.
46+
pub fn new(stream: S, lower_bound: usize, upper_bound: usize) -> Self {
4647
Self {
4748
inner: Some(SyncWrapper::new(stream)),
48-
expected_length,
49+
lower_bound,
50+
upper_bound,
4951
byte_counter: ByteCounter::new(),
5052
}
5153
}
5254

53-
/// Returns the expected total length of the stream.
54-
pub fn expected_length(&self) -> Option<usize> {
55-
self.expected_length
56-
}
57-
5855
/// Returns a shared handle to read the byte count after the stream is consumed.
5956
pub fn byte_counter(&self) -> ByteCounter {
6057
self.byte_counter.clone()
6158
}
6259
}
6360

64-
impl<S, E> Stream for CountingStream<S>
61+
impl<S, E> Stream for BoundedStream<S>
6562
where
6663
S: Stream<Item = Result<Bytes, E>> + Send + Unpin,
6764
E: Into<io::Error>,
@@ -78,15 +75,13 @@ where
7875
match inner.poll_next(cx) {
7976
Poll::Ready(Some(Ok(bytes))) => {
8077
let bytes_received = this.byte_counter.add(bytes.len());
81-
if let Some(expected_length) = this.expected_length
82-
&& bytes_received > expected_length
83-
{
78+
if bytes_received > this.upper_bound {
8479
this.inner = None;
8580
Poll::Ready(Some(Err(io::Error::new(
8681
io::ErrorKind::FileTooLarge,
8782
format!(
88-
"stream exceeded expected length: received {} > {}",
89-
bytes_received, expected_length
83+
"stream exceeded upper bound: received {} > {}",
84+
bytes_received, this.upper_bound
9085
),
9186
))))
9287
} else {
@@ -96,15 +91,13 @@ where
9691
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e.into()))),
9792
Poll::Ready(None) => {
9893
let bytes_received = this.byte_counter.get();
99-
if let Some(expected_length) = this.expected_length
100-
&& bytes_received < expected_length
101-
{
94+
if bytes_received < this.lower_bound {
10295
this.inner = None;
10396
Poll::Ready(Some(Err(io::Error::new(
10497
io::ErrorKind::UnexpectedEof,
10598
format!(
106-
"stream shorter than expected length: received {} < {}",
107-
bytes_received, expected_length
99+
"stream shorter than lower bound: received {} < {}",
100+
bytes_received, this.lower_bound
108101
),
109102
))))
110103
} else {

relay-server/src/utils/upload.rs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ use crate::services::upload::{Error as ServiceError, Upload};
2424
use crate::services::upstream::{
2525
SendRequest, UpstreamRelay, UpstreamRequest, UpstreamRequestError,
2626
};
27-
use crate::utils::{CountingStream, tus};
27+
use crate::utils::{BoundedStream, tus};
2828

2929
/// An error that occurs during upload.
3030
#[derive(Debug, thiserror::Error)]
@@ -57,7 +57,7 @@ pub struct Stream {
5757
/// The organization and project the stream belongs to.
5858
pub scoping: Scoping,
5959
/// The body to be uploaded to objectstore, with length validation.
60-
pub stream: CountingStream<BoxStream<'static, std::io::Result<Bytes>>>,
60+
pub stream: BoundedStream<BoxStream<'static, std::io::Result<Bytes>>>,
6161
}
6262

6363
/// A dispatcher for uploading large files.
@@ -204,7 +204,7 @@ impl SignedLocation {
204204
/// An upstream request made to the `/upload` endpoint.
205205
struct UploadRequest {
206206
scoping: Scoping,
207-
body: Option<CountingStream<BoxStream<'static, std::io::Result<Bytes>>>>,
207+
body: Option<BoundedStream<BoxStream<'static, std::io::Result<Bytes>>>>,
208208
sender: oneshot::Sender<Result<Response, UpstreamRequestError>>,
209209
}
210210

@@ -280,7 +280,12 @@ impl UpstreamRequest for UploadRequest {
280280

281281
let project_key = self.scoping.project_key;
282282
builder.header("X-Sentry-Auth", format!("Sentry sentry_key={project_key}"));
283-
for (key, value) in tus::request_headers(body.expected_length()) {
283+
let upload_length = if body.lower_bound == body.upper_bound {
284+
Some(body.lower_bound)
285+
} else {
286+
None
287+
};
288+
for (key, value) in tus::request_headers(upload_length) {
284289
let Some(key) = key else { continue };
285290
builder.header(key, value);
286291
}

0 commit comments

Comments
 (0)