Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion native/core/src/execution/shuffle/spark_unsafe/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ impl SparkUnsafeObject for SparkUnsafeArray {
impl SparkUnsafeArray {
/// Creates a `SparkUnsafeArray` which points to the given address and size in bytes.
pub fn new(addr: i64) -> Self {
// Read the number of elements from the first 8 bytes.
// SAFETY: addr points to valid Spark UnsafeArray data from the JVM.
// The first 8 bytes contain the element count as a little-endian i64.
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr as *const u8, 8) };
let num_elements = i64::from_le_bytes(slice.try_into().unwrap());

Expand Down Expand Up @@ -83,6 +84,9 @@ impl SparkUnsafeArray {
/// Returns true if the null bit at the given index of the array is set.
#[inline]
pub(crate) fn is_null_at(&self, index: usize) -> bool {
// SAFETY: row_addr points to valid Spark UnsafeArray data. The null bitset starts
// at offset 8 and contains ceil(num_elements/64) * 8 bytes. The caller ensures
// index < num_elements, so word_offset is within the bitset region.
unsafe {
let mask: i64 = 1i64 << (index & 0x3f);
let word_offset = (self.row_addr + 8 + (((index >> 6) as i64) << 3)) as *const i64;
Expand Down
3 changes: 2 additions & 1 deletion native/core/src/execution/shuffle/spark_unsafe/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ pub struct SparkUnsafeMap {
impl SparkUnsafeMap {
/// Creates a `SparkUnsafeMap` which points to the given address and size in bytes.
pub(crate) fn new(addr: i64, size: i32) -> Self {
// Read the number of bytes of key array from the first 8 bytes.
// SAFETY: addr points to valid Spark UnsafeMap data from the JVM.
// The first 8 bytes contain the key array size as a little-endian i64.
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr as *const u8, 8) };
let key_array_size = i64::from_le_bytes(slice.try_into().unwrap());

Expand Down
41 changes: 41 additions & 0 deletions native/core/src/execution/shuffle/spark_unsafe/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,19 @@ const NESTED_TYPE_BUILDER_CAPACITY: usize = 100;
/// A common trait for Spark Unsafe classes that can be used to access the underlying data,
/// e.g., `UnsafeRow` and `UnsafeArray`. This defines a set of methods that can be used to
/// access the underlying data with index.
///
/// # Safety
///
/// Implementations must ensure that:
/// - `get_row_addr()` returns a valid pointer to JVM-allocated memory
/// - `get_element_offset()` returns a valid pointer within the row/array data region
/// - The memory layout follows Spark's UnsafeRow/UnsafeArray format
/// - The memory remains valid for the lifetime of the object (guaranteed by JVM ownership)
///
/// All accessor methods (get_boolean, get_int, etc.) use unsafe pointer operations but are
/// safe to call as long as:
/// - The index is within bounds (caller's responsibility)
/// - The object was constructed from valid Spark UnsafeRow/UnsafeArray data
pub trait SparkUnsafeObject {
/// Returns the address of the row.
fn get_row_addr(&self) -> i64;
Expand All @@ -77,47 +90,55 @@ pub trait SparkUnsafeObject {
/// Returns boolean value at the given index of the object.
fn get_boolean(&self, index: usize) -> bool {
let addr = self.get_element_offset(index, 1);
// SAFETY: addr points to valid element data within the UnsafeRow/UnsafeArray region.
// The caller ensures index is within bounds.
unsafe { *addr != 0 }
}

/// Returns byte value at the given index of the object.
fn get_byte(&self, index: usize) -> i8 {
let addr = self.get_element_offset(index, 1);
// SAFETY: addr points to valid element data (1 byte) within the row/array region.
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 1) };
i8::from_le_bytes(slice.try_into().unwrap())
}

/// Returns short value at the given index of the object.
fn get_short(&self, index: usize) -> i16 {
let addr = self.get_element_offset(index, 2);
// SAFETY: addr points to valid element data (2 bytes) within the row/array region.
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 2) };
i16::from_le_bytes(slice.try_into().unwrap())
}

/// Returns integer value at the given index of the object.
fn get_int(&self, index: usize) -> i32 {
let addr = self.get_element_offset(index, 4);
// SAFETY: addr points to valid element data (4 bytes) within the row/array region.
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 4) };
i32::from_le_bytes(slice.try_into().unwrap())
}

/// Returns long value at the given index of the object.
fn get_long(&self, index: usize) -> i64 {
let addr = self.get_element_offset(index, 8);
// SAFETY: addr points to valid element data (8 bytes) within the row/array region.
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 8) };
i64::from_le_bytes(slice.try_into().unwrap())
}

/// Returns float value at the given index of the object.
fn get_float(&self, index: usize) -> f32 {
let addr = self.get_element_offset(index, 4);
// SAFETY: addr points to valid element data (4 bytes) within the row/array region.
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 4) };
f32::from_le_bytes(slice.try_into().unwrap())
}

