diff --git a/CHANGELOG.md b/CHANGELOG.md index 9ba6e7e214..9926f74f20 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,10 @@ - Populate gen_ai.response.model from gen_ai.request.model if not already set. ([#5654](https://github.com/getsentry/relay/pull/5654)) - Add support for Unix domain sockets for statsd metrics. ([#5668](https://github.com/getsentry/relay/pull/5668)) +**Internal** + +- Allow deferred lengths to the `/upload` endpoint when the sender is trusted. ([#5658](https://github.com/getsentry/relay/pull/5658)) + ## 26.2.1 **Bug Fixes**: diff --git a/relay-server/src/endpoints/upload.rs b/relay-server/src/endpoints/upload.rs index b1b6901b18..ec912d66fb 100644 --- a/relay-server/src/endpoints/upload.rs +++ b/relay-server/src/endpoints/upload.rs @@ -31,7 +31,7 @@ use crate::services::projects::cache::Project; use crate::services::upload::Error as ServiceError; use crate::services::upstream::UpstreamRequestError; use crate::utils::upload::SignedLocation; -use crate::utils::{ExactStream, find_error_source, tus, upload}; +use crate::utils::{BoundedStream, find_error_source, tus, upload}; #[derive(Debug, thiserror::Error)] enum Error { @@ -48,6 +48,7 @@ enum Error { impl IntoResponse for Error { fn into_response(self) -> Response { match self { + Error::Tus(tus::Error::DeferLengthNotAllowed) => StatusCode::FORBIDDEN, Error::Tus(_) => StatusCode::BAD_REQUEST, Error::Request(error) => return error.into_response(), Error::Upload(error) => match error { @@ -115,10 +116,11 @@ async fn handle( headers: HeaderMap, body: Body, ) -> axum::response::Result { - let upload_length = tus::validate_headers(&headers).map_err(Error::from)?; + let upload_length = + tus::validate_headers(&headers, meta.request_trust().is_trusted()).map_err(Error::from)?; let config = state.config(); - if upload_length > config.max_upload_size() { + if upload_length.is_some_and(|len| len > config.max_upload_size()) { return Err(StatusCode::PAYLOAD_TOO_LARGE.into()); } @@ -135,7 +137,12 @@ async fn handle( .into_data_stream() .map(|result| result.map_err(io::Error::other)) .boxed(); - let stream = ExactStream::new(stream, upload_length); + let (lower_bound, upper_bound) = match upload_length { + None => (1, config.max_upload_size()), + Some(u) => (u, u), + }; + let stream = BoundedStream::new(stream, lower_bound, upper_bound); + let byte_counter = stream.byte_counter(); let location = upload::Sink::new(&state) .upload(config, upload::Stream { scoping, stream }) @@ -148,7 +155,7 @@ async fn handle( let mut response = location.into_response(); response .headers_mut() - .insert(tus::UPLOAD_OFFSET, upload_length.into()); + .insert(tus::UPLOAD_OFFSET, byte_counter.get().into()); Ok(response) } @@ -160,14 +167,14 @@ async fn handle( async fn check_request( state: &ServiceState, meta: RequestMeta, - upload_length: usize, + upload_length: Option, project: Project<'_>, ) -> Result { let mut envelope = Envelope::from_request(None, meta); envelope.require_feature(Feature::UploadEndpoint); let mut item = Item::new(ItemType::Attachment); item.set_payload(ContentType::AttachmentRef, vec![]); - item.set_attachment_length(upload_length as u64); + item.set_attachment_length(upload_length.unwrap_or(1) as u64); envelope.add_item(item); let mut envelope = Managed::from_envelope(envelope, state.outcome_aggregator().clone()); let rate_limits = project diff --git a/relay-server/src/utils/stream.rs b/relay-server/src/utils/stream.rs index 7717e7be4e..642d97bfaf 100644 --- a/relay-server/src/utils/stream.rs +++ b/relay-server/src/utils/stream.rs @@ -1,12 +1,14 @@ use std::io; use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; use std::task::{Context, Poll}; use bytes::Bytes; use futures::Stream; use sync_wrapper::SyncWrapper; -/// A streaming body that validates the total byte count against the announced length. +/// A streaming body that validates the total byte count against an expected length if provided. /// /// Returns an error if the stream provides more bytes than `expected_length` (checked per chunk) /// or fewer bytes than `expected_length` (checked when the stream ends). @@ -14,29 +16,49 @@ use sync_wrapper::SyncWrapper; /// This type is `Sync` via [`SyncWrapper`], allowing it to be sent across thread boundaries /// as required by the upload service. #[derive(Debug)] -pub struct ExactStream { +pub struct BoundedStream { + pub lower_bound: usize, + pub upper_bound: usize, inner: Option>, - expected_length: usize, - bytes_received: usize, + byte_counter: ByteCounter, } -impl ExactStream { - /// Creates a new `ExactStream` wrapping the given stream with the expected total length. - pub fn new(stream: S, expected_length: usize) -> Self { +/// A shared counter that can be read after the [`BoundedStream`] has been moved. +#[derive(Clone, Debug)] +pub struct ByteCounter(Arc); + +impl ByteCounter { + fn new() -> Self { + Self(Arc::new(AtomicUsize::new(0))) + } + + fn add(&self, n: usize) -> usize { + n + self.0.fetch_add(n, Ordering::Relaxed) + } + + pub fn get(&self) -> usize { + self.0.load(Ordering::Relaxed) + } +} + +impl BoundedStream { + /// Creates a new [`BoundedStream`] wrapping the given stream with the expected total length. + pub fn new(stream: S, lower_bound: usize, upper_bound: usize) -> Self { Self { inner: Some(SyncWrapper::new(stream)), - expected_length, - bytes_received: 0, + lower_bound, + upper_bound, + byte_counter: ByteCounter::new(), } } - /// Returns the expected total length of the stream. - pub fn expected_length(&self) -> usize { - self.expected_length + /// Returns a shared handle to read the byte count after the stream is consumed. + pub fn byte_counter(&self) -> ByteCounter { + self.byte_counter.clone() } } -impl Stream for ExactStream +impl Stream for BoundedStream where S: Stream> + Send + Unpin, E: Into, @@ -52,14 +74,14 @@ where match inner.poll_next(cx) { Poll::Ready(Some(Ok(bytes))) => { - this.bytes_received += bytes.len(); - if this.bytes_received > this.expected_length { + let bytes_received = this.byte_counter.add(bytes.len()); + if bytes_received > this.upper_bound { this.inner = None; Poll::Ready(Some(Err(io::Error::new( io::ErrorKind::FileTooLarge, format!( - "stream exceeded expected length: received {} > {}", - this.bytes_received, this.expected_length + "stream exceeded upper bound: received {} > {}", + bytes_received, this.upper_bound ), )))) } else { @@ -68,13 +90,14 @@ where } Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e.into()))), Poll::Ready(None) => { - if this.bytes_received < this.expected_length { + let bytes_received = this.byte_counter.get(); + if bytes_received < this.lower_bound { this.inner = None; Poll::Ready(Some(Err(io::Error::new( io::ErrorKind::UnexpectedEof, format!( - "stream shorter than expected length: received {} < {}", - this.bytes_received, this.expected_length + "stream shorter than lower bound: received {} < {}", + bytes_received, this.lower_bound ), )))) } else { diff --git a/relay-server/src/utils/tus.rs b/relay-server/src/utils/tus.rs index fec1ff755f..69211b9590 100644 --- a/relay-server/src/utils/tus.rs +++ b/relay-server/src/utils/tus.rs @@ -5,6 +5,8 @@ //! //! Reference: +use std::str::FromStr; + use axum::http::HeaderMap; use http::HeaderValue; @@ -13,9 +15,12 @@ pub enum Error { /// The TUS version is missing or does not match the server's version. #[error("Version Mismatch")] Version, - /// The `Upload-Length` header is missing or cannot be parsed. + /// The `Upload-Length` and `Upload-Defer-Length` headers are both missing, incorrect, or cannot be parsed. #[error("Invalid Upload-Length")] UploadLength, + /// The `Upload-Defer-Length` header is not allowed for external/untrusted requests. + #[error("Upload-Defer-Length not allowed")] + DeferLengthNotAllowed, /// The `Content-Type` header is not what TUS expects. #[error("Invalid Content-Type")] ContentType, @@ -41,6 +46,11 @@ pub const TUS_VERSION: HeaderValue = HeaderValue::from_static("1.0.0"); /// See . pub const UPLOAD_LENGTH: &str = "Upload-Length"; +/// TUS protocol header for the deferred upload length. +/// +/// See . +pub const UPLOAD_DEFER_LENGTH: &str = "Upload-Defer-Length"; + /// TUS protocol header for the current upload offset. /// /// See . @@ -51,7 +61,10 @@ pub const EXPECTED_CONTENT_TYPE: HeaderValue = HeaderValue::from_static("application/offset+octet-stream"); /// Validates TUS protocol headers and returns the expected upload length. -pub fn validate_headers(headers: &HeaderMap) -> Result { +pub fn validate_headers( + headers: &HeaderMap, + allow_defer_length: bool, +) -> Result, Error> { let tus_version = headers.get(TUS_RESUMABLE); if tus_version != Some(&TUS_VERSION) { return Err(Error::Version); @@ -62,27 +75,30 @@ pub fn validate_headers(headers: &HeaderMap) -> Result { return Err(Error::ContentType); } - let upload_length = headers - .get(UPLOAD_LENGTH) - .ok_or(Error::UploadLength)? - .to_str() - .map_err(|_| Error::UploadLength)? - .parse() - .map_err(|_| Error::UploadLength)?; - - Ok(upload_length) + let upload_length: Option = parse_header(headers, UPLOAD_LENGTH); + let upload_defer_length: Option = parse_header(headers, UPLOAD_DEFER_LENGTH); + + // Exactly one of Upload-Length and Upload-Defer-Length must be present. + // Upload-Defer-Length is only accepted if its value is 1 (as demanded by the TUS protocol) + // and `allow_defer_length` is true (i.e. the sender is trusted/internal). + match (upload_length, upload_defer_length, allow_defer_length) { + (Some(u), None, _) => Ok(Some(u)), + (None, Some(1), true) => Ok(None), + (None, Some(1), false) => Err(Error::DeferLengthNotAllowed), + _ => Err(Error::UploadLength), + } } /// Prepares the required TUS request headers for upstream requests. -pub fn request_headers(upload_length: usize) -> HeaderMap { +pub fn request_headers(upload_length: Option) -> HeaderMap { let mut headers = HeaderMap::new(); headers.insert(TUS_RESUMABLE, TUS_VERSION); headers.insert(http::header::CONTENT_TYPE, EXPECTED_CONTENT_TYPE); - headers.insert( - UPLOAD_LENGTH, - HeaderValue::from_str(&upload_length.to_string()) - .expect("string from usize should always be a valid header"), - ); + if let Some(upload_length) = upload_length { + headers.insert(UPLOAD_LENGTH, HeaderValue::from(upload_length)); + } else { + headers.insert(UPLOAD_DEFER_LENGTH, HeaderValue::from(1)); + } headers } @@ -94,6 +110,10 @@ pub fn response_headers() -> HeaderMap { headers } +fn parse_header(headers: &HeaderMap, header_name: &str) -> Option { + headers.get(header_name)?.to_str().ok()?.parse().ok() +} + #[cfg(test)] mod tests { use http::HeaderValue; @@ -103,7 +123,7 @@ mod tests { #[test] fn test_validate_tus_headers_missing_version() { let headers = HeaderMap::new(); - let result = validate_headers(&headers); + let result = validate_headers(&headers, false); assert!(matches!(result, Err(Error::Version))); } @@ -112,7 +132,7 @@ mod tests { let mut headers = HeaderMap::new(); headers.insert(TUS_RESUMABLE, HeaderValue::from_static("1.0.0")); headers.insert(UPLOAD_LENGTH, HeaderValue::from_static("1024")); - let result = validate_headers(&headers); + let result = validate_headers(&headers, false); assert!(matches!(result, Err(Error::ContentType))); } @@ -121,7 +141,7 @@ mod tests { let mut headers = HeaderMap::new(); headers.insert(TUS_RESUMABLE, HeaderValue::from_static("1.0.0")); headers.insert(hyper::header::CONTENT_TYPE, EXPECTED_CONTENT_TYPE); - let result = validate_headers(&headers); + let result = validate_headers(&headers, false); assert!(matches!(result, Err(Error::UploadLength))); } @@ -131,8 +151,8 @@ mod tests { headers.insert(TUS_RESUMABLE, HeaderValue::from_static("1.0.0")); headers.insert(hyper::header::CONTENT_TYPE, EXPECTED_CONTENT_TYPE); headers.insert(UPLOAD_LENGTH, HeaderValue::from_static("1024")); - let result = validate_headers(&headers); - assert_eq!(result.unwrap(), 1024); + let result = validate_headers(&headers, false); + assert_eq!(result.unwrap().unwrap(), 1024); } #[test] @@ -140,7 +160,37 @@ mod tests { let mut headers = HeaderMap::new(); headers.insert(TUS_RESUMABLE, HeaderValue::from_static("0.2.0")); headers.insert(UPLOAD_LENGTH, HeaderValue::from_static("1024")); - let result = validate_headers(&headers); + let result = validate_headers(&headers, false); assert!(matches!(result, Err(Error::Version))); } + + #[test] + fn test_validate_tus_headers_valid_deferred_length_from_trusted_source() { + let mut headers = HeaderMap::new(); + headers.insert(TUS_RESUMABLE, HeaderValue::from_static("1.0.0")); + headers.insert(hyper::header::CONTENT_TYPE, EXPECTED_CONTENT_TYPE); + headers.insert(UPLOAD_DEFER_LENGTH, HeaderValue::from_static("1")); + let result = validate_headers(&headers, true); + assert!(matches!(result, Ok(None))); + } + + #[test] + fn test_validate_tus_headers_valid_deferred_length_from_untrusted_source() { + let mut headers = HeaderMap::new(); + headers.insert(TUS_RESUMABLE, HeaderValue::from_static("1.0.0")); + headers.insert(hyper::header::CONTENT_TYPE, EXPECTED_CONTENT_TYPE); + headers.insert(UPLOAD_DEFER_LENGTH, HeaderValue::from_static("1")); + let result = validate_headers(&headers, false); + assert!(matches!(result, Err(Error::DeferLengthNotAllowed))); + } + + #[test] + fn test_validate_tus_headers_invalid_deferred_length() { + let mut headers = HeaderMap::new(); + headers.insert(TUS_RESUMABLE, HeaderValue::from_static("1.0.0")); + headers.insert(hyper::header::CONTENT_TYPE, EXPECTED_CONTENT_TYPE); + headers.insert(UPLOAD_DEFER_LENGTH, HeaderValue::from_static("2")); + let result = validate_headers(&headers, true); + assert!(matches!(result, Err(Error::UploadLength))); + } } diff --git a/relay-server/src/utils/upload.rs b/relay-server/src/utils/upload.rs index 2fd938125d..abbf419d48 100644 --- a/relay-server/src/utils/upload.rs +++ b/relay-server/src/utils/upload.rs @@ -24,7 +24,7 @@ use crate::services::upload::{Error as ServiceError, Upload}; use crate::services::upstream::{ SendRequest, UpstreamRelay, UpstreamRequest, UpstreamRequestError, }; -use crate::utils::{ExactStream, tus}; +use crate::utils::{BoundedStream, tus}; /// An error that occurs during upload. #[derive(Debug, thiserror::Error)] @@ -57,7 +57,7 @@ pub struct Stream { /// The organization and project the stream belongs to. pub scoping: Scoping, /// The body to be uploaded to objectstore, with length validation. - pub stream: ExactStream>>, + pub stream: BoundedStream>>, } /// A dispatcher for uploading large files. @@ -92,12 +92,13 @@ impl Sink { #[cfg(feature = "processing")] Sink::Upload(addr) => { let project_id = stream.scoping.project_id; - let length = stream.stream.expected_length(); + let byte_counter = stream.stream.byte_counter(); let key = addr .send(stream) .await .map_err(|_send_error| Error::ServiceUnavailable)?? .into_inner(); + let length = byte_counter.get(); Location { project_id, @@ -202,7 +203,7 @@ impl SignedLocation { /// An upstream request made to the `/upload` endpoint. struct UploadRequest { scoping: Scoping, - body: Option>>>, + body: Option>>>, sender: oneshot::Sender>, } @@ -278,7 +279,8 @@ impl UpstreamRequest for UploadRequest { let project_key = self.scoping.project_key; builder.header("X-Sentry-Auth", format!("Sentry sentry_key={project_key}")); - for (key, value) in tus::request_headers(body.expected_length()) { + let upload_length = (body.lower_bound == body.upper_bound).then_some(body.lower_bound); + for (key, value) in tus::request_headers(upload_length) { let Some(key) = key else { continue }; builder.header(key, value); } diff --git a/tests/integration/test_upload.py b/tests/integration/test_upload.py index 922d1e7d6d..dbe6361f2c 100644 --- a/tests/integration/test_upload.py +++ b/tests/integration/test_upload.py @@ -241,3 +241,27 @@ def test_upload_processing( assert PublicKey.parse(processing_relay.public_key).verify( unsigned_uri.encode(), signature ) + + +@pytest.mark.parametrize("defer_length_value", ["1", "2"]) +def test_upload_with_deferred_length( + mini_sentry, relay, relay_with_processing, project_config, defer_length_value +): + project_id = 42 + processing_relay = relay_with_processing(PROCESSING_OPTIONS) + relay = relay(processing_relay) + + data = b"hello world" + response = relay.post( + "/api/%s/upload/?sentry_key=%s" + % (project_id, mini_sentry.get_dsn_public_key(project_id)), + headers={ + "Tus-Resumable": "1.0.0", + "Upload-Defer-Length": defer_length_value, + "Content-Type": "application/offset+octet-stream", + }, + data=data, + ) + + expected_status_code = 403 if defer_length_value == "1" else 400 + assert response.status_code == expected_status_code