Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
- Populate gen_ai.response.model from gen_ai.request.model if not already set. ([#5654](https://github.com/getsentry/relay/pull/5654))
- Add support for Unix domain sockets for statsd metrics. ([#5668](https://github.com/getsentry/relay/pull/5668))

**Internal**

- Allow deferred lengths to the `/upload` endpoint when the sender is trusted. ([#5658](https://github.com/getsentry/relay/pull/5658))

## 26.2.1

**Bug Fixes**:
Expand Down
21 changes: 14 additions & 7 deletions relay-server/src/endpoints/upload.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use crate::services::projects::cache::Project;
use crate::services::upload::Error as ServiceError;
use crate::services::upstream::UpstreamRequestError;
use crate::utils::upload::SignedLocation;
use crate::utils::{ExactStream, find_error_source, tus, upload};
use crate::utils::{BoundedStream, find_error_source, tus, upload};

#[derive(Debug, thiserror::Error)]
enum Error {
Expand All @@ -48,6 +48,7 @@ enum Error {
impl IntoResponse for Error {
fn into_response(self) -> Response {
match self {
Error::Tus(tus::Error::DeferLengthNotAllowed) => StatusCode::FORBIDDEN,
Error::Tus(_) => StatusCode::BAD_REQUEST,
Error::Request(error) => return error.into_response(),
Error::Upload(error) => match error {
Expand Down Expand Up @@ -115,10 +116,11 @@ async fn handle(
headers: HeaderMap,
body: Body,
) -> axum::response::Result<impl IntoResponse> {
let upload_length = tus::validate_headers(&headers).map_err(Error::from)?;
let upload_length =
tus::validate_headers(&headers, meta.request_trust().is_trusted()).map_err(Error::from)?;
let config = state.config();

if upload_length > config.max_upload_size() {
if upload_length.is_some_and(|len| len > config.max_upload_size()) {
return Err(StatusCode::PAYLOAD_TOO_LARGE.into());
}

Expand All @@ -135,7 +137,12 @@ async fn handle(
.into_data_stream()
.map(|result| result.map_err(io::Error::other))
.boxed();
let stream = ExactStream::new(stream, upload_length);
let (lower_bound, upper_bound) = match upload_length {
None => (1, config.max_upload_size()),
Some(u) => (u, u),
};
let stream = BoundedStream::new(stream, lower_bound, upper_bound);
let byte_counter = stream.byte_counter();

let location = upload::Sink::new(&state)
.upload(config, upload::Stream { scoping, stream })
Expand All @@ -148,7 +155,7 @@ async fn handle(
let mut response = location.into_response();
response
.headers_mut()
.insert(tus::UPLOAD_OFFSET, upload_length.into());
.insert(tus::UPLOAD_OFFSET, byte_counter.get().into());

Ok(response)
}
Expand All @@ -160,14 +167,14 @@ async fn handle(
async fn check_request(
state: &ServiceState,
meta: RequestMeta,
upload_length: usize,
upload_length: Option<usize>,
project: Project<'_>,
) -> Result<Scoping, BadStoreRequest> {
let mut envelope = Envelope::from_request(None, meta);
envelope.require_feature(Feature::UploadEndpoint);
let mut item = Item::new(ItemType::Attachment);
item.set_payload(ContentType::AttachmentRef, vec![]);
item.set_attachment_length(upload_length as u64);
item.set_attachment_length(upload_length.unwrap_or(1) as u64);
envelope.add_item(item);
let mut envelope = Managed::from_envelope(envelope, state.outcome_aggregator().clone());
let rate_limits = project
Expand Down
63 changes: 43 additions & 20 deletions relay-server/src/utils/stream.rs
Original file line number Diff line number Diff line change
@@ -1,42 +1,64 @@
use std::io;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::task::{Context, Poll};

use bytes::Bytes;
use futures::Stream;
use sync_wrapper::SyncWrapper;

/// A streaming body that validates the total byte count against the announced length.
/// A streaming body that validates the total byte count against an expected length if provided.
///
/// Returns an error if the stream provides more bytes than `expected_length` (checked per chunk)
/// or fewer bytes than `expected_length` (checked when the stream ends).
///
/// This type is `Sync` via [`SyncWrapper`], allowing it to be sent across thread boundaries
/// as required by the upload service.
#[derive(Debug)]
pub struct ExactStream<S> {
pub struct BoundedStream<S> {
pub lower_bound: usize,
pub upper_bound: usize,
inner: Option<SyncWrapper<S>>,
expected_length: usize,
bytes_received: usize,
byte_counter: ByteCounter,
}

impl<S> ExactStream<S> {
/// Creates a new `ExactStream` wrapping the given stream with the expected total length.
pub fn new(stream: S, expected_length: usize) -> Self {
/// A shared counter that can be read after the [`BoundedStream`] has been moved.
#[derive(Clone, Debug)]
pub struct ByteCounter(Arc<AtomicUsize>);

impl ByteCounter {
fn new() -> Self {
Self(Arc::new(AtomicUsize::new(0)))
}

fn add(&self, n: usize) -> usize {
n + self.0.fetch_add(n, Ordering::Relaxed)
}

pub fn get(&self) -> usize {
self.0.load(Ordering::Relaxed)
}
}

impl<S> BoundedStream<S> {
/// Creates a new [`BoundedStream`] wrapping the given stream with the expected total length.
pub fn new(stream: S, lower_bound: usize, upper_bound: usize) -> Self {
Self {
inner: Some(SyncWrapper::new(stream)),
expected_length,
bytes_received: 0,
lower_bound,
upper_bound,
byte_counter: ByteCounter::new(),
}
}

/// Returns the expected total length of the stream.
pub fn expected_length(&self) -> usize {
self.expected_length
/// Returns a shared handle to read the byte count after the stream is consumed.
pub fn byte_counter(&self) -> ByteCounter {
self.byte_counter.clone()
}
}

impl<S, E> Stream for ExactStream<S>
impl<S, E> Stream for BoundedStream<S>
where
S: Stream<Item = Result<Bytes, E>> + Send + Unpin,
E: Into<io::Error>,
Expand All @@ -52,14 +74,14 @@ where

match inner.poll_next(cx) {
Poll::Ready(Some(Ok(bytes))) => {
this.bytes_received += bytes.len();
if this.bytes_received > this.expected_length {
let bytes_received = this.byte_counter.add(bytes.len());
if bytes_received > this.upper_bound {
this.inner = None;
Poll::Ready(Some(Err(io::Error::new(
io::ErrorKind::FileTooLarge,
format!(
"stream exceeded expected length: received {} > {}",
this.bytes_received, this.expected_length
"stream exceeded upper bound: received {} > {}",
bytes_received, this.upper_bound
),
))))
} else {
Expand All @@ -68,13 +90,14 @@ where
}
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e.into()))),
Poll::Ready(None) => {
if this.bytes_received < this.expected_length {
let bytes_received = this.byte_counter.get();
if bytes_received < this.lower_bound {
this.inner = None;
Poll::Ready(Some(Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
format!(
"stream shorter than expected length: received {} < {}",
this.bytes_received, this.expected_length
"stream shorter than lower bound: received {} < {}",
bytes_received, this.lower_bound
),
))))
} else {
Expand Down
96 changes: 73 additions & 23 deletions relay-server/src/utils/tus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
//!
//! Reference: <https://tus.io/protocols/resumable-upload>

use std::str::FromStr;

use axum::http::HeaderMap;
use http::HeaderValue;

Expand All @@ -13,9 +15,12 @@ pub enum Error {
/// The TUS version is missing or does not match the server's version.
#[error("Version Mismatch")]
Version,
/// The `Upload-Length` header is missing or cannot be parsed.
/// The `Upload-Length` and `Upload-Defer-Length` headers are both missing, incorrect, or cannot be parsed.
#[error("Invalid Upload-Length")]
UploadLength,
/// The `Upload-Defer-Length` header is not allowed for external/untrusted requests.
#[error("Upload-Defer-Length not allowed")]
DeferLengthNotAllowed,
/// The `Content-Type` header is not what TUS expects.
#[error("Invalid Content-Type")]
ContentType,
Expand All @@ -41,6 +46,11 @@ pub const TUS_VERSION: HeaderValue = HeaderValue::from_static("1.0.0");
/// See <https://tus.io/protocols/resumable-upload#upload-length>.
pub const UPLOAD_LENGTH: &str = "Upload-Length";

/// TUS protocol header for the deferred upload length.
///
/// See <https://tus.io/protocols/resumable-upload#upload-defer-length>.
pub const UPLOAD_DEFER_LENGTH: &str = "Upload-Defer-Length";

/// TUS protocol header for the current upload offset.
///
/// See <https://tus.io/protocols/resumable-upload#upload-offset>.
Expand All @@ -51,7 +61,10 @@ pub const EXPECTED_CONTENT_TYPE: HeaderValue =
HeaderValue::from_static("application/offset+octet-stream");

/// Validates TUS protocol headers and returns the expected upload length.
pub fn validate_headers(headers: &HeaderMap) -> Result<usize, Error> {
pub fn validate_headers(
headers: &HeaderMap,
allow_defer_length: bool,
) -> Result<Option<usize>, Error> {
let tus_version = headers.get(TUS_RESUMABLE);
if tus_version != Some(&TUS_VERSION) {
return Err(Error::Version);
Expand All @@ -62,27 +75,30 @@ pub fn validate_headers(headers: &HeaderMap) -> Result<usize, Error> {
return Err(Error::ContentType);
}

let upload_length = headers
.get(UPLOAD_LENGTH)
.ok_or(Error::UploadLength)?
.to_str()
.map_err(|_| Error::UploadLength)?
.parse()
.map_err(|_| Error::UploadLength)?;

Ok(upload_length)
let upload_length: Option<usize> = parse_header(headers, UPLOAD_LENGTH);
let upload_defer_length: Option<usize> = parse_header(headers, UPLOAD_DEFER_LENGTH);

// Exactly one of Upload-Length and Upload-Defer-Length must be present.
// Upload-Defer-Length is only accepted if its value is 1 (as demanded by the TUS protocol)
// and `allow_defer_length` is true (i.e. the sender is trusted/internal).
match (upload_length, upload_defer_length, allow_defer_length) {
(Some(u), None, _) => Ok(Some(u)),
(None, Some(1), true) => Ok(None),
(None, Some(1), false) => Err(Error::DeferLengthNotAllowed),
_ => Err(Error::UploadLength),
}
}

/// Prepares the required TUS request headers for upstream requests.
pub fn request_headers(upload_length: usize) -> HeaderMap {
pub fn request_headers(upload_length: Option<usize>) -> HeaderMap {
let mut headers = HeaderMap::new();
headers.insert(TUS_RESUMABLE, TUS_VERSION);
headers.insert(http::header::CONTENT_TYPE, EXPECTED_CONTENT_TYPE);
headers.insert(
UPLOAD_LENGTH,
HeaderValue::from_str(&upload_length.to_string())
.expect("string from usize should always be a valid header"),
);
if let Some(upload_length) = upload_length {
headers.insert(UPLOAD_LENGTH, HeaderValue::from(upload_length));
} else {
headers.insert(UPLOAD_DEFER_LENGTH, HeaderValue::from(1));
}
headers
}

Expand All @@ -94,6 +110,10 @@ pub fn response_headers() -> HeaderMap {
headers
}

fn parse_header<T: FromStr>(headers: &HeaderMap, header_name: &str) -> Option<T> {
headers.get(header_name)?.to_str().ok()?.parse().ok()
}

#[cfg(test)]
mod tests {
use http::HeaderValue;
Expand All @@ -103,7 +123,7 @@ mod tests {
#[test]
fn test_validate_tus_headers_missing_version() {
let headers = HeaderMap::new();
let result = validate_headers(&headers);
let result = validate_headers(&headers, false);
assert!(matches!(result, Err(Error::Version)));
}

Expand All @@ -112,7 +132,7 @@ mod tests {
let mut headers = HeaderMap::new();
headers.insert(TUS_RESUMABLE, HeaderValue::from_static("1.0.0"));
headers.insert(UPLOAD_LENGTH, HeaderValue::from_static("1024"));
let result = validate_headers(&headers);
let result = validate_headers(&headers, false);
assert!(matches!(result, Err(Error::ContentType)));
}

Expand All @@ -121,7 +141,7 @@ mod tests {
let mut headers = HeaderMap::new();
headers.insert(TUS_RESUMABLE, HeaderValue::from_static("1.0.0"));
headers.insert(hyper::header::CONTENT_TYPE, EXPECTED_CONTENT_TYPE);
let result = validate_headers(&headers);
let result = validate_headers(&headers, false);
assert!(matches!(result, Err(Error::UploadLength)));
}

Expand All @@ -131,16 +151,46 @@ mod tests {
headers.insert(TUS_RESUMABLE, HeaderValue::from_static("1.0.0"));
headers.insert(hyper::header::CONTENT_TYPE, EXPECTED_CONTENT_TYPE);
headers.insert(UPLOAD_LENGTH, HeaderValue::from_static("1024"));
let result = validate_headers(&headers);
assert_eq!(result.unwrap(), 1024);
let result = validate_headers(&headers, false);
assert_eq!(result.unwrap().unwrap(), 1024);
}

#[test]
fn test_validate_tus_headers_unsupported_version() {
let mut headers = HeaderMap::new();
headers.insert(TUS_RESUMABLE, HeaderValue::from_static("0.2.0"));
headers.insert(UPLOAD_LENGTH, HeaderValue::from_static("1024"));
let result = validate_headers(&headers);
let result = validate_headers(&headers, false);
assert!(matches!(result, Err(Error::Version)));
}

#[test]
fn test_validate_tus_headers_valid_deferred_length_from_trusted_source() {
let mut headers = HeaderMap::new();
headers.insert(TUS_RESUMABLE, HeaderValue::from_static("1.0.0"));
headers.insert(hyper::header::CONTENT_TYPE, EXPECTED_CONTENT_TYPE);
headers.insert(UPLOAD_DEFER_LENGTH, HeaderValue::from_static("1"));
let result = validate_headers(&headers, true);
assert!(matches!(result, Ok(None)));
}

#[test]
fn test_validate_tus_headers_valid_deferred_length_from_untrusted_source() {
let mut headers = HeaderMap::new();
headers.insert(TUS_RESUMABLE, HeaderValue::from_static("1.0.0"));
headers.insert(hyper::header::CONTENT_TYPE, EXPECTED_CONTENT_TYPE);
headers.insert(UPLOAD_DEFER_LENGTH, HeaderValue::from_static("1"));
let result = validate_headers(&headers, false);
assert!(matches!(result, Err(Error::DeferLengthNotAllowed)));
}

#[test]
fn test_validate_tus_headers_invalid_deferred_length() {
let mut headers = HeaderMap::new();
headers.insert(TUS_RESUMABLE, HeaderValue::from_static("1.0.0"));
headers.insert(hyper::header::CONTENT_TYPE, EXPECTED_CONTENT_TYPE);
headers.insert(UPLOAD_DEFER_LENGTH, HeaderValue::from_static("2"));
let result = validate_headers(&headers, true);
assert!(matches!(result, Err(Error::UploadLength)));
}
}
Loading
Loading