Skip to content

Commit 9e0aac1

Browse files
authored
Add unified and configurable null handling (#1101)
* Add unified and configurable null handling * Remove unused import of `Array` in I/O utilities
1 parent 2dc9a02 commit 9e0aac1

File tree

12 files changed

+287
-69
lines changed

12 files changed

+287
-69
lines changed

qdp/qdp-core/src/encoding/mod.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,8 @@ pub(crate) fn stream_encode<E: ChunkEncoder>(
141141
encoder: E,
142142
) -> Result<*mut DLManagedTensor> {
143143
// Initialize reader
144-
let mut reader_core = crate::io::ParquetBlockReader::new(path, None)?;
144+
let mut reader_core =
145+
crate::io::ParquetBlockReader::new(path, None, crate::reader::NullHandling::FillZero)?;
145146
let num_samples = reader_core.total_rows;
146147

147148
// Allocate output state vector

qdp/qdp-core/src/io.rs

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,45 +26,43 @@ use std::fs::File;
2626
use std::path::Path;
2727
use std::sync::Arc;
2828

29-
use arrow::array::{Array, ArrayRef, Float64Array, RecordBatch};
29+
use arrow::array::{ArrayRef, Float64Array, RecordBatch};
3030
use arrow::datatypes::{DataType, Field, Schema};
3131
use parquet::arrow::ArrowWriter;
3232
use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
3333
use parquet::file::properties::WriterProperties;
3434

3535
use crate::error::{MahoutError, Result};
36+
use crate::reader::{NullHandling, handle_float64_nulls};
3637

3738
/// Converts an Arrow Float64Array to Vec<f64>.
38-
pub fn arrow_to_vec(array: &Float64Array) -> Vec<f64> {
39-
if array.null_count() == 0 {
40-
array.values().to_vec()
41-
} else {
42-
array.iter().map(|opt| opt.unwrap_or(0.0)).collect()
43-
}
39+
pub fn arrow_to_vec(array: &Float64Array, null_handling: NullHandling) -> Result<Vec<f64>> {
40+
let mut result = Vec::with_capacity(array.len());
41+
handle_float64_nulls(&mut result, array, null_handling)?;
42+
Ok(result)
4443
}
4544

4645
/// Flattens multiple Arrow Float64Arrays into a single Vec<f64>.
47-
pub fn arrow_to_vec_chunked(arrays: &[Float64Array]) -> Vec<f64> {
46+
pub fn arrow_to_vec_chunked(
47+
arrays: &[Float64Array],
48+
null_handling: NullHandling,
49+
) -> Result<Vec<f64>> {
4850
let total_len: usize = arrays.iter().map(|a| a.len()).sum();
4951
let mut result = Vec::with_capacity(total_len);
5052

5153
for array in arrays {
52-
if array.null_count() == 0 {
53-
result.extend_from_slice(array.values());
54-
} else {
55-
result.extend(array.iter().map(|opt| opt.unwrap_or(0.0)));
56-
}
54+
handle_float64_nulls(&mut result, array, null_handling)?;
5755
}
5856

59-
result
57+
Ok(result)
6058
}
6159

6260
/// Reads Float64 data from a Parquet file.
6361
///
6462
/// Expects a single Float64 column. For zero-copy access, use [`read_parquet_to_arrow`].
6563
pub fn read_parquet<P: AsRef<Path>>(path: P) -> Result<Vec<f64>> {
6664
let chunks = read_parquet_to_arrow(path)?;
67-
Ok(arrow_to_vec_chunked(&chunks))
65+
arrow_to_vec_chunked(&chunks, NullHandling::FillZero)
6866
}
6967

7068
/// Writes Float64 data to a Parquet file.
@@ -228,7 +226,7 @@ pub fn write_arrow_to_parquet<P: AsRef<Path>>(
228226
/// Add OOM protection for very large files
229227
pub fn read_parquet_batch<P: AsRef<Path>>(path: P) -> Result<(Vec<f64>, usize, usize)> {
230228
use crate::reader::DataReader;
231-
let mut reader = crate::readers::ParquetReader::new(path, None)?;
229+
let mut reader = crate::readers::ParquetReader::new(path, None, NullHandling::FillZero)?;
232230
reader.read_batch()
233231
}
234232

@@ -244,7 +242,7 @@ pub fn read_parquet_batch<P: AsRef<Path>>(path: P) -> Result<(Vec<f64>, usize, u
244242
/// Add OOM protection for very large files
245243
pub fn read_arrow_ipc_batch<P: AsRef<Path>>(path: P) -> Result<(Vec<f64>, usize, usize)> {
246244
use crate::reader::DataReader;
247-
let mut reader = crate::readers::ArrowIPCReader::new(path)?;
245+
let mut reader = crate::readers::ArrowIPCReader::new(path, NullHandling::FillZero)?;
248246
reader.read_batch()
249247
}
250248

qdp/qdp-core/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ mod profiling;
3434

3535
pub use error::{MahoutError, Result, cuda_error_to_string};
3636
pub use gpu::memory::Precision;
37+
pub use reader::{NullHandling, handle_float64_nulls};
3738

3839
// Throughput/latency pipeline runner: single path using QdpEngine and encode_batch in Rust.
3940
#[cfg(target_os = "linux")]

qdp/qdp-core/src/pipeline_runner.rs

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ use crate::QdpEngine;
2626
use crate::dlpack::DLManagedTensor;
2727
use crate::error::{MahoutError, Result};
2828
use crate::io;
29-
use crate::reader::StreamingDataReader;
29+
use crate::reader::{NullHandling, StreamingDataReader};
3030
use crate::readers::ParquetStreamingReader;
3131

3232
/// Configuration for throughput/latency pipeline runs (Python run_throughput_pipeline_py).
@@ -39,6 +39,7 @@ pub struct PipelineConfig {
3939
pub encoding_method: String,
4040
pub seed: Option<u64>,
4141
pub warmup_batches: usize,
42+
pub null_handling: NullHandling,
4243
}
4344

4445
impl Default for PipelineConfig {
@@ -51,6 +52,7 @@ impl Default for PipelineConfig {
5152
encoding_method: "amplitude".to_string(),
5253
seed: None,
5354
warmup_batches: 0,
55+
null_handling: NullHandling::FillZero,
5456
}
5557
}
5658
}
@@ -154,12 +156,23 @@ fn path_extension_lower(path: &Path) -> Option<String> {
154156

155157
/// Dispatches by path extension to the appropriate io reader. Returns (data, num_samples, sample_size).
156158
/// Unsupported or missing extension returns Err with message listing supported formats.
157-
fn read_file_by_extension(path: &Path) -> Result<(Vec<f64>, usize, usize)> {
159+
fn read_file_by_extension(
160+
path: &Path,
161+
null_handling: NullHandling,
162+
) -> Result<(Vec<f64>, usize, usize)> {
158163
let ext_lower = path_extension_lower(path);
159164
let ext = ext_lower.as_deref();
160165
match ext {
161-
Some("parquet") => io::read_parquet_batch(path),
162-
Some("arrow") | Some("feather") | Some("ipc") => io::read_arrow_ipc_batch(path),
166+
Some("parquet") => {
167+
use crate::reader::DataReader;
168+
let mut reader = crate::readers::ParquetReader::new(path, None, null_handling)?;
169+
reader.read_batch()
170+
}
171+
Some("arrow") | Some("feather") | Some("ipc") => {
172+
use crate::reader::DataReader;
173+
let mut reader = crate::readers::ArrowIPCReader::new(path, null_handling)?;
174+
reader.read_batch()
175+
}
163176
Some("npy") => io::read_numpy_batch(path),
164177
Some("pt") | Some("pth") => io::read_torch_batch(path),
165178
Some("pb") => io::read_tensorflow_batch(path),
@@ -211,7 +224,7 @@ impl PipelineIterator {
211224
batch_limit: usize,
212225
) -> Result<Self> {
213226
let path = path.as_ref();
214-
let (data, num_samples, sample_size) = read_file_by_extension(path)?;
227+
let (data, num_samples, sample_size) = read_file_by_extension(path, config.null_handling)?;
215228
let vector_len = vector_len(config.num_qubits, &config.encoding_method);
216229

217230
// Dimension validation at construction.
@@ -263,7 +276,11 @@ impl PipelineIterator {
263276
)));
264277
}
265278

266-
let mut reader = ParquetStreamingReader::new(path, Some(DEFAULT_PARQUET_ROW_GROUP_SIZE))?;
279+
let mut reader = ParquetStreamingReader::new(
280+
path,
281+
Some(DEFAULT_PARQUET_ROW_GROUP_SIZE),
282+
config.null_handling,
283+
)?;
267284
let vector_len = vector_len(config.num_qubits, &config.encoding_method);
268285

269286
// Read first chunk to learn sample_size; reuse as initial buffer.

qdp/qdp-core/src/reader.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,48 @@
4545
//! }
4646
//! ```
4747
48+
use arrow::array::{Array, Float64Array};
49+
4850
use crate::error::Result;
4951

52+
/// Policy for handling null values in Float64 arrays.
53+
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
54+
pub enum NullHandling {
55+
/// Replace nulls with 0.0 (backward-compatible default).
56+
#[default]
57+
FillZero,
58+
/// Return an error when a null is encountered.
59+
Reject,
60+
}
61+
62+
/// Append values from a `Float64Array` into `output`, applying the given null policy.
63+
///
64+
/// When there are no nulls the fast path copies the underlying buffer directly.
65+
pub fn handle_float64_nulls(
66+
output: &mut Vec<f64>,
67+
float_array: &Float64Array,
68+
null_handling: NullHandling,
69+
) -> crate::error::Result<()> {
70+
if float_array.null_count() == 0 {
71+
output.extend_from_slice(float_array.values());
72+
} else {
73+
match null_handling {
74+
NullHandling::FillZero => {
75+
output.extend(float_array.iter().map(|opt| opt.unwrap_or(0.0)));
76+
}
77+
NullHandling::Reject => {
78+
return Err(crate::error::MahoutError::InvalidInput(
79+
"Null value encountered in Float64Array. \
80+
Use NullHandling::FillZero to replace nulls with 0.0, \
81+
or clean the data at the source."
82+
.to_string(),
83+
));
84+
}
85+
}
86+
}
87+
Ok(())
88+
}
89+
5090
/// Generic data reader interface for batch quantum data.
5191
///
5292
/// Implementations should read data in the format:

qdp/qdp-core/src/readers/arrow_ipc.rs

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,22 @@ use arrow::datatypes::DataType;
2424
use arrow::ipc::reader::FileReader as ArrowFileReader;
2525

2626
use crate::error::{MahoutError, Result};
27-
use crate::reader::DataReader;
27+
use crate::reader::{DataReader, NullHandling, handle_float64_nulls};
2828

2929
/// Reader for Arrow IPC files containing FixedSizeList<Float64> or List<Float64> columns.
3030
pub struct ArrowIPCReader {
3131
path: std::path::PathBuf,
3232
read: bool,
33+
null_handling: NullHandling,
3334
}
3435

3536
impl ArrowIPCReader {
3637
/// Create a new Arrow IPC reader.
3738
///
3839
/// # Arguments
3940
/// * `path` - Path to the Arrow IPC file (.arrow or .feather)
40-
pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
41+
/// * `null_handling` - Policy for null values (defaults to `FillZero`)
42+
pub fn new<P: AsRef<Path>>(path: P, null_handling: NullHandling) -> Result<Self> {
4143
let path = path.as_ref();
4244

4345
// Verify file exists
@@ -64,6 +66,7 @@ impl ArrowIPCReader {
6466
Ok(Self {
6567
path: path.to_path_buf(),
6668
read: false,
69+
null_handling,
6770
})
6871
}
6972
}
@@ -136,11 +139,7 @@ impl DataReader for ArrowIPCReader {
136139
.downcast_ref::<Float64Array>()
137140
.ok_or_else(|| MahoutError::Io("Values must be Float64".to_string()))?;
138141

139-
if float_array.null_count() == 0 {
140-
all_data.extend_from_slice(float_array.values());
141-
} else {
142-
all_data.extend(float_array.iter().map(|opt| opt.unwrap_or(0.0)));
143-
}
142+
handle_float64_nulls(&mut all_data, float_array, self.null_handling)?;
144143

145144
num_samples += list_array.len();
146145
}
@@ -182,11 +181,7 @@ impl DataReader for ArrowIPCReader {
182181
all_data.reserve(new_capacity);
183182
}
184183

185-
if float_array.null_count() == 0 {
186-
all_data.extend_from_slice(float_array.values());
187-
} else {
188-
all_data.extend(float_array.iter().map(|opt| opt.unwrap_or(0.0)));
189-
}
184+
handle_float64_nulls(&mut all_data, float_array, self.null_handling)?;
190185

191186
num_samples += 1;
192187
}

0 commit comments

Comments
 (0)