/// Returns double value at the given index of the object.
fn get_double(&self, index: usize) -> f64 {
let addr = self.get_element_offset(index, 8);
// SAFETY: addr points to valid element data (8 bytes) within the row/array region.
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 8) };
f64::from_le_bytes(slice.try_into().unwrap())
}
Expand All @@ -126,6 +147,8 @@ pub trait SparkUnsafeObject {
fn get_string(&self, index: usize) -> &str {
let (offset, len) = self.get_offset_and_len(index);
let addr = self.get_row_addr() + offset as i64;
// SAFETY: addr points to valid UTF-8 string data within the variable-length region.
// Offset and length are read from the fixed-length portion of the row/array.
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr as *const u8, len as usize) };

from_utf8(slice).unwrap()
Expand All @@ -135,19 +158,23 @@ pub trait SparkUnsafeObject {
fn get_binary(&self, index: usize) -> &[u8] {
let (offset, len) = self.get_offset_and_len(index);
let addr = self.get_row_addr() + offset as i64;
// SAFETY: addr points to valid binary data within the variable-length region.
// Offset and length are read from the fixed-length portion of the row/array.
unsafe { std::slice::from_raw_parts(addr as *const u8, len as usize) }
}

/// Returns date value at the given index of the object.
fn get_date(&self, index: usize) -> i32 {
let addr = self.get_element_offset(index, 4);
// SAFETY: addr points to valid element data (4 bytes) within the row/array region.
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 4) };
i32::from_le_bytes(slice.try_into().unwrap())
}

/// Returns timestamp value at the given index of the object.
fn get_timestamp(&self, index: usize) -> i64 {
let addr = self.get_element_offset(index, 8);
// SAFETY: addr points to valid element data (8 bytes) within the row/array region.
let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 8) };
i64::from_le_bytes(slice.try_into().unwrap())
}
Expand Down Expand Up @@ -257,6 +284,9 @@ impl SparkUnsafeRow {
/// Returns true if the null bit at the given index of the row is set.
#[inline]
pub(crate) fn is_null_at(&self, index: usize) -> bool {
// SAFETY: row_addr points to valid Spark UnsafeRow data with at least
// ceil(num_fields/64) * 8 bytes of null bitset. The caller ensures index < num_fields.
// word_offset is within the bitset region since (index >> 6) << 3 < bitset size.
unsafe {
let mask: i64 = 1i64 << (index & 0x3f);
let word_offset = (self.row_addr + (((index >> 6) as i64) << 3)) as *const i64;
Expand All @@ -267,6 +297,9 @@ impl SparkUnsafeRow {

/// Unsets the null bit at the given index of the row, i.e., set the bit to 0 (not null).
pub fn set_not_null_at(&mut self, index: usize) {
// SAFETY: row_addr points to valid Spark UnsafeRow data with at least
// ceil(num_fields/64) * 8 bytes of null bitset. The caller ensures index < num_fields.
// Writing is safe because we have mutable access and the memory is owned by the JVM.
unsafe {
let mask: i64 = 1i64 << (index & 0x3f);
let word_offset = (self.row_addr + (((index >> 6) as i64) << 3)) as *mut i64;
Expand Down Expand Up @@ -463,6 +496,8 @@ fn append_columns(
let mut row = SparkUnsafeRow::new(schema);

for i in row_start..row_end {
// SAFETY: row_addresses_ptr and row_sizes_ptr are JNI arrays with at least
// row_end elements. i is in [row_start, row_end) so the offset is in bounds.
let row_addr = unsafe { *row_addresses_ptr.add(i) };
let row_size = unsafe { *row_sizes_ptr.add(i) };
row.point_to(row_addr, row_size);
Expand Down Expand Up @@ -593,6 +628,8 @@ fn append_columns(
let mut row = SparkUnsafeRow::new(schema);

for i in row_start..row_end {
// SAFETY: row_addresses_ptr and row_sizes_ptr are JNI arrays with at least
// row_end elements. i is in [row_start, row_end) so the offset is in bounds.
let row_addr = unsafe { *row_addresses_ptr.add(i) };
let row_size = unsafe { *row_sizes_ptr.add(i) };
row.point_to(row_addr, row_size);
Expand All @@ -613,6 +650,8 @@ fn append_columns(
let mut row = SparkUnsafeRow::new(schema);

for i in row_start..row_end {
// SAFETY: row_addresses_ptr and row_sizes_ptr are JNI arrays with at least
// row_end elements. i is in [row_start, row_end) so the offset is in bounds.
let row_addr = unsafe { *row_addresses_ptr.add(i) };
let row_size = unsafe { *row_sizes_ptr.add(i) };
row.point_to(row_addr, row_size);
Expand Down Expand Up @@ -640,6 +679,8 @@ fn append_columns(
let mut row = SparkUnsafeRow::new(schema);

for i in row_start..row_end {
// SAFETY: row_addresses_ptr and row_sizes_ptr are JNI arrays with at least
// row_end elements. i is in [row_start, row_end) so the offset is in bounds.
let row_addr = unsafe { *row_addresses_ptr.add(i) };
let row_size = unsafe { *row_sizes_ptr.add(i) };
row.point_to(row_addr, row_size);
Expand Down
Loading