Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/polars-compute/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pub mod moment;
pub mod nan;
pub mod propagate_dictionary;
pub mod propagate_nulls;
pub mod prune_list_values_validity;
pub mod rolling;
pub mod size;
pub mod sum;
Expand Down
78 changes: 78 additions & 0 deletions crates/polars-compute/src/prune_list_values_validity.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
use arrow::array::{Array, FixedSizeListArray, ListArray};
use arrow::bitmap::bitmask::BitMask;
use arrow::types::Offset;

/// Removes validity mask from list values if all bits that fall within the
/// offsets are set.
pub fn prune_list_values_validity<O: Offset>(arr: &ListArray<O>) -> Option<ListArray<O>> {
let values = arr.values();

let values_validity = values.validity()?;

let list_validity = arr.validity();

let list_validity = list_validity.map(BitMask::from_bitmap);
let values_validity = BitMask::from_bitmap(values_validity);

if values_validity.unset_bits() > 0 {
let offsets = arr.offsets();

let mut has_unset = false;

assert!(list_validity.is_none_or(|x| x.len() == offsets.len_proxy()));
assert_eq!(values_validity.len(), offsets.last().to_usize());

for i in 0..offsets.len_proxy() {
let (start, end) = offsets.start_end(i);

has_unset |= list_validity.is_none_or(|x| unsafe { x.get_bit_unchecked(i) })
&& unsafe { values_validity.sliced_unchecked(start, end - start) }.unset_bits() > 0;
}

if has_unset {
return None;
}
}

Some(ListArray::new(
arr.dtype().clone(),
arr.offsets().clone(),
values.with_validity(None),
arr.validity().cloned(),
))
}

#[cfg(feature = "dtype-array")]
pub fn prune_fixed_size_list_values_validity(
arr: &FixedSizeListArray,
) -> Option<FixedSizeListArray> {
let values = arr.values();

let width = arr.size();

if width > 0
&& let Some(values_validity) = values.validity()
{
let list_validity = arr.validity().filter(|x| x.unset_bits() > 0)?;

let mut has_unset = false;
let values_validity = BitMask::from_bitmap(values_validity);

assert_eq!(list_validity.len(), values_validity.len());

for i in list_validity.true_idx_iter() {
has_unset |= values_validity.sliced(i * width, width).unset_bits() > 0;
}

if has_unset {
return None;
}
}

Some(FixedSizeListArray::new(
arr.dtype().clone(),
arr.len(),
values.with_validity(None),
arr.validity().cloned(),
))
}
44 changes: 42 additions & 2 deletions crates/polars-core/src/series/into.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use std::borrow::Cow;

use arrow::offset::OffsetsBuffer;
use polars_compute::cast::cast_unchecked;
use polars_compute::prune_list_values_validity::prune_list_values_validity;

use crate::prelude::*;

Expand Down Expand Up @@ -178,14 +180,46 @@ impl ToArrowConverter {
Box::new(arr)
},
(DataType::List(item_dtype), ArrowDataType::LargeList(_)) => {
let arr: &ListArray<i64> = array.as_any().downcast_ref().unwrap();
let mut arr: Cow<ListArray<i64>> =
Cow::Borrowed(array.as_any().downcast_ref().unwrap());

let mut arrow_dtype = to_owned_dtype(arrow_field);

let ArrowDataType::LargeList(arrow_item_field) = &mut arrow_dtype else {
unreachable!()
};

if arr.offsets().range() as usize != arr.values().len() {
arr = Cow::Owned({
let offsets = arr.offsets();

let first_idx = *arr.offsets().first();

let offsets = if first_idx == 0 {
offsets.clone()
} else {
let v: Vec<i64> = offsets.iter().map(|x| *x - first_idx).collect();
unsafe { OffsetsBuffer::<i64>::new_unchecked(v.into()) }
};

let values = arr
.values()
.clone()
.sliced(first_idx as usize, offsets.range() as usize);

ListArray::<i64>::new(
arr.dtype().clone(),
offsets,
values,
arr.validity().cloned(),
)
});
}

if !arrow_item_field.is_nullable {
arr = prune_list_values_validity(&arr).map_or(arr, Cow::Owned);
}

self.attach_pl_field_metadata(std::iter::once((
item_dtype.as_ref(),
arrow_item_field.as_mut(),
Expand All @@ -209,7 +243,8 @@ impl ToArrowConverter {
#[cfg(feature = "dtype-array")]
(DataType::Array(item_dtype, width), ArrowDataType::FixedSizeList(_, arrow_width)) => {
use arrow::array::FixedSizeListArray;
let arr: &FixedSizeListArray = array.as_any().downcast_ref().unwrap();
let mut arr: Cow<FixedSizeListArray> =
Cow::Borrowed(array.as_any().downcast_ref().unwrap());

polars_ensure!(
*arrow_width == *width,
Expand All @@ -224,6 +259,11 @@ impl ToArrowConverter {
unreachable!()
};

if !arrow_item_field.is_nullable {
use polars_compute::prune_list_values_validity::prune_fixed_size_list_values_validity;
arr = prune_fixed_size_list_values_validity(&arr).map_or(arr, Cow::Owned);
}

self.attach_pl_field_metadata(std::iter::once((
item_dtype.as_ref(),
arrow_item_field.as_mut(),
Expand Down
6 changes: 2 additions & 4 deletions crates/polars-parquet/src/arrow/write/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ pub fn array_to_page_simple(
) -> PolarsResult<Page> {
let dtype = array.dtype();

if type_.field_info.repetition == Repetition::Required && array.null_count() > 0 {
if type_.field_info.repetition == Repetition::Required && array.has_nulls() {
polars_bail!(InvalidOperation: "writing a missing value to required parquet column '{}'", type_.field_info.name);
}

Expand Down Expand Up @@ -830,9 +830,7 @@ fn array_to_page_nested(
options: WriteOptions,
_encoding: Encoding,
) -> PolarsResult<Page> {
if type_.field_info.repetition == Repetition::Required
&& array.validity().is_some_and(|v| v.unset_bits() > 0)
{
if type_.field_info.repetition == Repetition::Required && array.has_nulls() {
polars_bail!(InvalidOperation: "writing a missing value to required parquet column '{}'", type_.field_info.name);
}

Expand Down
10 changes: 2 additions & 8 deletions py-polars/tests/unit/io/test_lazy_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1649,12 +1649,6 @@ def test_sink_parquet_arrow_schema_view_types() -> None:
assert_frame_equal(pl.scan_parquet(f).collect(), df)


@pytest.mark.xfail(
reason="""
unimplemented: NULLs in list values array corresponding to masked out rows.
ref https://github.com/pola-rs/polars/issues/26600.
""",
)
def test_sink_parquet_arrow_schema_sliced_non_nullable_list() -> None:
schema = {
"list": pl.List(pl.Int64),
Expand Down Expand Up @@ -1697,8 +1691,8 @@ def test_sink_parquet_arrow_schema_sliced_non_nullable_list() -> None:
pl.scan_parquet(f).collect(),
pl.DataFrame(
{
"list": [[1], [2]],
"fixed_size_list": [[1], [2]],
"list": [None, [1], [2]],
"fixed_size_list": [None, [1], [2]],
},
schema=schema,
),
Expand Down
Loading